In [1]:
import torch
import torch.nn as nn

# ----------------------------------------------
# Import all model components from your repo
# ----------------------------------------------
from models.dna_model import NucleotideTransformerEmbedder
from models.rna_model import RNAFMEmbedder
from models.protein_model import ESM2Embedder
from models.text_model import TextEmbedder

from models.projection_heads import ProjectionHead
from models.fusion_concat import FusionConcat
from models.fusion_mil import FusionMIL
from models.fusion_xattn import FusionCrossAttention

from models.prediction_head import TextCNNHead, MLPHead
from models.lora_adapter import LoRAAdapter


device = "cuda" if torch.cuda.is_available() else "cpu"

print("âš™ Running on:", device.upper())


ModuleNotFoundError: No module named 'models'

In [None]:
# Short toy sequences for testing

dna_seq = "ATGCGTACGTAGCTAGCTAGCTA"
rna_seq = "AUGGCUACUGAACCUUAGCUGGAAA"
protein_seq = "MKTLLIALAVAAALA"
text_info = "This is a sample description of an mRNA sequence."

# Instantiate encoders
dna_enc = NucleotideTransformerEmbedder(max_len=200, device=device)
rna_enc = RNAFMEmbedder(max_len=200, device=device)
protein_enc = ESM2Embedder(max_len=200, device=device)
text_enc = TextEmbedder(max_len=100, device=device)

print("Encoders initialized.")


In [None]:
with torch.no_grad():
    dna_emb = dna_enc(dna_seq)
    rna_emb = rna_enc(rna_seq)
    protein_emb = protein_enc(protein_seq)
    text_emb = text_enc(text_info)

print("DNA embedding:", dna_emb.shape)
print("RNA embedding:", rna_emb.shape)
print("Protein embedding:", protein_emb.shape)
print("Text embedding:", text_emb.shape)


In [None]:
latent_dim = 256

proj_dna = ProjectionHead(input_dim=dna_emb.shape[-1], output_dim=latent_dim).to(device)
proj_rna = ProjectionHead(input_dim=rna_emb.shape[-1], output_dim=latent_dim).to(device)
proj_prot = ProjectionHead(input_dim=protein_emb.shape[-1], output_dim=latent_dim).to(device)
proj_text = ProjectionHead(input_dim=text_emb.shape[-1], output_dim=latent_dim).to(device)

dna_z = proj_dna(dna_emb)
rna_z = proj_rna(rna_emb)
prot_z = proj_prot(protein_emb)
text_z = proj_text(text_emb)

print("Projected shapes:")
print("DNA:", dna_z.shape)
print("RNA:", rna_z.shape)
print("Protein:", prot_z.shape)
print("Text:", text_z.shape)


In [2]:
fusion = FusionConcat()
fused = fusion(dna_z, rna_z, prot_z)

print("Fused (concat) shape:", fused.shape)


NameError: name 'FusionConcat' is not defined