In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from pathlib import Path
import gc
import logging
from tqdm import tqdm
import math 
# local imports
from src.data import RNADataset
from src.model import EnhancedRNAPredictor
from src.train import train_epoch, validate_epoch, train_model, create_padding_mask
from src.utils import inference, visualize_predictions
from src.objective import create_criterion
from src.checkpoint import save_checkpoint, load_checkpoint
import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
class RnaEncoder(nn.Module):
    def __init__(self, input_size=5, hidden_size=64):
        super().__init__()
        self.in_layer = nn.Linear(input_size, hidden_size//2)
        self.out_layer = nn.Linear(hidden_size//2, hidden_size)
        self.activation = nn.ReLU()
        
    def forward(self, x):
        out = self.activation(self.in_layer(x))
        out = self.out_layer(out)
        return out
    
class RnaDecoder(nn.Module):
    def __init__(self, hidden_size=64, output_size=61):
        super().__init__()
        self.in_layer = nn.Linear(hidden_size, hidden_size//2)
        self.out_layer = nn.Linear(hidden_size//2, output_size)
        self.activation = nn.ReLU()
        
    def forward(self, x):
        out = self.activation(self.in_layer(x))
        out =  nn.Softmax()(self.out_layer(out))
        return out
    
class SimpleRNAPredictor(nn.Module):
    def __init__(self, input_size=5, hidden_size=64, out_size=61, num_hidden=6):
        super().__init__()
        self.encoder = RnaEncoder()
        self.decoder = RnaDecoder()

    def forward(self, x):
        latent = self.encoder(x)
        out    = self.decoder(latent)
        return out

In [3]:
class EfficientRnaEncoder(nn.Module):
    def __init__(self, input_size=5, hidden_size=64, num_layers=2, dropout=0.1):
        super().__init__()
        
        self.input_proj = nn.Linear(input_size, hidden_size)
        self.layer_norm = nn.LayerNorm(hidden_size)
        
        self.layers = nn.ModuleList([
            nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.LayerNorm(hidden_size),
                nn.ReLU(),
                nn.Dropout(dropout)
            ) for _ in range(num_layers)
        ])
        
    def forward(self, x):
        # Ensure input is float and requires grad
        x = x.float()
        x.requires_grad_(True)
        
        x = self.input_proj(x)
        x = self.layer_norm(x)
        
        for layer in self.layers:
            x = layer(x) + x  # Add residual connection
        return x

class EfficientRnaDecoder(nn.Module):
    def __init__(self, hidden_size=64, output_size=61, num_layers=2, dropout=0.1):
        super().__init__()
        
        self.layers = nn.ModuleList()
        current_size = hidden_size
        
        for i in range(num_layers):
            next_size = output_size if i == num_layers-1 else current_size
            self.layers.append(nn.Sequential(
                nn.Linear(current_size, next_size),
                nn.LayerNorm(next_size) if i < num_layers-1 else nn.Identity(),
                nn.ReLU() if i < num_layers-1 else nn.Identity(),
                nn.Dropout(dropout) if i < num_layers-1 else nn.Identity()
            ))
            current_size = next_size
            
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x  # Note: No softmax here, will be handled by loss function

class RNAPredictor(nn.Module):
    def __init__(self, input_size=5, hidden_size=64, output_size=61, num_layers=2, dropout=0.1):
        super().__init__()
        
        self.encoder = EfficientRnaEncoder(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            dropout=dropout
        )
        
        self.decoder = EfficientRnaDecoder(
            hidden_size=hidden_size,
            output_size=output_size,
            num_layers=num_layers,
            dropout=dropout
        )
        
        # Initialize weights properly
        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)
                
    def forward(self, x):
        # Ensure model is in training mode when needed
        if self.training:
            self.train()
        
        latent = self.encoder(x)
        logits = self.decoder(latent)
        return logits

In [4]:
# Configuration

batch_size = 28
num_epochs = 3
learning_rate = 0.003
model_save_dir = Path('models')

In [5]:
# Create model save directory
model_save_dir.mkdir(parents=True, exist_ok=True)

# Load dataset
train_dataset = RNADataset.load("data/train_dataset.pkl")
val_dataset   = RNADataset.load("data/val_dataset.pkl")
test_dataset  = RNADataset.load("data/test_dataset.pkl")

struct_to_idx = train_dataset.struct_to_idx

In [6]:
# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    # collate_fn=lambda b: collate_fn(b, device)
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    # collate_fn=lambda b: collate_fn(b, device)
)

In [7]:
# model = RNATransformer(
#     num_nucleotides=5,  # A, U, G, C, N
#     num_structure_labels=60,  
#     d_model=128
# ).to(device)
# predictor = RNAStructurePredictor(model)


# model = SimpleRNAPredictor(input_size=5).to(device)

In [8]:
model = EnhancedRNAPredictor(
    input_size=5,
    hidden_size=32,
    output_size=61,
    num_layers=3,
    num_heads=8,
    dropout=0.1
    
    ).to(device)


# Use mixed precision criterion
criterion = nn.CrossEntropyLoss(ignore_index=-100)
# criterion = create_criterion(struct_to_idx=struct_to_idx, ignore_index=-100, bracket_weight=0.5)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, verbose=True
)

In [None]:
best_loss = float('inf')
validate_every = 10
for epoch in range(num_epochs):
    # Clear cache before each epoch
    gc.collect()
    torch.cuda.empty_cache()
    
    # Train
    train_loss = train_epoch(
        model,
        train_loader, 
        criterion, 
        optimizer,
        device=device
    )
    print(f'Epoch {epoch+1} - Training Loss: {train_loss:.4f}')
    
    if (epoch + 1) % validate_every == 0:
        val_loss, metrics = validate_epoch(model, val_loader, criterion, device=device)
        # print(20*"=" + " : Validation : " + 20*"=")
        print(f'Validation Loss: {val_loss:.4f}')
        # visualize_predictions(model, val_loader, device, total_examples=3)
        # print(20*"=" + " : Validation : " + 20*"=")
    
    # Update learning rate
    scheduler.step(train_loss)
    save_checkpoint(model, optimizer, epoch)
    
    # save model 
    if best_loss > train_loss:
        best_loss = train_loss
        save_checkpoint(model, optimizer, epoch, 'checkpoints/best.pt')
    

Training: 100%|██████████| 3289/3289 [9:03:54<00:00,  9.92s/it, loss=0.9554, avg_loss=0.9888]  


Epoch 1 - Training Loss: 0.9888
Saved checkpoint to checkpoints/checkpoint.pt
Saved checkpoint to checkpoints/best.pt


Training:  17%|█▋        | 570/3289 [1:34:43<7:29:43,  9.92s/it, loss=0.9437, avg_loss=0.9494] 

In [None]:

inference(model, val_loader, device)

Running inference...


100%|██████████| 183/183 [04:46<00:00,  1.56s/it]



Overall Metrics:
Accuracy: 0.4995
Precision: 0.5147
Recall: 0.4995
F1 Score: 0.4933

Per-Class Metrics:

Class .:
Precision: 0.6110
Recall: 0.5519
F1 Score: 0.5799
Support: 712513

Class (:
Precision: 0.4959
Recall: 0.3068
F1 Score: 0.3790
Support: 452842

Class ):
Precision: 0.0000
Recall: 0.0000
F1 Score: 0.0000
Support: 19


{'accuracy': 0.4994855787621452,
 'precision': 0.5146972939378416,
 'recall': 0.4994855787621452,
 'f1': 0.49329954855007085,
 'per_class': {'.': {'precision': 0.6109866411442813,
   'recall': 0.5519085265812694,
   'f1': 0.5799469225318203,
   'support': 712513},
  '(': {'precision': 0.4958522516348268,
   'recall': 0.3067604153325001,
   'f1': 0.37903175196522754,
   'support': 452842},
  ')': {'precision': 0.0, 'recall': 0.0, 'f1': 0.0, 'support': 19}}}

In [None]:
visualize_predictions(model, val_loader, device, num_examples=1)


Visualizing 1 example predictions:

Example   1:
Sequence:  GGGAUUGUAGUUCAAUUGGUCAGAGCACCGCCCUGUCAAGGCGGAAGCUGCGGGUUCGAGCCCCGUCAGUCCCG
Predicted: (((((((..((((.........))))((((((.......))))))...((((((......))))))))))))).
Ground Tr: (((((((..((((.........))))((((((.......))))))....(((((.......)))))))))))).

Example   2:
Sequence:  UUAAAACAGCUCUGGGGUUGCACCCACCCCAGAGGCCCACGUGGCGGCUAGUACUCCGGUAUUGCGGUACCCUUGUACGCCUGUUUUAUAC
Predicted: ....(((.((.(.((((((((...(.((((...........................))....)))....)).))).)))))).......)
Ground Tr: .(((((((.(((((((((.......)))))))))(((.....))).((..((((...((((......))))...)))))).)))))))...

Example   3:
Sequence:  GGGGCUGAUUCUGGAUUCGACGGGAUUUGCGAAACCCAAGGUGCAUGCCGAGGGGCGGUUGGCCUCGUAAAAAGCCGCAAAAAAAUAGUCGCAAACGACGAAACCUACGCUUUAGCAGCUUAAUAACCUGCUUAGAGCCCUCUCUCCCUAGCCUCCGCUCUUAGGACGGGGAUCAAGAGAGGUCAAACCCAAAAGAGAUCGCGCGGAUGCCCUGCCUGGGGUUGAAGCGUUAAAACGAAUCAGGCUAGUCUGGUAGUGGCGUGUCCGUCCGCAGGUGCCAGGCGAAUGUAAAGACUGACUAAGCAUGUAGUACCGAGGAUGUAGGAAUUUCGGACGCGGGUUCA