In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset
import pandas as pd
from typing import List, Tuple
from tqdm import tqdm

# Tokenizers

In [2]:
class SequenceTokenizer:
    def __init__(self):
        self.VOCABULARY = {
            "A": 1,
            "U": 2,
            "C": 3,
            "G": 4,
        }
        self.INVERTED_VOCABULARY = {self.VOCABULARY[i]: i for i in self.VOCABULARY}


    def tokenize(self, sequence: str) -> torch.tensor:
        return torch.tensor([self.VOCABULARY[i] for i in sequence], dtype=torch.long)
    
    def detokenize(self, tokens: torch.tensor) -> str:
        return "".join([self.INVERTED_VOCABULARY[i] for i in tokens.tolist()])
    
class StructureTokenizer:
    def __init__(self):
        self.VOCABULARY = {
            "(": 1,
            ")": 2,
            "[": 3,
            "]": 4,
            "{": 5,
            "}": 6,
            ".": 7,
        }
        self.INVERTED_VOCABULARY = {self.VOCABULARY[i]: i for i in self.VOCABULARY}


    def tokenize(self, structure: str) -> torch.tensor:
        return torch.tensor([self.VOCABULARY[i] for i in structure], dtype=torch.long)
    
    def detokenize(self, tokens: torch.tensor) -> str:
        return "".join([self.INVERTED_VOCABULARY[i] for i in tokens.tolist()])

# Prepare dataset
instead of predicting sequence it predicts pairwise interaction matrix

In [3]:
class RNADataset(Dataset):
    def __init__(self, path: str, indices: List[int]):
        """path: path to .csv file with sequences and structures"""
        super().__init__()
        
        self.data = pd.read_csv(path)
        self.data = self.data.iloc[indices].reset_index(drop=True)

        self.sequence_tokenizer = SequenceTokenizer()
        self.structure_tokenizer = StructureTokenizer()

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        row = self.data.iloc[index]
        sequence = row["sequence"]
        tokenized_sequence = self.sequence_tokenizer.tokenize(sequence)
        structure = row["structure"]
        tokenized_structure = self.structure_tokenizer.tokenize(structure)
        
        return sequence, tokenized_sequence, structure, tokenized_structure

In [None]:
def collate_fn(batch):
    sequences, tokenized_seqs, structures, tokenized_structures = zip(*batch)
    max_len = max(len(seq) for seq in tokenized_seqs)

    padded_tokenized_sequences = []
    padded_tokenized_structures = []

    for tokenized_sequence, tokenized_structure in zip(tokenized_seqs, tokenized_structures):
        sequence_lenght = len(tokenized_sequence)
        padded_seq = torch.cat([
            tokenized_sequence,
            torch.zeros(max_len - sequence_lenght, dtype=torch.long)
        ])
        padded_structure = torch.cat([
            tokenized_structure,
            torch.zeros(max_len - sequence_lenght, dtype=torch.long)
        ])
        padded_tokenized_sequences.append(padded_seq)
        padded_tokenized_structures.append(padded_structure)

    padded_tokenized_sequences = torch.stack(padded_tokenized_sequences, dim=0)
    padded_tokenized_structures = torch.stack(padded_tokenized_structures, dim=0)

    return sequences, padded_tokenized_sequences, structures, padded_tokenized_structures


# Create model 

In [5]:
class transformerRNA(nn.Module):
    def __init__(
        self, 
        hidden_dim: int=1000, 
        num_transformer_layers: int=10, 
        n_head: int=8, 
        dropout: float=0.1,
    ):
        super().__init__()

        # output (L, H)
        self.embedding = nn.Embedding(num_embeddings=5, embedding_dim=hidden_dim, padding_idx=0)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=n_head,
            batch_first=True
        )

        self.encoder = nn.TransformerEncoder(
            encoder_layer=encoder_layer,
            num_layers=num_transformer_layers
        )

        self.output_head = nn.Linear(hidden_dim, 8)

    def forward(self, x, src_key_padding_mask=None):
        padding_mask = (x == 0)
        x = self.embedding(x)
        # (L, H)
        x = self.encoder(x, src_key_padding_mask=padding_mask)
        # (L, H)
        x = self.output_head(x)
        return x

# Train model

In [18]:
class RNATrainer:
    def __init__(self, model: transformerRNA, device: torch.device):
        self.model = model
        self.device = device

        class_weights = torch.tensor([
            0.0,   # padding (ignored anyway)
            5.0,   # ( 
            5.0,   # )
            5.0,   # [
            5.0,   # ]
            5.0,   # {
            5.0,   # }
            1.0,   # .
        ], dtype=torch.float).to(device)

        self.criterion = nn.CrossEntropyLoss(ignore_index=0, weight=class_weights)

        self.optimizer = torch.optim.Adam(self.model.parameters())

    def train_epoch(self, train_dataloader: torch.utils.data.DataLoader) -> float:
        self.model.train()
        total_loss = 0

        for sequence, tokenized_sequence, structure, tokenized_structure in tqdm(train_dataloader, desc="Training"):
            tokenized_sequence = tokenized_sequence.to(self.device)
            tokenized_structure = tokenized_structure.to(self.device)
            # print(tokenized_sequence)
            # padding_mask = padding_mask.to(self.device)

            out_logits = self.model(tokenized_sequence)

            # print(out_logits.shape)

            loss = self.criterion(out_logits.view(-1, 8), tokenized_structure.view(-1))

            # valid_mask = (~padding_mask).unsqueeze(-1) & (~padding_mask).unsqueeze(-2)
            # loss = (loss * valid_mask).sum() / valid_mask.sum()

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()

        avg_losss = total_loss / len(train_dataloader)
        return avg_losss

    # def test_model(self, test_loader)
    # TODO: test model with reconstructed structure from interaction matrix

    def train(
        self, 
        train_dataloader: torch.utils.data.DataLoader, 
        test_dataloader: torch.utils.data.DataLoader,
        num_epochs: int
    ) -> None:
        for epoch in range(num_epochs):
            avg_loss = self.train_epoch(train_dataloader)
            print(f"Epoch {epoch} current loss: {avg_loss}")


In [19]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
device

device(type='cuda')

In [20]:
TRAIN_SIZE = 11500
TEST_SIZE = 1646

train_indices = list(range(TRAIN_SIZE))
test_indices = list(range(TRAIN_SIZE, TRAIN_SIZE+TEST_SIZE))

train_dataset = RNADataset("rna_dataset.csv", train_indices)
test_dataset = RNADataset("rna_dataset.csv", test_indices)


train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=16,          # try 16–64; tune to your VRAM
    collate_fn=collate_fn
)

test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    shuffle=False,
    batch_size=1,
    collate_fn=collate_fn
)

In [9]:
model = transformerRNA(hidden_dim=512, num_transformer_layers=8, n_head=8)
model = model.to(device)
trainer = RNATrainer(model, device)
trainer.train(train_dataloader, test_dataloader, 10)

Training: 100%|██████████| 719/719 [02:17<00:00,  5.23it/s]


Epoch 0 current loss: 1.0670220129844707


Training: 100%|██████████| 719/719 [02:15<00:00,  5.29it/s]


Epoch 1 current loss: 1.052749472798492


Training: 100%|██████████| 719/719 [02:14<00:00,  5.34it/s]


Epoch 2 current loss: 1.0513404981648973


Training: 100%|██████████| 719/719 [02:15<00:00,  5.32it/s]


Epoch 3 current loss: 1.0504548284374127


Training: 100%|██████████| 719/719 [02:17<00:00,  5.23it/s]


Epoch 4 current loss: 1.0500512302999536


Training: 100%|██████████| 719/719 [02:15<00:00,  5.30it/s]


Epoch 5 current loss: 1.050414896475589


Training: 100%|██████████| 719/719 [02:15<00:00,  5.30it/s]


Epoch 6 current loss: 1.0501475552192816


Training: 100%|██████████| 719/719 [02:13<00:00,  5.39it/s]


Epoch 7 current loss: 1.0493770580795776


Training: 100%|██████████| 719/719 [02:14<00:00,  5.36it/s]


Epoch 8 current loss: 1.05007279698474


Training: 100%|██████████| 719/719 [02:14<00:00,  5.35it/s]

Epoch 9 current loss: 1.0493873958793236





In [14]:
def get_model_memory(model):
    """Calculate model memory in MB"""
    param_size = sum(p.nelement() * p.element_size() for p in model.parameters())
    buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
    total_size_mb = (param_size + buffer_size) / 1024**2
    return total_size_mb

# Check your current model
model_memory = get_model_memory(model)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
print(f"Model memory: {model_memory:.2f} MB")

Model parameters: 25,225,736
Model memory: 96.23 MB


In [11]:
import matplotlib.pyplot as plt
import seaborn as sns

def generate_heatmap(matrix) -> None:
    if isinstance(matrix, torch.Tensor):
        matrix = matrix.float().cpu().numpy()
    sns.heatmap(matrix)
    plt.show()

In [None]:
model.eval()
detokenizer = StructureTokenizer()

with torch.no_grad():
    for sequence, tokenized_sequence, structure, tokenized_structure in test_dataloader:
        print(tokenized_structure.shape)


        tokenized_structure = tokenized_structure.view(-1)
        tokenized_sequence = tokenized_sequence.to(device)
        out_logits = model(tokenized_sequence)
        out_probs = torch.softmax(out_logits, dim=-1)
        predicted_structure = torch.argmax(out_probs, dim=-1, )
        print(predicted_structure)
        print("Predicted structure tokens:", detokenizer.detokenize(predicted_structure[0]))
        print("Ground truth structure tokens:", structure[0])

        

        break

torch.Size([1, 84])
tensor([[7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
         7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
         7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7,
         7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7]], device='cuda:0')
Predicted structure tokens: ....................................................................................
Ground truth structure tokens: (((((((..(((...........))).(((((.......)))))............(((((.......))))))))))))....


In [13]:
train_dataset.sequence_tokenizer.tokenize("AUGC")

tensor([1, 2, 4, 3])