In [1]:
import sys
import os

# Go up one level: notebooks → project root
PROJECT_ROOT = os.path.abspath("..")
sys.path.append(PROJECT_ROOT)


import torch
import torch.nn as nn

# ----------------------------------------------
# Import all model components from your repo
# ----------------------------------------------

from models.dna_model import NucleotideTransformerEmbedder
from models.rna_model import RNAFMEmbedder
from models.protein_model import ESM2Embedder
from models.text_model import TextEmbedder

from models.projection_heads import ProjectionHead
from models.fusion_concat import FusionConcat
from models.fusion_mil import FusionMIL
from models.fusion_xattn import FusionCrossAttention

from models.prediction_head import TextCNNHead, MLPHead
from models.lora_adapter import LoRAAdapter


device = "cuda" if torch.cuda.is_available() else "cpu"

print("⚙ Running on:", device.upper())


⚙ Running on: CPU


In [2]:
# Short toy sequences for testing

dna_seq = "ATGCGTACGTAGCTAGCTAGCTA"
rna_seq = "AUGGCUACUGAACCUUAGCUGGAAA"
protein_seq = "MKTLLIALAVAAALA"
text_info = "This is a sample description of an mRNA sequence."

# Instantiate encoders
dna_enc = NucleotideTransformerEmbedder(max_len=200, device=device)
rna_enc = RNAFMEmbedder(max_len=200, device=device)
protein_enc = ESM2Embedder(max_len=200, device=device)
text_enc = TextEmbedder(max_len=100, device=device)

print("Encoders initialized.")


Loading public DNA model: zhihan1996/DNABERT-2-117M


Some weights of the model checkpoint at zhihan1996/DNABERT-2-117M were not used when initializing BertModel: ['cls.predictions.decoder.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertModel were not initialized from the model checkpoint at zhihan1996/DNABERT-2-117M and are newly initialized: ['bert.pooler.dense.weight', 'bert.pooler.dense.bias']
You should prob

Encoders initialized.


In [3]:
with torch.no_grad():
    dna_emb = dna_enc(dna_seq)
    rna_emb = rna_enc(rna_seq)
    protein_emb = protein_enc(protein_seq)
    text_emb = text_enc(text_info)

print("DNA embedding:", dna_emb.shape)
print("RNA embedding:", rna_emb.shape)
print("Protein embedding:", protein_emb.shape)
print("Text embedding:", text_emb.shape)


DNA embedding: torch.Size([200, 768])
RNA embedding: torch.Size([200, 320])
Protein embedding: torch.Size([200, 640])
Text embedding: torch.Size([100, 768])


In [4]:
latent_dim = 256

proj_dna = ProjectionHead(input_dim=dna_emb.shape[-1], output_dim=latent_dim).to(device)
proj_rna = ProjectionHead(input_dim=rna_emb.shape[-1], output_dim=latent_dim).to(device)
proj_prot = ProjectionHead(input_dim=protein_emb.shape[-1], output_dim=latent_dim).to(device)
proj_text = ProjectionHead(input_dim=text_emb.shape[-1], output_dim=latent_dim).to(device)

dna_z = proj_dna(dna_emb)
rna_z = proj_rna(rna_emb)
prot_z = proj_prot(protein_emb)
text_z = proj_text(text_emb)

print("Projected shapes:")
print("DNA:", dna_z.shape)
print("RNA:", rna_z.shape)
print("Protein:", prot_z.shape)
print("Text:", text_z.shape)


Projected shapes:
DNA: torch.Size([200, 256])
RNA: torch.Size([200, 256])
Protein: torch.Size([200, 256])
Text: torch.Size([100, 256])


In [6]:
latent_dim = 256

# Dimensions after your ProjectionHead
dDNA = dna_z.shape[-1]    # = 256
dRNA = rna_z.shape[-1]    # = 256
dProt = prot_z.shape[-1]  # = 256

fusion = FusionConcat(
    dDNA=dDNA,
    dRNA=dRNA,
    dProt=dProt,
    dDNA_proj=256     # or 128 to reduce DNA dominance
)

fused = fusion(dna_z, rna_z, prot_z)
print("Fused concat shape:", fused.shape)

Fused concat shape: torch.Size([200, 768])


In [8]:
dDNA = dna_z.shape[-1]    # 256
dRNA = rna_z.shape[-1]    # 256
dProt = prot_z.shape[-1]  # 256

fusion = FusionMIL(
    dDNA=dDNA,
    dRNA=dRNA,
    dProt=dProt,
    d_model=256,      # shared dimension
    d_attn=128        # gating attention dimension
)

fused = fusion(dna_z, rna_z, prot_z)

print("Fused (MIL) shape:", fused.shape)

Fused (MIL) shape: torch.Size([200, 256])


In [10]:
fusion = FusionCrossAttention(
    dDNA=dDNA,
    dRNA=dRNA,
    dProt=dProt,
    d_model=latent_dim,
    num_heads=4
)

fused = fusion(dna_z, rna_z, prot_z)
print("Fused (xAttn) shape:", fused.shape)

Fused (xAttn) shape: torch.Size([200, 256])


In [11]:
pred_head = TextCNNHead(embed_dim=latent_dim, num_classes=1).to(device)
prediction = pred_head(fused)

print("Prediction shape (TextCNN):", prediction.shape)


Prediction shape (TextCNN): torch.Size([1, 1])


In [12]:
pred_head = MLPHead(input_dim=latent_dim, num_classes=1).to(device)
prediction = pred_head(fused)

print("Prediction shape (MLP):", prediction.shape)


Prediction shape (MLP): torch.Size([1])


In [13]:
linear = nn.Linear(latent_dim, latent_dim)
lora = LoRAAdapter(linear, rank=8)

x = torch.randn(10, latent_dim)
y = lora(x)

print("LoRA output shape:", y.shape)


LoRA output shape: torch.Size([10, 256])


In [14]:
print("\n====== SUMMARY ======")
print("DNA emb:", dna_emb.shape)
print("RNA emb:", rna_emb.shape)
print("Protein emb:", protein_emb.shape)
print("Text emb:", text_emb.shape)

print("After projection:", dna_z.shape)
print("Fused representation:", fused.shape)
print("Final prediction:", prediction.shape)
print("======================")



DNA emb: torch.Size([200, 768])
RNA emb: torch.Size([200, 320])
Protein emb: torch.Size([200, 640])
Text emb: torch.Size([100, 768])
After projection: torch.Size([200, 256])
Fused representation: torch.Size([200, 256])
Final prediction: torch.Size([1])
