In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset
import pandas as pd
from typing import List, Tuple
from tqdm import tqdm

# Tokenizers

In [2]:
class SequenceTokenizer:
    def __init__(self):
        self.VOCABULARY = {
            "A": 1,
            "U": 2,
            "C": 3,
            "G": 4,
        }
        self.INVERTED_VOCABULARY = {self.VOCABULARY[i]: i for i in self.VOCABULARY}


    def tokenize(self, sequence: str) -> torch.tensor:
        return torch.tensor([self.VOCABULARY[i] for i in sequence], dtype=torch.long)
    
    def detokenize(self, tokens: torch.tensor) -> str:
        return "".join([self.INVERTED_VOCABULARY[i] for i in tokens.tolist()])
    
class StructureTokenizer:
    def __init__(self):
        self.VOCABULARY = {
            "(": 1,
            ")": 2,
            "[": 3,
            "]": 4,
            "{": 5,
            "}": 6,
            ".": 7,
        }
        self.INVERTED_VOCABULARY = {self.VOCABULARY[i]: i for i in self.VOCABULARY}


    def tokenize(self, structure: str) -> torch.tensor:
        return torch.tensor([self.VOCABULARY[i] for i in structure], dtype=torch.long)
    
    def detokenize(self, tokens: torch.tensor) -> str:
        return "".join([self.INVERTED_VOCABULARY[i] for i in tokens.tolist()])

# Prepare dataset
instead of predicting sequence it predicts pairwise interaction matrix

In [3]:
class RNADataset(Dataset):
    def __init__(self, path: str, indices: List[int]):
        """path: path to .csv file with sequences and structures"""
        super().__init__()
        
        self.data = pd.read_csv(path)
        self.data = self.data.iloc[indices].reset_index(drop=True)

        self.sequence_tokenizer = SequenceTokenizer()
        self.structure_tokenizer = StructureTokenizer()

    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        row = self.data.iloc[index]
        sequence = row["sequence"]
        tokenized_sequence = self.sequence_tokenizer.tokenize(sequence)
        structure = row["structure"]
        tokenized_structure = self.structure_tokenizer.tokenize(structure)
        
        return sequence, tokenized_sequence, structure, tokenized_structure

In [4]:
def collate_fn(batch):
    sequences, tokenized_seqs, structures, tokenized_structures = zip(*batch)
    max_len = max(len(seq) for seq in tokenized_seqs)

    padded_tokenized_sequences = []
    padded_tokenized_structures = []

    for tokenized_sequence, tokenized_structure in zip(tokenized_seqs, tokenized_structures):
        sequence_lenght = len(tokenized_sequence)
        padded_seq = torch.cat([
            tokenized_sequence,
            torch.zeros(max_len - sequence_lenght, dtype=torch.long)
        ])
        padded_structure = torch.cat([
            tokenized_structure,
            torch.zeros(max_len - sequence_lenght, dtype=torch.long)
        ])
        padded_tokenized_sequences.append(padded_seq)
        padded_tokenized_structures.append(padded_structure)

    padded_tokenized_sequences = torch.stack(padded_tokenized_sequences, dim=0)
    padded_tokenized_structures = torch.stack(padded_tokenized_structures, dim=0)

    return sequences, padded_tokenized_sequences, structures, padded_tokenized_structures


# Create model 

In [5]:
class transformerRNA(nn.Module):
    def __init__(
        self, 
        hidden_dim: int=1000, 
        num_transformer_layers: int=10, 
        n_head: int=8, 
        dropout: float=0.1,
    ):
        super().__init__()

        # output (L, H)
        self.embedding = nn.Embedding(num_embeddings=5, embedding_dim=hidden_dim, padding_idx=0)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=n_head,
            batch_first=True
        )

        self.encoder = nn.TransformerEncoder(
            encoder_layer=encoder_layer,
            num_layers=num_transformer_layers
        )

        self.output_head = nn.Linear(hidden_dim, 8)

    def forward(self, x, src_key_padding_mask=None):
        padding_mask = (x == 0)
        x = self.embedding(x)
        # (L, H)
        x = self.encoder(x, src_key_padding_mask=padding_mask)
        # (L, H)
        x = self.output_head(x)
        return x

# Train model

In [None]:
class RNATrainer:
    def __init__(self, model: transformerRNA, device: torch.device, balance_loss_weight: float=0.1):
        self.model = model
        self.device = device

        class_weights = torch.tensor([
            0.0,   # padding (ignored anyway)
            2.6,   # ( 
            2.6,   # )
            98.8,   # [
            98.8,   # ]
            8064.5,   # {
            8064.5,   # }
            1.07,   # .
        ], dtype=torch.float).to(device)

        self.balance_loss_weight = balance_loss_weight

        self.criterion = nn.CrossEntropyLoss(ignore_index=0, weight=class_weights)
        self.optimizer = torch.optim.Adam(self.model.parameters())

        self.structure_tokenizer = StructureTokenizer()


    def compute_bracket_balance(self, predicted_structure) -> torch.Tensor:
        batch_size = predicted_structure.shape[0]
        balance_loss = 0.0

        for i in range(int(batch_size)):
            open_round = (predicted_structure[i] == 1).sum().float()   # (
            close_round = (predicted_structure[i] == 2).sum().float()  # )
            open_square = (predicted_structure[i] == 3).sum().float()  # [
            close_square = (predicted_structure[i] == 4).sum().float() # ]
            open_curly = (predicted_structure[i] == 5).sum().float()   # {
            close_curly = (predicted_structure[i] == 6).sum().float()  # }

            balance_loss += torch.abs(open_round - close_round)
            balance_loss += torch.abs(open_square - close_square)
            balance_loss += torch.abs(open_curly - close_curly)

        return balance_loss / batch_size


    def train_epoch(self, train_dataloader: torch.utils.data.DataLoader) -> Tuple[float, float, float]:
        self.model.train()
        total_loss = 0
        total_ce_loss = 0
        total_bracket_balance_loss = 0

        for sequence, tokenized_sequence, structure, tokenized_structure in tqdm(train_dataloader, desc="Training"):
            tokenized_sequence = tokenized_sequence.to(self.device)
            tokenized_structure = tokenized_structure.to(self.device)
            # print(tokenized_sequence)
            # padding_mask = padding_mask.to(self.device)

            out_logits = self.model(tokenized_sequence)

            # print(out_logits.shape)

            cross_entropy_loss = self.criterion(out_logits.view(-1, 8), tokenized_structure.view(-1))
            total_ce_loss += cross_entropy_loss.item()

            # bracket balance loss
            predicted_tokens = torch.argmax(out_logits, dim=-1)
            balance_loss = self.compute_bracket_balance(predicted_tokens)
            total_bracket_balance_loss += balance_loss.item()

            
            # valid_mask = (~padding_mask).unsqueeze(-1) & (~padding_mask).unsqueeze(-2)
            loss = cross_entropy_loss + self.balance_loss_weight * balance_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()

        return float(total_loss), float(total_ce_loss), float(total_bracket_balance_loss)

    def test_model(self, test_dataloader):
        self.model.eval()

        total_correct = 0
        total_tokens = 0


        with torch.no_grad():
            for idx, (sequence, tokenized_sequence, structure, tokenized_structure) in enumerate(test_dataloader):
                tokenized_sequence = tokenized_sequence.to(self.device)
                tokenized_structure = tokenized_structure.to(self.device)

                out_logits = self.model(tokenized_sequence)
                
                predicted_tokens = torch.argmax(out_logits, dim=-1)

                # Calculate accuracy (excluding padding)
                non_padding_mask = (tokenized_structure != 0)
                correct = (predicted_tokens == tokenized_structure) & non_padding_mask
                total_correct += correct.sum().item()
                total_tokens += non_padding_mask.sum().item()
                
                # show example 
                if idx == 0:
                    print(sequence[0])
                    print(structure[0])
                    non_padded_tokens = predicted_tokens[0][tokenized_structure[0] != 0]
                    detokenized_predicted_structure = self.structure_tokenizer.detokenize(non_padded_tokens)
                    print(detokenized_predicted_structure)
                    print(len(sequence[0]), len(structure[0]), len(detokenized_predicted_structure))

        accuracy = total_correct / total_tokens if total_tokens > 0 else 0.0
        return accuracy, total_correct, total_tokens


    def train(
        self, 
        train_dataloader: torch.utils.data.DataLoader, 
        test_dataloader: torch.utils.data.DataLoader,
        num_epochs: int
    ) -> None:
        best_loss = float("inf")
        for epoch in range(num_epochs):
            avg_loss, ce_loss, bracket_loss = self.train_epoch(train_dataloader)
            print(f"Epoch {epoch} current loss: {avg_loss} cross-entropy loss: {ce_loss} bracket_loss {bracket_loss}")
            
            accuracy, total_correct, total_tokens = self.test_model(test_dataloader)
            print(f"\nTest Accuracy: {(accuracy*100):.4f}% ({total_correct}/{total_tokens})")

            best_loss = min(best_loss, avg_loss)
        print("Best loss", best_loss)


In [58]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
device

device(type='cuda')

In [59]:
import random
TRAIN_SIZE = 12149
TEST_SIZE = 1000

all_indexes = list(range(TRAIN_SIZE+TEST_SIZE))
test_indices = random.choices(all_indexes, k=TEST_SIZE)
train_indices = [i for i in all_indexes if i not in test_indices]

train_dataset = RNADataset("rna_dataset.csv", train_indices)
test_dataset = RNADataset("rna_dataset.csv", test_indices)


train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=64,          # try 16–64; tune to your VRAM
    collate_fn=collate_fn
)

test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    shuffle=False,
    batch_size=100,
    collate_fn=collate_fn
)

# Train 2 example models

In [None]:
model = transformerRNA(hidden_dim=256, num_transformer_layers=4, n_head=4)
model = model.to(device)
trainer = RNATrainer(model, device, balance_loss_weight=0.02)
trainer.train(train_dataloader, test_dataloader, 1)

Training: 100%|██████████| 191/191 [00:06<00:00, 30.75it/s]

Epoch 0 current loss: 1669.0869679450989 cross-entropy loss: 369.7909541130066 bracket_loss 64964.80209350586





In [None]:
model = transformerRNA(hidden_dim=384, num_transformer_layers=6, n_head=1)
model = model.to(device)
trainer = RNATrainer(model, device, balance_loss_weight=0.02)
trainer.train(train_dataloader, test_dataloader, 30)