In [None]:
import torch
from torch import nn
from torch.utils.data import Dataset
import pandas as pd

# Prepare dataset
instead of predicting sequence it predicts pairwise interaction matrix

In [None]:
class RNADataset(Dataset):
    def __init__(self, path: str):
        """path: path to .csv file with sequences and structures"""
        super().__init__()
        
        self.data = pd.read_csv(path)

    def __len__(self):
        return len(self.data)
    
    def create_interaction_matrix(structure: str) -> torch.tensor:
        stack = [[], [], []]
        matrix = torch.zeros(len(structure), len(structure), dtype=torch.bfloat16)

        for i in range(len(structure)):
            match structure[i]:
                case "(":
                    stack[0].append(i)
                case ")":
                    matrix[stack[0].pop(), i] = 1
                case "[":
                    stack[1].append(i)
                case "]":
                    matrix[stack[1].pop(), i] = 1
                case "{":
                    stack[2].append(i)
                case "}":
                    matrix[stack[2].pop(), i] = 1
                case ".":
                    continue
        return matrix
    
    def tokenize(sequence: str) -> torch.tensor:
        VOCABULARY = {
            "A": torch.int8(0),
            "U": torch.int8(1),
            "C": torch.int8(2),
            "G": torch.int8(3),
        }

        return torch.tensor([VOCABULARY[i] for i in sequence])
    
    def __getitem__(self, index):
        row = self.data.iloc[index]
        sequence = row["sequence"]
        tokenized_sequence = self.tokenize(sequence)
        structure = row["structure"]
        matrix = self.create_interaction_matrix(structure)
        
        return tokenized_sequence, matrix

# Create model 

In [34]:
class transformerRNA(nn.Module):
    def __init__(self, hidden_dim: int=1000, num_transformer_layers: int=10, vocab_size: int = 4, n_head: int=8, dropout: float=0.1):
        super().__init__()

        # output (L, H)
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=hidden_dim)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=hidden_dim,
            nhead=n_head,
            dropout=dropout,
            batch_first=True,
        )
        
        # output (L, H)
        self.encoder = nn.TransformerEncoder(
            encoder_layer=encoder_layer, 
            num_layers=num_transformer_layers
        )
        
        # (H, H)
        self.pairwise = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, x):
        x = self.embedding(x)

        # (L, H)
        x = self.encoder(x)

        # (L, H)
        xW = self.pairwise(x)

        # (L, L)
        scores = torch.matmul(xW, x.transpose(-2, -1))

        return scores

In [38]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')