In [10]:
import torch
from torch import nn
from torch.utils.data import Dataset
import pandas as pd
from typing import List, Tuple
from tqdm import tqdm

# Tokenizers

In [11]:
class SequenceTokenizer:
    def __init__(self):
        self.VOCABULARY = {
            "A": 1,
            "U": 2,
            "C": 3,
            "G": 4,
        }
        self.INVERTED_VOCABULARY = {self.VOCABULARY[i]: i for i in self.VOCABULARY}


    def tokenize(self, sequence: str) -> torch.tensor:
        return torch.tensor([self.VOCABULARY[i] for i in sequence], dtype=torch.long)
    
    def detokenize(self, tokens: torch.tensor) -> str:
        return "".join([self.INVERTED_VOCABULARY[i] for i in tokens.tolist()])
    
class StructureTokenizer:
    def __init__(self):
        self.VOCABULARY = {
            ".": 0,
            "(": 1,
            ")": 2,
            "[": 3,
            "]": 4,
            "{": 5,
            "}": 6
        }
        self.INVERTED_VOCABULARY = {self.VOCABULARY[i]: i for i in self.VOCABULARY}


    def tokenize(self, structure: str) -> torch.tensor:
        return torch.tensor([self.VOCABULARY[i] for i in structure], dtype=torch.long)
    
    def detokenize(self, tokens: torch.tensor) -> str:
        return "".join([self.INVERTED_VOCABULARY[i] for i in tokens.tolist()])

# Prepare dataset
instead of predicting sequence it predicts pairwise interaction matrix

In [12]:
class RNADataset(Dataset):
    def __init__(self, path: str, indices: List[int]):
        """path: path to .csv file with sequences and structures"""
        super().__init__()
        
        self.data = pd.read_csv(path)
        self.data = self.data.iloc[indices].reset_index(drop=True)

        self.sequence_tokenizer = SequenceTokenizer()
        self.structure_tokenizer = StructureTokenizer()

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        row = self.data.iloc[index]
        sequence = row["sequence"]
        tokenized_sequence = self.sequence_tokenizer.tokenize(sequence)
        structure = row["structure"]
        tokenized_structure = self.structure_tokenizer.tokenize(structure)
        
        return sequence, tokenized_sequence, structure, tokenized_structure

In [13]:
# Add this new cell after imports
def collate_fn(batch):
    """Custom collate function to pad variable-length sequences"""
    sequences, tokenized_seqs, structures, interaction_matrices = zip(*batch)
    
    # Find max length in this batch
    max_len = max(len(seq) for seq in tokenized_seqs)
    
    # Pad tokenized sequences
    padded_seqs = []
    padding_masks = []
    padded_matrices = []
    
    for tokenized_seq, interaction_matrix in zip(tokenized_seqs, interaction_matrices):
        seq_len = len(tokenized_seq)
        
        # Pad sequence (0 is pad_idx)
        padded_seq = torch.cat([
            tokenized_seq,
            torch.zeros(max_len - seq_len, dtype=torch.long)
        ])
        padded_seqs.append(padded_seq)
        
        # Create padding mask (True for padded positions)
        mask = torch.cat([
            torch.zeros(seq_len, dtype=torch.bool),
            torch.ones(max_len - seq_len, dtype=torch.bool)
        ])
        padding_masks.append(mask)
        
        # Pad interaction matrix
        padded_matrix = torch.zeros(max_len, max_len, dtype=torch.bfloat16)
        padded_matrix[:seq_len, :seq_len] = interaction_matrix
        padded_matrices.append(padded_matrix)
    
    return (
        sequences,
        torch.stack(padded_seqs),
        structures,
        torch.stack(padded_matrices),
        torch.stack(padding_masks)
    )

# Create model 

In [14]:
class transformerRNA(nn.Module):
    def __init__(
        self, 
        hidden_dim: int=1000, 
        num_transformer_layers: int=10, 
        n_head: int=8, 
        dropout: float=0.1,
    ):
        super().__init__()

        # output (L, H)
        self.embedding = nn.Embedding(num_embeddings=5, embedding_dim=hidden_dim, padding_idx=0)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=n_head,
            batch_first=True
        )

        self.encoder = nn.TransformerEncoder(
            encoder_layer=encoder_layer,
            num_layers=num_transformer_layers
        )

        self.output_head = nn.Linear(hidden_dim, 7)

    def forward(self, x, src_key_padding_mask=None):
        x = self.embedding(x)
        # (L, H)
        x = self.encoder(x, src_key_padding_mask=src_key_padding_mask)
        # (L, H)
        x = self.output_head(x)
        return x

# Train model

In [41]:
class RNATrainer:
    def __init__(self, model: transformerRNA, device: torch.device):
        self.model = model
        self.device = device

        self.criterion = nn.CrossEntropyLoss()

        self.optimizer = torch.optim.Adam(self.model.parameters())

    def train_epoch(self, train_dataloader: torch.utils.data.DataLoader) -> float:
        self.model.train()
        total_loss = 0

        for sequence, tokenized_sequence, structure, tokenized_structure in tqdm(train_dataloader, desc="Training"):
            tokenized_sequence = tokenized_sequence.to(self.device)
            tokenized_structure = tokenized_structure.to(self.device)
            # print(tokenized_sequence)
            # padding_mask = padding_mask.to(self.device)

            out_logits = self.model(tokenized_sequence)

            # print(out_logits.shape)

            loss = self.criterion(out_logits.view(-1, 7), tokenized_structure.view(-1))

            # valid_mask = (~padding_mask).unsqueeze(-1) & (~padding_mask).unsqueeze(-2)
            # loss = (loss * valid_mask).sum() / valid_mask.sum()

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()

        avg_losss = total_loss / len(train_dataloader)
        return avg_losss

    # def test_model(self, test_loader)
    # TODO: test model with reconstructed structure from interaction matrix

    def train(
        self, 
        train_dataloader: torch.utils.data.DataLoader, 
        test_dataloader: torch.utils.data.DataLoader,
        num_epochs: int
    ) -> None:
        for epoch in range(num_epochs):
            avg_loss = self.train_epoch(train_dataloader)
            print(f"Epoch {epoch} current loss: {avg_loss}")


In [42]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
device

device(type='cuda')

In [43]:
TRAIN_SIZE = 11500
TEST_SIZE = 1646

train_indices = list(range(TRAIN_SIZE))
test_indices = list(range(TRAIN_SIZE, TRAIN_SIZE+TEST_SIZE))

train_dataset = RNADataset("rna_dataset.csv", train_indices)
test_dataset = RNADataset("rna_dataset.csv", test_indices)


train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=1,          # try 16–64; tune to your VRAM
    # collate_fn=collate_fn
)

test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    shuffle=False,
    batch_size=1,
    # collate_fn=collate_fn
)

In [45]:
model = transformerRNA(hidden_dim=50, num_transformer_layers=2, n_head=1)
model = model.to(device)
trainer = RNATrainer(model, device)
trainer.train(train_dataloader, test_dataloader, 50)

Training: 100%|██████████| 11500/11500 [00:31<00:00, 366.01it/s]


Epoch 0 current loss: 1.0229959918156915


Training: 100%|██████████| 11500/11500 [00:32<00:00, 354.30it/s]


Epoch 1 current loss: 1.0179621970083403


Training:   2%|▏         | 227/11500 [00:00<00:31, 353.60it/s]


KeyboardInterrupt: 

In [None]:
def get_model_memory(model):
    """Calculate model memory in MB"""
    param_size = sum(p.nelement() * p.element_size() for p in model.parameters())
    buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
    total_size_mb = (param_size + buffer_size) / 1024**2
    return total_size_mb

# Check your current model
model_memory = get_model_memory(model)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model memory: {model_memory:.2f} MB")

Model parameters: 25,484,288
Model memory: 97.21 MB


In [34]:
import matplotlib.pyplot as plt
import seaborn as sns

def generate_heatmap(matrix) -> None:
    if isinstance(matrix, torch.Tensor):
        matrix = matrix.float().cpu().numpy()
    sns.heatmap(matrix)
    plt.show()

In [46]:
model.eval()
with torch.no_grad():
    for sequence, tokenized_sequence, structure, tokenized_structure in test_dataloader:
        print(tokenized_structure.shape)


        tokenized_structure = tokenized_structure.view(-1)
        print(tokenized_structure)
        tokenized_sequence = tokenized_sequence.to(device)
        out_logits = model(tokenized_sequence)
        print(out_logits.shape)
        out_logits = out_logits.view(-1, 7)
        print(out_logits)
        print(out_logits.shape)
        # out_logits = torch.sigmoid(out_logits)
        # print(out_logits)

        # generate_heatmap(interaction_matrix[0])
        # generate_heatmap(out_logits[0])


        break

torch.Size([1, 84])
tensor([1, 1, 1, 1, 1, 1, 1, 0, 0, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2,
        2, 2, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0])
torch.Size([1, 84, 7])
tensor([[ 1.3313,  0.9573,  0.8748, -3.4629, -3.6553, -7.6936, -7.8399],
        [ 1.2822,  0.8976,  0.9294, -3.5669, -3.4107, -7.5125, -8.0027],
        [ 1.2822,  0.8976,  0.9294, -3.5669, -3.4107, -7.5125, -8.0027],
        [ 1.5409,  0.8641,  0.7817, -3.4292, -3.6989, -7.8867, -7.9812],
        [ 1.5409,  0.8641,  0.7817, -3.4292, -3.6989, -7.8867, -7.9812],
        [ 1.2822,  0.8976,  0.9294, -3.5669, -3.4107, -7.5125, -8.0027],
        [ 1.3313,  0.9573,  0.8748, -3.4629, -3.6553, -7.6936, -7.8399],
        [ 1.5409,  0.8641,  0.7817, -3.4292, -3.6989, -7.8867, -7.9812],
        [ 1.3313,  0.9573,  0.8748, -3.4629, -3.6553, -7.6936, -7.8399],
        [ 2.026

In [None]:
train_dataset.sequence_tokenizer.tokenize("AUGC")

tensor([1, 2, 4, 3], dtype=torch.int8)