# Project MOML - RBP Siamese

RNA-RBP Interaction Prediction and RBP Generation Process:

Input (RNA + RBP sequences)

  ↓

a) Embedding Layer (Fusion of rnabert+ESM)

  ↓

b) T5 Encoder Stacks (RNA and RBP)

  ↓

c) Latent Space Projection
  
  ↓

d) Cross-Attention Layer

  ↓

e) Fusion Layer

  ↓

f) T5 Decoder Stack

  ↓

g) Language Modeling Head (ESM)

  ↓

Output (Interaction Prediction + Generated RBP sequence)

Detailed Description:

The process begins with input RNA and RBP sequences, which are tokenized and converted to integer indices. These indices are passed through a shared embedding layer (a) that transforms them into dense vector representations. The embedded sequences then enter separate but identical T5 Encoder Stacks (b), each consisting of multiple layers of self-attention mechanisms and feed-forward networks. These encoders capture the contextual information within each sequence. The encoder outputs are projected into a shared latent space (c), where a cross-attention layer (d) allows for interaction between the RNA and RBP representations. This interaction is crucial for capturing the relationship between RNA motifs and their binding proteins. The cross-attended representations are then combined in a fusion layer (e), producing a single representation that encapsulates the joint RNA-RBP information. This fused representation serves as input to the T5 Decoder Stack (f), which generates the RBP sequence through a series of self-attention, cross-attention, and feed-forward layers. Finally, a language modeling head (g) projects the decoder output to the vocabulary space, producing probabilities for each amino acid at each position. The model outputs both an interaction prediction (derived from the similarity of the latent representations) and a generated RBP sequence. Throughout this process, the model learns to encode meaningful representations of RNA and RBP sequences, predict their interactions, and generate novel RBP sequences tailored to specific RNA inputs.

##gdrive

In [3]:
import os
os.chdir('/content/drive/MyDrive/rna/aidrugsx-siamese network')

In [4]:
!ls

final_attract_db_with_emb.csv  rbp_seqs_dict.pkl  rna_motif_emb.npy  siamese.ipynb


## generate dataset

In [5]:
import pandas as pd

In [6]:
data = pd.read_csv('final_attract_db_with_emb.csv')

In [7]:
import pickle
# Load the dictionary back from the pickle file.
with open("rbp_seqs_dict.pkl", "rb") as f:
    rbp_seqs_dict = pickle.load(f)

  return torch.load(io.BytesIO(b))


In [8]:
import numpy as np
rna_motif_emb = np.load('rna_motif_emb.npy', allow_pickle=True)

In [9]:
len(rna_motif_emb[1])

7

In [10]:
data = data.drop(columns=['rna_motif_emb', 'rbp_esm_emb'])

In [11]:
data['rna_motif_emb'] = rna_motif_emb

In [12]:
data['rbp_esm_emb'] = data['RBP_sequence'].map(rbp_seqs_dict)

In [13]:
# Convert list of tensors to numpy array
def tensors_to_numpy(tensor_list):
    return np.stack([t.numpy() for t in tensor_list])

# Apply the conversion to the 'rbp_esm_emb' column
data['rbp_esm_emb'] = data['rbp_esm_emb'].apply(tensors_to_numpy)

In [14]:
data.head()

Unnamed: 0,Gene_name,Gene_id,Motif,RBP_sequence,rna_motif_emb,rbp_esm_emb
0,A1CF,ENSG00000148584,UGAUCAGUAUA,MESNHKSGDGLSGTQKEAALRALVQRTGYSLVQENGQRKYGGPPPG...,"[[0.02515462040901184, -0.1693081259727478, 1....","[[0.102689214, -0.18220823, -0.05008613, 0.156..."
1,A1CF,ENSG00000148584,AUAAUUA,MESNHKSGDGLSGTQKEAALRALVQRTGYSLVQENGQRKYGGPPPG...,"[[-0.005481339991092682, -0.008752591907978058...","[[0.102689214, -0.18220823, -0.05008613, 0.156..."
2,A1CF,ENSG00000148584,UUAAUUA,MESNHKSGDGLSGTQKEAALRALVQRTGYSLVQENGQRKYGGPPPG...,"[[-0.006157027557492256, -0.04564734920859337,...","[[0.102689214, -0.18220823, -0.05008613, 0.156..."
3,A1CF,ENSG00000148584,AUAAUUG,MESNHKSGDGLSGTQKEAALRALVQRTGYSLVQENGQRKYGGPPPG...,"[[0.0016195997595787048, 0.0697508156299591, 1...","[[0.102689214, -0.18220823, -0.05008613, 0.156..."
4,A1CF,ENSG00000148584,UUAAUUG,MESNHKSGDGLSGTQKEAALRALVQRTGYSLVQENGQRKYGGPPPG...,"[[0.02757852151989937, 0.006868686527013779, 3...","[[0.102689214, -0.18220823, -0.05008613, 0.156..."


## fusion layer

In [14]:
# fusion

In [None]:
class EnhancedEmbeddingFusionModule(nn.Module):
    def __init__(self, rna_dim=120, protein_dim=1280, fusion_dim=512, num_heads=8):
        super(EnhancedEmbeddingFusionModule, self).__init__()

        self.rna_projection = nn.Linear(rna_dim, fusion_dim)
        self.protein_projection = nn.Linear(protein_dim, fusion_dim)

        self.cross_attention = MultiHeadAttention(fusion_dim, num_heads)

        self.fusion_layer = nn.Sequential(
            nn.Linear(fusion_dim * 2, fusion_dim),
            nn.LayerNorm(fusion_dim),
            nn.ReLU(),
            nn.Linear(fusion_dim, fusion_dim)
        )

        self.enhanced_gating_mechanism = EnhancedGatingMechanism(fusion_dim)

    def forward(self, rna_emb, protein_emb):
        rna_proj = self.rna_projection(rna_emb)
        protein_proj = self.protein_projection(protein_emb)

        rna_attended = self.cross_attention(rna_proj, protein_proj, protein_proj)
        protein_attended = self.cross_attention(protein_proj, rna_proj, rna_proj)

        concat_features = torch.cat([rna_attended, protein_attended], dim=-1)
        fused_features = self.fusion_layer(concat_features)

        gated_output = self.enhanced_gating_mechanism(rna_attended, protein_attended, fused_features)

        return gated_output

class EnhancedGatingMechanism(nn.Module):
    def __init__(self, fusion_dim):
        super(EnhancedGatingMechanism, self).__init__()

        self.rna_gate = nn.Sequential(
            nn.Linear(fusion_dim, fusion_dim),
            nn.LayerNorm(fusion_dim),
            nn.ReLU(),
            nn.Linear(fusion_dim, fusion_dim),
            nn.Sigmoid()
        )

        self.protein_gate = nn.Sequential(
            nn.Linear(fusion_dim, fusion_dim),
            nn.LayerNorm(fusion_dim),
            nn.ReLU(),
            nn.Linear(fusion_dim, fusion_dim),
            nn.Sigmoid()
        )

        self.fusion_gate = nn.Sequential(
            nn.Linear(fusion_dim, fusion_dim),
            nn.LayerNorm(fusion_dim),
            nn.ReLU(),
            nn.Linear(fusion_dim, fusion_dim),
            nn.Sigmoid()
        )

        self.final_gate = nn.Sequential(
            nn.Linear(fusion_dim * 3, fusion_dim),
            nn.LayerNorm(fusion_dim),
            nn.ReLU(),
            nn.Linear(fusion_dim, 3),
            nn.Softmax(dim=-1)
        )

        self.output_projection = nn.Sequential(
            nn.Linear(fusion_dim, fusion_dim),
            nn.LayerNorm(fusion_dim),
            nn.ReLU(),
            nn.Linear(fusion_dim, fusion_dim)
        )

    def forward(self, rna_features, protein_features, fused_features):
        rna_gate = self.rna_gate(rna_features)
        protein_gate = self.protein_gate(protein_features)
        fusion_gate = self.fusion_gate(fused_features)

        gated_rna = rna_gate * rna_features
        gated_protein = protein_gate * protein_features
        gated_fusion = fusion_gate * fused_features

        combined_features = torch.cat([gated_rna, gated_protein, gated_fusion], dim=-1)
        final_gate = self.final_gate(combined_features)

        gated_output = (final_gate[:,:,0].unsqueeze(-1) * gated_rna +
                        final_gate[:,:,1].unsqueeze(-1) * gated_protein +
                        final_gate[:,:,2].unsqueeze(-1) * gated_fusion)

        return self.output_projection(gated_output)

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        super(MultiHeadAttention, self).__init__()
        assert d_model % num_heads == 0

        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads

        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.d_k, dtype=torch.float32))
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = F.softmax(attn_scores, dim=-1)
        output = torch.matmul(attn_probs, V)
        return output

    def forward(self, Q, K, V, mask=None):
        batch_size = Q.size(0)

        Q = self.W_q(Q).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        K = self.W_k(K).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
        V = self.W_v(V).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)

        output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = output.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)
        return self.W_o(output)


## siamese t5

Architecture Brief:
The Siamese T5 Generator is an innovative architecture designed for RNA-protein binding prediction and generation. It combines the power of the T5 (Text-to-Text Transfer Transformer) model with a Siamese network structure, enabling both sequence comparison and generation tasks.
The core of the architecture consists of two identical T5-based encoder stacks that process RNA and protein sequences independently. These encoders utilize self-attention mechanisms and feed-forward networks, incorporating relative position embeddings for enhanced spatial awareness. The encoded representations are projected into a shared latent space, where a cross-attention mechanism facilitates information exchange between RNA and protein features. This interaction is crucial for capturing the nuanced relationships between RNA motifs and their binding proteins.
The decoder stack, also based on the T5 architecture, takes the fused representation of RNA and protein information as input. It employs a combination of self-attention and cross-attention layers to generate novel protein sequences conditioned on the input RNA motif. The use of gated feed-forward layers and layer normalization throughout the network enhances its expressive power and training stability. The final output is produced through a language modeling head, enabling the model to generate amino acid sequences. This architecture not only allows for binding prediction but also for the de novo design of proteins tailored to specific RNA motifs.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SiameseT5Generator(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.shared = nn.Embedding(config.vocab_size, config.d_model)
        self.encoder_stack = T5Stack(config, is_decoder=False)
        self.decoder_stack = T5Stack(config, is_decoder=True)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        self.latent_projector = nn.Linear(config.d_model, config.latent_dim)
        self.fusion_layer = nn.Linear(config.latent_dim * 2, config.latent_dim)

    def forward(self, rna_ids, rbp_ids, decoder_input_ids=None):
        rna_emb = self.shared(rna_ids)
        rbp_emb = self.shared(rbp_ids)

        rna_enc = self.encoder_stack(rna_emb)
        rbp_enc = self.encoder_stack(rbp_emb)

        rna_latent = self.latent_projector(rna_enc[0])
        rbp_latent = self.latent_projector(rbp_enc[0])

        fused = self.fusion_layer(torch.cat([rna_latent, rbp_latent], dim=-1))

        if decoder_input_ids is None:
            decoder_input_ids = self._shift_right(rbp_ids)

        decoder_outputs = self.decoder_stack(
            input_ids=decoder_input_ids,
            encoder_hidden_states=fused,
            encoder_attention_mask=None
        )

        lm_logits = self.lm_head(decoder_outputs[0])

        return rna_latent, rbp_latent, lm_logits

    def _shift_right(self, input_ids):
        shifted_input_ids = input_ids.new_zeros(input_ids.shape)
        shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
        shifted_input_ids[..., 0] = self.config.decoder_start_token_id
        return shifted_input_ids

class T5Stack(nn.Module):
    def __init__(self, config, is_decoder=False):
        super().__init__()
        self.is_decoder = is_decoder
        self.block = nn.ModuleList([T5Block(config, has_relative_attention_bias=bool(i == 0)) for i in range(config.num_layers)])
        self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
        self.dropout = nn.Dropout(config.dropout_rate)

    def forward(self, input_ids=None, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
        hidden_states = self.dropout(self.shared(input_ids))

        for i, layer_module in enumerate(self.block):
            layer_outputs = layer_module(
                hidden_states,
                attention_mask=attention_mask,
                encoder_hidden_states=encoder_hidden_states,
                encoder_attention_mask=encoder_attention_mask,
            )
            hidden_states = layer_outputs[0]

        hidden_states = self.final_layer_norm(hidden_states)
        return (hidden_states,)

class T5Block(nn.Module):
    def __init__(self, config, has_relative_attention_bias=False):
        super().__init__()
        self.is_decoder = config.is_decoder
        self.layer = nn.ModuleList()
        self.layer.append(T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias))
        if self.is_decoder:
            self.layer.append(T5LayerCrossAttention(config))
        self.layer.append(T5LayerFF(config))

    def forward(self, hidden_states, attention_mask=None, encoder_hidden_states=None, encoder_attention_mask=None):
        self_attention_outputs = self.layer[0](hidden_states, attention_mask=attention_mask)
        hidden_states = self_attention_outputs[0]

        if self.is_decoder and encoder_hidden_states is not None:
            cross_attention_outputs = self.layer[1](hidden_states, encoder_hidden_states, encoder_attention_mask)
            hidden_states = cross_attention_outputs[0]

        feed_forward_outputs = self.layer[-1](hidden_states)
        hidden_states = feed_forward_outputs[0]

        return (hidden_states,)

class T5LayerSelfAttention(nn.Module):
    def __init__(self, config, has_relative_attention_bias=False):
        super().__init__()
        self.SelfAttention = T5Attention(config, has_relative_attention_bias=has_relative_attention_bias)
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
        self.dropout = nn.Dropout(config.dropout_rate)

    def forward(self, hidden_states, attention_mask=None):
        normed_hidden_states = self.layer_norm(hidden_states)
        attention_output = self.SelfAttention(normed_hidden_states, mask=attention_mask)
        hidden_states = hidden_states + self.dropout(attention_output[0])
        return (hidden_states,)

class T5LayerCrossAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False)
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
        self.dropout = nn.Dropout(config.dropout_rate)

    def forward(self, hidden_states, encoder_hidden_states, encoder_attention_mask):
        normed_hidden_states = self.layer_norm(hidden_states)
        attention_output = self.EncDecAttention(
            normed_hidden_states,
            mask=encoder_attention_mask,
            key_value_states=encoder_hidden_states
        )
        layer_output = hidden_states + self.dropout(attention_output[0])
        return (layer_output,)

class T5LayerFF(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.DenseReluDense = T5DenseGatedActDense(config)
        self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
        self.dropout = nn.Dropout(config.dropout_rate)

    def forward(self, hidden_states):
        forwarded_states = self.layer_norm(hidden_states)
        forwarded_states = self.DenseReluDense(forwarded_states)
        hidden_states = hidden_states + self.dropout(forwarded_states)
        return (hidden_states,)

class T5DenseGatedActDense(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
        self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
        self.dropout = nn.Dropout(config.dropout_rate)
        self.act = F.gelu

    def forward(self, hidden_states):
        hidden_gelu = self.act(self.wi_0(hidden_states))
        hidden_linear = self.wi_1(hidden_states)
        hidden_states = hidden_gelu * hidden_linear
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.wo(hidden_states)
        return hidden_states

class T5LayerNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states

class T5Attention(nn.Module):
    def __init__(self, config, has_relative_attention_bias=False):
        super().__init__()
        self.is_decoder = config.is_decoder
        self.has_relative_attention_bias = has_relative_attention_bias
        self.relative_attention_num_buckets = config.relative_attention_num_buckets
        self.d_model = config.d_model
        self.key_value_proj_dim = config.d_kv
        self.n_heads = config.num_heads
        self.dropout = config.dropout_rate
        self.inner_dim = self.n_heads * self.key_value_proj_dim

        self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
        self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)

        if self.has_relative_attention_bias:
            self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)

    def forward(self, hidden_states, mask=None, key_value_states=None, position_bias=None):
        batch_size, seq_length = hidden_states.shape[:2]

        def shape(states):
            return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)

        def unshape(states):
            return states.transpose(1, 2).contiguous().view(batch_size, -1, self.inner_dim)

        query_states = shape(self.q(hidden_states))

        if key_value_states is None:
            key_states = shape(self.k(hidden_states))
            value_states = shape(self.v(hidden_states))
        else:
            key_states = shape(self.k(key_value_states))
            value_states = shape(self.v(key_value_states))

        scores = torch.matmul(query_states, key_states.transpose(3, 2))

        if position_bias is None:
            if not self.has_relative_attention_bias:
                position_bias = torch.zeros(
                    (1, self.n_heads, seq_length, seq_length),
                    device=scores.device,
                    dtype=scores.dtype
                )
            else:
                position_bias = self.compute_bias(seq_length)

            if mask is not None:
                position_bias = position_bias + mask

        scores += position_bias
        attn_weights = F.softmax(scores.float(), dim=-1).type_as(scores)
        attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)

        attn_output = unshape(torch.matmul(attn_weights, value_states))
        attn_output = self.o(attn_output)

        return (attn_output, position_bias)

    def compute_bias(self, seq_length):
        context_position = torch.arange(seq_length, dtype=torch.long)[:, None]
        memory_position = torch.arange(seq_length, dtype=torch.long)[None, :]
        relative_position = memory_position - context_position
        relative_position_bucket = self._relative_position_bucket(
            relative_position,
            bidirectional=(not self.is_decoder),
            num_buckets=self.relative_attention_num_buckets
        )
        values = self.relative_attention_bias(relative_position_bucket)
        values = values.permute([2, 0, 1]).unsqueeze(0)
        return values

    @staticmethod
    def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
        relative_buckets = 0
        if bidirectional:
            num_buckets //= 2
            relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
            relative_position = torch.abs(relative_position)
        else:
            relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
        max_exact = num_buckets // 2
        is_small = relative_position < max_exact
        relative_postion_if_large = max_exact + (
            torch.log(relative_position.float() / max_exact)
            / math.log(max_distance / max_exact)
            * (num_buckets - max_exact)
        ).to(torch.long)
        relative_postion_if_large = torch.min(
            relative_postion_if_large, torch.full_like(relative_postion_if_large, num_buckets - 1)
        )
        relative_buckets += torch.where(is_small, relative_position, relative_postion_if_large)
        return relative_buckets

## training (dummy dataset)

dummy dataset ex.

RNA sequence: "AUGGCUAUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCU"
RBP sequence: "MKVILWAALVITFLAGCQAETEPEPELRQQTEWQSGQRWELALGRFWDYLRWVQTLSEQVQEELLSSQVTQELRALM"
Label: 1 (positive interaction)

RNA sequence: "CGAUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCU"
RBP sequence: "MKVLWAALLVTFLAGCQAKVEQAVETEPEPELRQQTEWQSGQRWELALGRFWDYLRWVQTLSEQVQEELLSSQVTQ"
Label: 0 (negative interaction)

RNA sequence: "UAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGCUAGC"
RBP sequence: "MELKAYKSELEEQLTPVAEETRARLSKELQAAQARLGADVLASHGRLVQYRGEVQAMLGQSTEELRVRLASHLRKL"
Label: 1 (positive interaction)

In these examples:

RNA sequences are strings of A, U, G, and C.
RBP sequences are strings of amino acid single-letter codes.
The label is either 0 (no interaction) or 1 (interaction).
Sequence lengths vary but are within the specified maximum length.

In the actual implementation, these would be converted to tensor representations of integer indices corresponding to each nucleotide or amino acid (tokenization).

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import random
import numpy as np

class Config:
    vocab_size = 30
    d_model = 512
    d_ff = 2048
    num_layers = 6
    num_heads = 8
    dropout_rate = 0.1
    layer_norm_epsilon = 1e-6
    max_seq_length = 100
    latent_dim = 256
    num_decoder_layers = 6
    decoder_start_token_id = 0
    batch_size = 32
    num_epochs = 10
    learning_rate = 3e-4

config = Config()

class DummyDataset(Dataset):
    def __init__(self, num_samples=1000, max_length=100):
        self.num_samples = num_samples
        self.max_length = max_length
        self.data = self.generate_dummy_data()

    def generate_dummy_data(self):
        data = []
        for _ in range(self.num_samples):
            rna_length = random.randint(20, self.max_length)
            rbp_length = random.randint(20, self.max_length)
            rna = torch.randint(1, config.vocab_size, (rna_length,))
            rbp = torch.randint(1, config.vocab_size, (rbp_length,))
            label = random.randint(0, 1)
            data.append((rna, rbp, label))
        return data

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        return self.data[idx]

def collate_fn(batch):
    rna_seqs, rbp_seqs, labels = zip(*batch)
    rna_lengths = [len(seq) for seq in rna_seqs]
    rbp_lengths = [len(seq) for seq in rbp_seqs]
    max_rna_len = max(rna_lengths)
    max_rbp_len = max(rbp_lengths)

    padded_rna = torch.zeros(len(batch), max_rna_len, dtype=torch.long)
    padded_rbp = torch.zeros(len(batch), max_rbp_len, dtype=torch.long)

    for i, (rna, rbp) in enumerate(zip(rna_seqs, rbp_seqs)):
        padded_rna[i, :len(rna)] = rna
        padded_rbp[i, :len(rbp)] = rbp

    return padded_rna, padded_rbp, torch.tensor(labels)

train_dataset = DummyDataset()
train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, collate_fn=collate_fn)

model = SiameseT5Generator(config)
optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)

contrastive_loss = nn.CosineEmbeddingLoss()
ce_loss = nn.CrossEntropyLoss(ignore_index=0)

def train_epoch(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0

    for rna_ids, rbp_ids, labels in dataloader:
        rna_ids, rbp_ids, labels = rna_ids.to(device), rbp_ids.to(device), labels.to(device)

        optimizer.zero_grad()

        rna_latent, rbp_latent, lm_logits = model(rna_ids, rbp_ids)

        loss_contrastive = contrastive_loss(rna_latent, rbp_latent, (labels * 2 - 1).float())

        loss_ce = ce_loss(lm_logits.view(-1, config.vocab_size), rbp_ids.view(-1))

        loss = loss_contrastive + loss_ce
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

for epoch in range(config.num_epochs):
    avg_loss = train_epoch(model, train_loader, optimizer, device)
    print(f"Epoch {epoch+1}/{config.num_epochs}, Loss: {avg_loss:.4f}")

def generate_rbp(model, rna_sequence, max_length=100):
    model.eval()
    with torch.no_grad():
        rna_ids = rna_sequence.unsqueeze(0).to(device)
        decoder_input = torch.tensor([[config.decoder_start_token_id]]).to(device)

        for _ in range(max_length):
            rna_latent, _, lm_logits = model(rna_ids, decoder_input)
            next_token = lm_logits[:, -1, :].argmax(dim=-1).unsqueeze(-1)
            decoder_input = torch.cat([decoder_input, next_token], dim=-1)

            if next_token.item() == config.decoder_start_token_id:
                break

    return decoder_input.squeeze().cpu().numpy()

test_rna = torch.randint(1, config.vocab_size, (50,)).to(device)
generated_rbp = generate_rbp(model, test_rna)
print("Generated RBP sequence:", generated_rbp)

def predict_binding(model, rna_sequence, rbp_sequence):
    model.eval()
    with torch.no_grad():
        rna_ids = rna_sequence.unsqueeze(0).to(device)
        rbp_ids = rbp_sequence.unsqueeze(0).to(device)
        rna_latent, rbp_latent, _ = model(rna_ids, rbp_ids)
        similarity = nn.functional.cosine_similarity(rna_latent, rbp_latent)
    return similarity.item()

test_rna = torch.randint(1, config.vocab_size, (50,)).to(device)
test_rbp = torch.randint(1, config.vocab_size, (50,)).to(device)
binding_score = predict_binding(model, test_rna, test_rbp)
print("Binding prediction score:", binding_score)