In [1]:
!pip install -q pandas numpy torch matplotlib tqdm pytorch-lightning torchmetrics

import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchmetrics


import pytorch_lightning as pl
import os
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
from torch.utils.data import Dataset , DataLoader
from pytorch_lightning.callbacks import LearningRateMonitor , ModelCheckpoint, EarlyStopping

  warn(f"Failed to load image Python extension: {e}")


In [2]:
TRAIN_DATA_PATH = "data\ECCB2017\dataset_Rfam_6320_13classes.fasta"
VAL_DATA_PATH = "data\ECCB2017\dataset_Rfam_validated_2600_13classes.fasta"

In [3]:
def read_fasta(filepath):
    try:
        sequences = {}
        with open(filepath, 'r') as f:
            name = None
            seq = ''
            for line in f:
                line = line.strip()
                if line.startswith('>'):
                    if name:
                        sequences[name] = seq
                    name = line[1:]
                    seq = ''
                else:
                    seq += line
            if name:
                sequences[name] = seq
        return sequences
    except FileNotFoundError:
        print(f"Error: File not found at {filepath}")
        return None
    except Exception as e:
        print(f"An error occurred while reading the file: {e}")
        return None

def clean_sequence(sequence):
    return ''.join(c.upper() for c in sequence if c in 'AUCG')

def extract_rna_type(name):
    parts = name.split()
    return parts[0] , parts[1]

def process_fasta_data(fasta_data):
    processed_data = []
    for name, seq in fasta_data.items():
        cleaned_seq = clean_sequence(seq)
        name , rna_type = extract_rna_type(name)
        processed_data.append((cleaned_seq, name,rna_type))
    return processed_data

def load_and_process_fasta(filepath):
    raw_data = read_fasta(filepath)
    if raw_data is None:
        return None
    processed_data = process_fasta_data(raw_data)
    return pd.DataFrame(processed_data, columns=['Sequence','Name','RNA_Type'])

In [4]:
train_df = load_and_process_fasta(TRAIN_DATA_PATH)
test_df = load_and_process_fasta(VAL_DATA_PATH)
len(set(train_df['Sequence'].tolist())) , len(set(test_df['Sequence'].tolist()))

(6320, 2600)

In [5]:
class RNADataset(Dataset):
    def __init__(self, dataframe):
        self.dataframe = dataframe
        self.label_to_index = {label: idx for idx, label in enumerate(dataframe['RNA_Type'].unique())}

    def __len__(self):
        return len(self.dataframe)

    def __getitem__(self, idx):
        sequence = self.dataframe.iloc[idx]['Sequence']
        label = self.dataframe.iloc[idx]['RNA_Type']
        label_encoded = self.label_to_index[label]
        return sequence, label_encoded

In [6]:
vocab = {
    'A': 1,
    'U': 2,
    'C': 3,
    'G': 4,
    '<UNK>':-1,
    '<PAD>': 0
}

def encode_sequence(sequence):
    return [vocab.get(char, vocab['<UNK>']) for char in sequence]

def collate_fn(batch):
    sequences, labels = zip(*batch)
    encoded = [encode_sequence(seq) for seq in sequences]
    
    max_len = max(len(seq) for seq in encoded)
    padded = [
        seq + [vocab['<PAD>']] * (max_len - len(seq)) for seq in encoded
    ]
    
    sequences_tensor = torch.tensor(padded, dtype=torch.long)
    labels_tensor = torch.tensor(labels, dtype=torch.long)
    
    return sequences_tensor, labels_tensor


In [7]:
train_dataset = RNADataset(train_df)
trainloader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    pin_memory=True,
    collate_fn=collate_fn,
)
test_dataset = RNADataset(train_df)
testloader = DataLoader(
    train_dataset,
    batch_size=64,
    shuffle=True,
    pin_memory=True,
    collate_fn=collate_fn,
)


In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class SimpleBiGRU(nn.Module):
    def __init__(self, hidden_dim: int, output_dim: int, vocab_size: int, embedding_dim: int):
        super(SimpleBiGRU, self).__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim, padding_idx=vocab['<PAD>'])
        self.gru = nn.GRU(input_size=embedding_dim,
                          hidden_size=hidden_dim,
                          batch_first=True,
                          bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, output_dim)  # Multiply by 2 for bidirectional

    def forward(self, x):
        embedded = self.embedding(x)  # Shape: (batch_size, seq_length, embedding_dim)
        gru_out, _ = self.gru(embedded)  # Shape: (batch_size, seq_length, hidden_dim * 2)
        pooled = F.adaptive_avg_pool1d(gru_out.permute(0, 2, 1), 1).squeeze(2)
        output = self.fc(pooled)  # Shape: (batch_size, output_dim)
        return output


In [9]:
class RNAClassifier(pl.LightningModule):
    def __init__(self, hidden_dim, output_dim, vocab_size, embedding_dim, learning_rate):
        super().__init__()
        self.model = SimpleBiGRU(hidden_dim, output_dim, vocab_size, embedding_dim)
        self.criterion = nn.CrossEntropyLoss()
        self.learning_rate = learning_rate

        self.train_acc = torchmetrics.Accuracy(task='multiclass',num_classes=13)
        self.val_acc = torchmetrics.Accuracy(task='multiclass',num_classes=13)
        self.val_f1 = torchmetrics.F1Score(task='multiclass',num_classes=13)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        sequences, labels = batch
        outputs = self(sequences)
        loss = self.criterion(outputs, labels)

        self.train_acc(outputs, labels.int())
        self.log('train_loss', loss, on_step=True, on_epoch=False, prog_bar=True)
        self.log('train_acc', self.train_acc, on_step=False, on_epoch=True, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        sequences, labels = batch
        outputs = self(sequences)
        loss = self.criterion(outputs, labels)

        self.val_acc(outputs, labels.int())
        self.val_f1(outputs, labels.int())

        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_acc', self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_f1', self.val_f1, on_step=False, on_epoch=True, prog_bar=True)

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.learning_rate, fused=True)
        scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)
        return [optimizer], [scheduler]

In [16]:
# Model instantiation
hidden_dim = 128
output_dim = 13
vocab_size = len(vocab)
embedding_dim = 128
learning_rate = 5e-4

model = RNAClassifier(hidden_dim, output_dim, vocab_size, embedding_dim, learning_rate)

checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',       # Metric to monitor
    save_top_k=1,             # Save only the best model
    mode='min',               # Mode 'min' because lower validation loss is better
    filename='best-checkpoint', # Filename for the checkpoint
    verbose=False              # Enable logging for this callback
)

early_stopping_callback = EarlyStopping(
    monitor='val_loss',       # Metric to monitor
    patience=30,              # Number of epochs with no improvement after which training will be stopped
    mode='min',               # Mode 'min' because lower validation loss is better
    verbose=False              # Enable logging for this callback
)

lr_scheduler_callback = LearningRateMonitor(logging_interval='step')

trainer = pl.Trainer(
    max_epochs=100,
    accelerator='gpu',
    devices=1,
    # strategy='ddp_notebook',  # Use 'ddp_notebook' for Jupyter environments
    callbacks=[lr_scheduler_callback,checkpoint_callback, early_stopping_callback]
)

trainer.fit(model, trainloader, testloader)