In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from pathlib import Path
import gc
import logging
from tqdm import tqdm

# local imports
from src.data import RNADataset
from src.model import RNAStructurePredictor
from src.train import train_epoch, validate_epoch

import warnings
warnings.filterwarnings('ignore')

In [18]:
class SimpleRNAPredictor(nn.Module):
    def __init__(self, input_size=5, hidden_size=64):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=1,
            bidirectional=True,
            batch_first=True
        )
        self.fc = nn.Linear(hidden_size * 2, 3)  # 3 classes: '.', '(', ')'

    def forward(self, x, lengths):
        # Pack sequence
        packed = nn.utils.rnn.pack_padded_sequence(
            x, lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        # LSTM
        output, _ = self.lstm(packed)
        # Unpack
        output, _ = nn.utils.rnn.pad_packed_sequence(output, batch_first=True)
        # Predict
        return self.fc(output)

In [19]:
def collate_fn(batch, device):
    # Sort batch by length for packed sequence
    batch = sorted(batch, key=lambda x: x['length'], reverse=True)
    
    # Get max length in batch
    max_len = max(item['length'] for item in batch)
    
    # Pad sequences
    sequences = torch.stack([
        torch.nn.functional.pad(
            item['sequence'],
            (0, 0, 0, max_len - item['length'])
        ) for item in batch
    ])
    
    structures = torch.stack([
        torch.nn.functional.pad(
            item['structure'],
            (0, max_len - item['length']),
            value=-100  # ignore_index for CrossEntropyLoss
        ) for item in batch
    ])
    
    lengths = torch.tensor([item['length'] for item in batch])
    
    return {
        'sequences': sequences.to(device),
        'structures': structures.to(device),
        'lengths': lengths.to(device)
    }

In [20]:
# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 16
num_epochs = 50
learning_rate = 0.001
val_split = 0.1
gradient_accumulation_steps = 4 
model_save_dir = Path('models')

In [25]:
# Create model save directory
model_save_dir.mkdir(parents=True, exist_ok=True)

# Load dataset
train_dataset = RNADataset.load("data/train_dataset.pkl")
val_dataset = RNADataset.load("data/val_dataset.pkl")
test_dataset = RNADataset.load("data/test_dataset.pkl")

In [26]:
# Create data loaders
test_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=True,
    collate_fn=lambda b: collate_fn(b, device)
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    collate_fn=lambda b: collate_fn(b, device)
)

In [23]:
model = SimpleRNAPredictor().to(device)
    
    # Use mixed precision criterion
criterion = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, verbose=True
)

In [24]:
best_loss = float('inf')
for epoch in range(num_epochs):
    # Clear cache before each epoch
    gc.collect()
    torch.cuda.empty_cache()
    
    # Train
    train_loss = train_epoch(
        model, 
        val_loader,
        # train_loader, 
        criterion, 
        optimizer
    )
    print(f'Epoch {epoch+1} - Training Loss: {train_loss:.4f}')
    
    # Update learning rate
    scheduler.step(train_loss)
    
    # Save checkpoint
    if train_loss < best_loss:
        best_loss = train_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_loss,
        }, 'best_model.pt')
    
    # Optional: clear cache after each epoch
    gc.collect()
    torch.cuda.empty_cache()

Training: 100%|██████████| 320/320 [00:20<00:00, 15.37it/s, loss=0.9985]


Epoch 1 - Training Loss: 0.9723


Training: 100%|██████████| 320/320 [00:20<00:00, 15.30it/s, loss=1.0210]


Epoch 2 - Training Loss: 0.8915


Training: 100%|██████████| 320/320 [00:21<00:00, 15.16it/s, loss=0.9719]


Epoch 3 - Training Loss: 0.8084


Training: 100%|██████████| 320/320 [00:22<00:00, 14.32it/s, loss=0.9570]


Epoch 4 - Training Loss: 0.7412


Training: 100%|██████████| 320/320 [00:22<00:00, 14.44it/s, loss=0.9550]


Epoch 5 - Training Loss: 0.6876


Training: 100%|██████████| 320/320 [00:21<00:00, 14.55it/s, loss=0.9267]


Epoch 6 - Training Loss: 0.6365


Training: 100%|██████████| 320/320 [00:21<00:00, 14.59it/s, loss=0.9058]


Epoch 7 - Training Loss: 0.5959


Training: 100%|██████████| 320/320 [00:21<00:00, 14.67it/s, loss=0.9266]


Epoch 8 - Training Loss: 0.5636


Training: 100%|██████████| 320/320 [00:22<00:00, 14.51it/s, loss=0.8863]


Epoch 9 - Training Loss: 0.5343


Training: 100%|██████████| 320/320 [00:22<00:00, 14.47it/s, loss=0.8269]


Epoch 10 - Training Loss: 0.5072


Training: 100%|██████████| 320/320 [00:22<00:00, 14.44it/s, loss=0.8090]


Epoch 11 - Training Loss: 0.4878


Training: 100%|██████████| 320/320 [00:21<00:00, 14.58it/s, loss=0.8181]


Epoch 12 - Training Loss: 0.4802


Training: 100%|██████████| 320/320 [00:22<00:00, 14.36it/s, loss=0.7943]


Epoch 13 - Training Loss: 0.4516


Training: 100%|██████████| 320/320 [00:22<00:00, 14.31it/s, loss=0.7865]


Epoch 14 - Training Loss: 0.4359


Training: 100%|██████████| 320/320 [00:22<00:00, 14.52it/s, loss=0.7837]


Epoch 15 - Training Loss: 0.4218


Training: 100%|██████████| 320/320 [00:22<00:00, 14.51it/s, loss=0.7627]


Epoch 16 - Training Loss: 0.4105


Training: 100%|██████████| 320/320 [00:22<00:00, 14.53it/s, loss=0.7682]


Epoch 17 - Training Loss: 0.4156


Training: 100%|██████████| 320/320 [00:21<00:00, 14.63it/s, loss=0.7632]


Epoch 18 - Training Loss: 0.3930


Training: 100%|██████████| 320/320 [00:22<00:00, 14.44it/s, loss=0.7596]


Epoch 19 - Training Loss: 0.3964


Training: 100%|██████████| 320/320 [00:22<00:00, 14.47it/s, loss=0.7554]


Epoch 20 - Training Loss: 0.3834


Training: 100%|██████████| 320/320 [00:21<00:00, 14.60it/s, loss=0.7639]


Epoch 21 - Training Loss: 0.3759


Training: 100%|██████████| 320/320 [00:21<00:00, 14.55it/s, loss=0.7604]


Epoch 22 - Training Loss: 0.3713


Training: 100%|██████████| 320/320 [00:22<00:00, 14.54it/s, loss=0.7495]


Epoch 23 - Training Loss: 0.3690


Training: 100%|██████████| 320/320 [00:22<00:00, 14.40it/s, loss=0.7542]


Epoch 24 - Training Loss: 0.3631


Training: 100%|██████████| 320/320 [00:22<00:00, 14.39it/s, loss=0.7545]


Epoch 25 - Training Loss: 0.3599


Training: 100%|██████████| 320/320 [00:22<00:00, 14.46it/s, loss=0.7330]


Epoch 26 - Training Loss: 0.3586


Training: 100%|██████████| 320/320 [00:22<00:00, 14.52it/s, loss=0.7421]


Epoch 27 - Training Loss: 0.3449


Training: 100%|██████████| 320/320 [00:21<00:00, 14.62it/s, loss=0.7339]


Epoch 28 - Training Loss: 0.3502


Training: 100%|██████████| 320/320 [00:22<00:00, 14.40it/s, loss=0.7461]


Epoch 29 - Training Loss: 0.3403


Training: 100%|██████████| 320/320 [00:22<00:00, 14.51it/s, loss=0.7220]


Epoch 30 - Training Loss: 0.3369


Training: 100%|██████████| 320/320 [00:22<00:00, 14.46it/s, loss=0.7255]


Epoch 31 - Training Loss: 0.3354


Training: 100%|██████████| 320/320 [00:21<00:00, 14.60it/s, loss=0.7261]


Epoch 32 - Training Loss: 0.3290


Training: 100%|██████████| 320/320 [00:21<00:00, 14.63it/s, loss=0.7400]


Epoch 33 - Training Loss: 0.3318


Training: 100%|██████████| 320/320 [00:22<00:00, 14.51it/s, loss=0.7226]


Epoch 34 - Training Loss: 0.3253


Training: 100%|██████████| 320/320 [00:21<00:00, 14.55it/s, loss=0.7210]


Epoch 35 - Training Loss: 0.3236


Training: 100%|██████████| 320/320 [00:22<00:00, 14.50it/s, loss=0.7095]


Epoch 36 - Training Loss: 0.3218


Training: 100%|██████████| 320/320 [00:22<00:00, 14.48it/s, loss=0.7083]


Epoch 37 - Training Loss: 0.3175


Training: 100%|██████████| 320/320 [00:22<00:00, 14.49it/s, loss=0.7136]


Epoch 38 - Training Loss: 0.3241


Training: 100%|██████████| 320/320 [00:22<00:00, 14.52it/s, loss=0.7072]


Epoch 39 - Training Loss: 0.3262


Training: 100%|██████████| 320/320 [00:21<00:00, 14.56it/s, loss=0.7096]


Epoch 40 - Training Loss: 0.3163


Training: 100%|██████████| 320/320 [00:22<00:00, 14.54it/s, loss=0.7077]


Epoch 41 - Training Loss: 0.3153


Training: 100%|██████████| 320/320 [00:22<00:00, 14.53it/s, loss=0.7042]


Epoch 42 - Training Loss: 0.3184


Training: 100%|██████████| 320/320 [00:22<00:00, 14.46it/s, loss=0.7027]


Epoch 43 - Training Loss: 0.3145


Training: 100%|██████████| 320/320 [00:22<00:00, 14.48it/s, loss=0.7077]


Epoch 44 - Training Loss: 0.3107


Training: 100%|██████████| 320/320 [00:22<00:00, 14.53it/s, loss=0.7067]


Epoch 45 - Training Loss: 0.3068


Training: 100%|██████████| 320/320 [00:22<00:00, 14.52it/s, loss=0.7022]


Epoch 46 - Training Loss: 0.3027


Training: 100%|██████████| 320/320 [00:21<00:00, 14.57it/s, loss=0.6987]


Epoch 47 - Training Loss: 0.3027


Training: 100%|██████████| 320/320 [00:21<00:00, 14.59it/s, loss=0.6960]


Epoch 48 - Training Loss: 0.3045


Training: 100%|██████████| 320/320 [00:22<00:00, 14.50it/s, loss=0.7107]


Epoch 49 - Training Loss: 0.3018


Training: 100%|██████████| 320/320 [00:21<00:00, 14.62it/s, loss=0.7067]


Epoch 50 - Training Loss: 0.3003


In [30]:
model.eval()
    
all_preds = []
all_targets = []

print("Running inference...")
with torch.no_grad():
    for batch in tqdm(test_loader):
        try:
            sequences = batch['sequences'].to(device)
            structures = batch['structures'].to(device)
            lengths = batch['lengths'].to(device)
            
            # Forward pass
            logits = model(sequences, lengths)
            
            # Get predictions
            preds = torch.argmax(logits, dim=-1)  # (batch_size, seq_len)
            
            # Convert to flat arrays for evaluation, ignoring padding
            for pred, target, length in zip(preds, structures, lengths):
                # Only consider predictions up to actual sequence length
                pred = pred[:length].cpu().numpy()
                target = target[:length].cpu().numpy()
                
                # Ignore padded positions (target == -100)
                mask = target != -100
                all_preds.append(pred[mask])
                all_targets.append(target[mask])
        except:
            break


# Calculate per-class metrics
per_class_metrics = {}
labels = ['.', '(', ')']
class_precision, class_recall, class_f1, support = precision_recall_fscore_support(
    all_targets, 
    all_preds
)

Running inference...


 11%|█         | 34/320 [00:01<00:12, 22.66it/s]


KeyError: 'A'

In [34]:
all_preds[0]

array([0, 0, 0, ..., 2, 2, 0])

In [35]:
all_targets[0]

array([0, 0, 0, ..., 2, 2, 0])