In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset
import pandas as pd
from typing import List, Tuple
from tqdm import tqdm

# Prepare dataset
instead of predicting sequence it predicts pairwise interaction matrix

In [2]:
class RNADataset(Dataset):
    def __init__(self, path: str, indices: List[int]):
        """path: path to .csv file with sequences and structures"""
        super().__init__()
        
        self.data = pd.read_csv(path)
        self.data = self.data.iloc[indices].reset_index(drop=True)

        
        self.VOCABULARY = {
            "A": 1,
            "U": 2,
            "C": 3,
            "G": 4,
        }

        self.pad_idx = 0

    def __len__(self):
        return len(self.data)
    
    def create_interaction_matrix(self, structure: str) -> torch.tensor:
        stack = [[], [], []]
        matrix = torch.zeros(len(structure), len(structure), dtype=torch.bfloat16)

        for i in range(len(structure)):
            match structure[i]:
                case "(":
                    stack[0].append(i)
                case ")":
                    matrix[stack[0].pop(), i] = 1
                case "[":
                    stack[1].append(i)
                case "]":
                    matrix[stack[1].pop(), i] = 1
                case "{":
                    stack[2].append(i)
                case "}":
                    matrix[stack[2].pop(), i] = 1
                case ".":
                    continue
        return matrix
    
    def tokenize(self, sequence: str) -> torch.tensor:
        return torch.tensor([self.VOCABULARY[i] for i in sequence], dtype=torch.long)
    
    def __getitem__(self, index):
        row = self.data.iloc[index]
        sequence = row["sequence"]
        tokenized_sequence = self.tokenize(sequence)
        structure = row["structure"]
        pairwise_interaction_matrix = self.create_interaction_matrix(structure)
        
        return sequence, tokenized_sequence, structure, pairwise_interaction_matrix

# Create model 

In [8]:
class transformerRNA(nn.Module):
    def __init__(
        self, 
        hidden_dim: int=1000, 
        num_transformer_layers: int=10, 
        vocab_size: int = 5, 
        n_head: int=8, 
        dropout: float=0.1,
    ):
        super().__init__()

        # output (L, H)
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=n_head,
            dropout=dropout,
            batch_first=True,
        )
        
        # output (L, H)
        self.encoder = nn.TransformerEncoder(
            encoder_layer=encoder_layer, 
            num_layers=num_transformer_layers
        )
        
        # (H, H)
        self.pairwise = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x):
        x = self.embedding(x)
        # (L, H)
        x = self.encoder(x)
        # (L, H)
        xW = self.pairwise(x)
        # (L, L)
        scores = torch.matmul(xW, x.transpose(-2, -1))

        return scores

# Train model

In [9]:
class RNATrainer:
    def __init__(self, model: transformerRNA, device: torch.device):
        self.model = model
        self.device = device

        self.criterion = nn.BCEWithLogitsLoss()

        self.optimizer = torch.optim.Adam(self.model.parameters())

    def train_epoch(self, train_dataloader: torch.utils.data.DataLoader) -> float:
        self.model.train()
        total_loss = 0

        for sequence, tokenized_seq, structure, interaction_matrix in tqdm(train_dataloader, desc="Training"):
            tokenized_seq = tokenized_seq.to(self.device)
            interaction_matrix = interaction_matrix.to(self.device)

            out_label = self.model(tokenized_seq)

            loss = self.criterion(out_label, interaction_matrix)

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            total_loss += loss.item()

        avg_losss = total_loss / len(train_dataloader)
        return avg_losss

    # def test_model(self, test_loader)
    # TODO: test model with reconstructed structure from interaction matrix

    def train(
        self, 
        train_dataloader: torch.utils.data.DataLoader, 
        test_dataloader: torch.utils.data.DataLoader,
        num_epochs: int
    ) -> None:
        for epoch in range(num_epochs):
            avg_loss = self.train_epoch(train_dataloader)
            print(f"Epoch {epoch} current loss: {avg_loss}")


In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
device

device(type='cuda')

In [14]:
TRAIN_SIZE = 11500
TEST_SIZE = 1646

train_indices = list(range(TRAIN_SIZE))
test_indices = list(range(TRAIN_SIZE, TRAIN_SIZE+TEST_SIZE))

train_dataset = RNADataset("rna_dataset.csv", train_indices)
test_dataset = RNADataset("rna_dataset.csv", test_indices)


train_dataloader = torch.utils.data.DataLoader(
    train_dataset,
    shuffle=True,
    batch_size=1,          # try 16–64; tune to your VRAM
)

test_dataloader = torch.utils.data.DataLoader(
    test_dataset,
    shuffle=False,
    batch_size=1,
)

In [15]:
model = transformerRNA(hidden_dim=100, num_transformer_layers=2, n_head=2)
model = model.to(device)
trainer = RNATrainer(model, device)
trainer.train(train_dataloader, test_dataloader, 10)

Training:  46%|████▌     | 5312/11500 [00:17<00:19, 310.73it/s]


KeyboardInterrupt: 