# Project MOML - RBP Siamese



##gdrive

In [1]:
import os
os.chdir('/content/drive/MyDrive/pbg-shortcuts/rna/aidrugsx-siamese network')

In [2]:
!ls

condacolab_install.log	       notes.gdoc	  rna_motif_emb.npy
final_attract_db_with_emb.csv  rbp_seqs_dict.pkl  siamese.ipynb


In [3]:
!pip install fair-esm



In [4]:
##setting up ESM
import torch
import esm

esm_model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
esm_model.eval()  # disables dropout for deterministic results
esm_model.cuda() #push model to gpu

ESM2(
  (embed_tokens): Embedding(33, 1280, padding_idx=1)
  (layers): ModuleList(
    (0-32): 33 x TransformerLayer(
      (self_attn): MultiheadAttention(
        (k_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (v_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (q_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (out_proj): Linear(in_features=1280, out_features=1280, bias=True)
        (rot_emb): RotaryEmbedding()
      )
      (self_attn_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
      (fc1): Linear(in_features=1280, out_features=5120, bias=True)
      (fc2): Linear(in_features=5120, out_features=1280, bias=True)
      (final_layer_norm): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
    )
  )
  (contact_head): ContactPredictionHead(
    (regression): Linear(in_features=660, out_features=1, bias=True)
    (activation): Sigmoid()
  )
  (emb_layer_norm_after): LayerNorm((1280,), eps=1

## generate dataset

In [6]:
import pandas as pd

In [7]:
data = pd.read_csv('final_attract_db_with_emb.csv')

In [8]:
import pickle
# Load the dictionary back from the pickle file.
with open("rbp_seqs_dict.pkl", "rb") as f:
    rbp_seqs_dict = pickle.load(f)

  return torch.load(io.BytesIO(b))


In [9]:
import numpy as np
rna_motif_emb = np.load('rna_motif_emb.npy', allow_pickle=True)

In [10]:
len(rna_motif_emb[1])

7

In [11]:
data = data.drop(columns=['rna_motif_emb', 'rbp_esm_emb'])

In [12]:
data['rna_motif_emb'] = rna_motif_emb

In [13]:
data['rbp_esm_emb'] = data['RBP_sequence'].map(rbp_seqs_dict)

In [14]:
# Convert list of tensors to numpy array
def tensors_to_numpy(tensor_list):
    return np.stack([t.numpy() for t in tensor_list])

# Apply the conversion to the 'rbp_esm_emb' column
data['rbp_esm_emb'] = data['rbp_esm_emb'].apply(tensors_to_numpy)

In [15]:
data.head()

Unnamed: 0,Gene_name,Gene_id,Motif,RBP_sequence,rna_motif_emb,rbp_esm_emb
0,A1CF,ENSG00000148584,UGAUCAGUAUA,MESNHKSGDGLSGTQKEAALRALVQRTGYSLVQENGQRKYGGPPPG...,"[[0.02515462040901184, -0.1693081259727478, 1....","[[0.102689214, -0.18220823, -0.05008613, 0.156..."
1,A1CF,ENSG00000148584,AUAAUUA,MESNHKSGDGLSGTQKEAALRALVQRTGYSLVQENGQRKYGGPPPG...,"[[-0.005481339991092682, -0.008752591907978058...","[[0.102689214, -0.18220823, -0.05008613, 0.156..."
2,A1CF,ENSG00000148584,UUAAUUA,MESNHKSGDGLSGTQKEAALRALVQRTGYSLVQENGQRKYGGPPPG...,"[[-0.006157027557492256, -0.04564734920859337,...","[[0.102689214, -0.18220823, -0.05008613, 0.156..."
3,A1CF,ENSG00000148584,AUAAUUG,MESNHKSGDGLSGTQKEAALRALVQRTGYSLVQENGQRKYGGPPPG...,"[[0.0016195997595787048, 0.0697508156299591, 1...","[[0.102689214, -0.18220823, -0.05008613, 0.156..."
4,A1CF,ENSG00000148584,UUAAUUG,MESNHKSGDGLSGTQKEAALRALVQRTGYSLVQENGQRKYGGPPPG...,"[[0.02757852151989937, 0.006868686527013779, 3...","[[0.102689214, -0.18220823, -0.05008613, 0.156..."


## mmseqs2 on dataset

In [None]:
# install condacolab
!pip install -q condacolab
import condacolab
condacolab.install()

✨🍰✨ Everything looks OK!


In [None]:
!conda install -c conda-forge -c bioconda mmseqs2

In [None]:
!mmseqs createdb sequences.fasta DB

SyntaxError: invalid syntax (<ipython-input-2-10b17da785be>, line 1)

## utils

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import random
import numpy as np
import esm

In [6]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.num_heads = config.num_heads
        self.d_model = config.d_model
        self.d_k = config.d_model // config.num_heads

        self.q_linear = nn.Linear(config.d_model, config.d_model)
        self.k_linear = nn.Linear(config.d_model, config.d_model)
        self.v_linear = nn.Linear(config.d_model, config.d_model)
        self.out = nn.Linear(config.d_model, config.d_model)

    def forward(self, x):
        bs = x.size(0)
        q = self.q_linear(x).view(bs, -1, self.num_heads, self.d_k).transpose(1, 2)
        k = self.k_linear(x).view(bs, -1, self.num_heads, self.d_k).transpose(1, 2)
        v = self.v_linear(x).view(bs, -1, self.num_heads, self.d_k).transpose(1, 2)

        scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
        attn = F.softmax(scores, dim=-1)
        context = torch.matmul(attn, v)

        context = context.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
        output = self.out(context)
        return output

In [7]:
contrastive_loss = nn.CosineEmbeddingLoss()
ce_loss = nn.CrossEntropyLoss(ignore_index=0)

## siamese network

In [8]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import random
import numpy as np

In [9]:
class SiameseNetwork(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.fusion_dim,
            nhead=config.num_heads,
            dim_feedforward=config.d_ff,
            dropout=config.dropout_rate,
            activation=F.gelu,
            batch_first=True,
            norm_first=True
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=config.num_layers)
        self.attention_pool = nn.MultiheadAttention(config.fusion_dim, config.num_heads, batch_first=True)

        self.projection_layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(config.fusion_dim, config.fusion_dim),
                nn.LayerNorm(config.fusion_dim),
                nn.GELU(),
                nn.Dropout(config.dropout_rate)
            ) for _ in range(4)
        ])

        self.latent_projection = nn.Linear(config.fusion_dim, config.latent_dim)
        self.final_layer_norm = nn.LayerNorm(config.latent_dim)
        self.dropout = nn.Dropout(config.dropout_rate)

    def forward(self, x):
        x = self.encoder(x)
        attn_output, attn_weights = self.attention_pool(x, x, x)
        weighted_sum = x + torch.bmm(attn_weights, x)

        for layer in self.projection_layers:
            weighted_sum = layer(weighted_sum) + weighted_sum  # Residual connection

        pooled = weighted_sum.mean(dim=1)
        latent = self.latent_projection(pooled)
        latent = self.final_layer_norm(self.dropout(F.gelu(latent)))
        return latent

class SiameseDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.decoder_layer = nn.TransformerDecoderLayer(
            d_model=config.latent_dim,
            nhead=config.num_heads,
            dim_feedforward=config.d_ff,
            dropout=config.dropout_rate,
            activation=F.gelu,
            batch_first=True,
            norm_first=True
        )
        self.decoder = nn.TransformerDecoder(self.decoder_layer, num_layers=config.num_decoder_layers)

        # Custom output projection
        self.output_projection = nn.Sequential(
            nn.Linear(config.latent_dim, config.latent_dim),
            nn.LayerNorm(config.latent_dim),
            nn.GELU(),
            nn.Dropout(config.dropout_rate),
            nn.Linear(config.latent_dim, config.latent_dim),
            nn.LayerNorm(config.latent_dim),
            nn.GELU(),
            nn.Dropout(config.dropout_rate),
            nn.Linear(config.latent_dim, config.vocab_size)
        )

        # Projection to 1280 dimensions for ESM-2 compatibility
        self.esm_projection = nn.Sequential(
            nn.Linear(config.latent_dim, config.latent_dim),
            nn.LayerNorm(config.latent_dim),
            nn.GELU(),
            nn.Dropout(config.dropout_rate),
            nn.Linear(config.latent_dim, config.latent_dim),
            nn.LayerNorm(config.latent_dim),
            nn.GELU(),
            nn.Dropout(config.dropout_rate),
            nn.Linear(config.latent_dim, 1280)
        )


        # ESM-2 language model head
        esm_model, _ = esm.pretrained.esm2_t33_650M_UR50D()
        self.lm_head = esm_model.lm_head

    def forward(self, latent, use_lm_head=True):
        batch_size = latent.size(0)
        seq_len = self.config.max_protein_length

        decoded = self.decoder(
            tgt=torch.zeros(batch_size, seq_len, self.config.latent_dim).to(latent.device),
            memory=latent.unsqueeze(1).repeat(1, seq_len, 1)
        )

        if use_lm_head:
            decoded_projected = self.esm_projection(decoded)
            logits = self.lm_head(decoded_projected)
        else:
            logits = self.output_projection(decoded)

        return logits

class CombinedModel(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.rna_projection = self.create_projection_layers(config.rna_embed_dim, config.fusion_dim)
        self.protein_projection = self.create_projection_layers(config.protein_embed_dim, config.fusion_dim)
        self.siamese_network = SiameseNetwork(config)
        self.siamese_decoder = SiameseDecoder(config)
        esm_model, _ = esm.pretrained.esm2_t33_650M_UR50D()
        self.lm_head = esm_model.lm_head

    def create_projection_layers(self, input_dim, output_dim):
        layers = []
        current_dim = input_dim
        for _ in range(3):  # 3 intermediate layers
            layers.extend([
                nn.Linear(current_dim, output_dim),
                nn.LayerNorm(output_dim),
                nn.GELU(),
                nn.Dropout(self.config.dropout_rate)
            ])
            current_dim = output_dim
        layers.append(nn.Linear(current_dim, output_dim))  # Final projection
        return nn.Sequential(*layers)

    def forward(self, rna_emb, protein_emb=None, use_lm_head=True):
        rna_proj = self.rna_projection(rna_emb)
        rna_latent = self.siamese_network(rna_proj)
        print('rna siamese done...')
        # print('rna_proj:',rna_proj.shape)
        # print('rna_latent:',rna_latent.shape)
        # print()

        if protein_emb is not None:
            protein_proj = self.protein_projection(protein_emb)
            protein_latent = self.siamese_network(protein_proj)
        else:
            protein_latent = None

        print('protein siamese done...')
        # print('protein_proj:',protein_proj.shape)
        # print('protein_latent:',protein_latent.shape)
        # print()

        # print('decoder starting inputs...')
        # print('rna_latent:',rna_latent.shape)
        # print('use_lm_head:',use_lm_head)

        generated_sequence = self.siamese_decoder(rna_latent, use_lm_head)
        print('decoder done...')
        print()

        return rna_latent, protein_latent, generated_sequence

def generate_peptides(model, token_representations, num_samples, sample_variances):
    generated_peptides = []
    aa_toks = list("ARNDCEQGHILKMFPSTWYV")
    aa_idxs = [alphabet.get_idx(aa) for aa in aa_toks]

    for i in sample_variances:
        for j in range(num_samples):
            gen_pep = token_representations + torch.randn(token_representations.shape) * i * token_representations.var()
            aa_logits = model.lm_head(gen_pep.cuda())[:, :, aa_idxs]
            predictions = torch.argmax(aa_logits, dim=2).tolist()[0]
            generated_pep_seq = "".join([aa_toks[i] for i in predictions])
            generated_peptides.append(generated_pep_seq[1:-1])

    return generated_peptides

## training (dummy dataset)



In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import random
import numpy as np
import esm
import torch.nn.functional as F

class Config:
    vocab_size = 33  # ESM-2 vocabulary size
    pad_token = 0
    start_token = 1
    end_token = 2
    rna_embed_dim = 120
    protein_embed_dim = 1280
    fusion_dim = 512
    latent_dim = 256
    d_model = 512
    num_heads = 8
    num_layers = 6
    num_decoder_layers = 6
    d_ff = 2048
    dropout_rate = 0.1
    max_rna_length = 100
    max_protein_length = 200
    batch_size = 32
    num_epochs = 10
    learning_rate = 3e-4

config = Config()

class DummyDataset(Dataset):
    def __init__(self, num_samples=1000):
        self.num_samples = num_samples
        self.data = self.generate_dummy_data()

    def generate_dummy_data(self):
        data = []
        for _ in range(self.num_samples):
            rna_length = random.randint(20, config.max_rna_length)
            protein_length = random.randint(20, config.max_protein_length)
            rna_emb = torch.randn(rna_length, config.rna_embed_dim)
            protein_emb = torch.randn(protein_length, config.protein_embed_dim)
            protein_seq = torch.randint(0, config.vocab_size, (protein_length,))
            label = random.randint(0, 1)
            data.append((rna_emb, protein_emb, protein_seq, label))
        return data

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx]

def collate_fn(batch):
    rna_embs, protein_embs, protein_seqs, labels = zip(*batch)

    max_rna_len = max(emb.size(0) for emb in rna_embs)
    max_protein_len = max(emb.size(0) for emb in protein_embs)
    print(max_protein_len)

    padded_rna_emb = torch.zeros(len(batch), max_rna_len, config.rna_embed_dim)
    padded_protein_emb = torch.zeros(len(batch), max_protein_len, config.protein_embed_dim)
    padded_protein_seq = torch.full((len(batch), max_protein_len), config.vocab_size - 1)

    for i, (rna_emb, protein_emb, protein_seq) in enumerate(zip(rna_embs, protein_embs, protein_seqs)):
        padded_rna_emb[i, :rna_emb.size(0)] = rna_emb
        padded_protein_emb[i, :protein_emb.size(0)] = protein_emb
        padded_protein_seq[i, :protein_seq.size(0)] = protein_seq

    return padded_rna_emb, padded_protein_emb, padded_protein_seq, torch.tensor(labels)

train_dataset = DummyDataset()
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn)

model = CombinedModel(config)
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

contrastive_loss = nn.CosineEmbeddingLoss()
ce_loss = nn.CrossEntropyLoss(ignore_index=config.vocab_size - 1)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print('model prep done...')

model prep done...


In [19]:
for epoch in range(config.num_epochs):
    model.train()
    total_loss=0
    for rna_embs,protein_embs,target_seqs,labels in train_loader:
        rna_embs,protein_embs,target_seqs,labels=rna_embs.to(device),protein_embs.to(device),target_seqs.to(device),labels.to(device)

        optimizer.zero_grad()

        rna_latent,protein_latent,generated_sequence=model(rna_embs,protein_embs,use_lm_head=True)
        print("generated sequence",generated_sequence.shape)
        print("target seqs",target_seqs.shape)

        loss_contrastive=contrastive_loss(rna_latent,protein_latent,(labels*2-1).float())
        loss_generation=ce_loss(generated_sequence.view(-1,config.vocab_size),target_seqs.view(-1))

        loss=loss_contrastive+loss_generation
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(),1.0)

        optimizer.step()

        total_loss+=loss.item()

    avg_loss=total_loss/len(train_loader)
    print(f"Epoch {epoch+1}/{config.num_epochs}, Loss: {avg_loss:.4f}")


184
rna siamese done...
protein siamese done...
decoder done...

generated sequence torch.Size([32, 200, 33])
target seqs torch.Size([32, 184])


ValueError: Expected input batch_size (6400) to match target batch_size (5888).