In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from pathlib import Path
import gc
import logging
from tqdm import tqdm

# local imports
from src.data import RNADataset
from src.model import RNATransformer, RNAStructurePredictor
from src.train import train_epoch, validate_epoch, train_model, create_padding_mask

import warnings
warnings.filterwarnings('ignore')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [16]:
class SimpleRNAPredictor(nn.Module):
    def __init__(self, input_size=5, hidden_size=64, out_size=61, num_hidden=6):
        super().__init__()
        self.in_layer = nn.Linear(input_size, hidden_size)
        self.out_layer = nn.Linear(hidden_size, out_size)
        self.hidden_layers = []
        
        # for _ in range(num_hidden):
        #     self.hidden_layers.append(nn.Linear(hidden_size, hidden_size))
        
        self.activation = nn.ReLU()

    def forward(self, x):
        out = self.activation(self.in_layer(x))
        # for layer in self.hidden_layers:
        #     out = self.activation(layer(out))
        out = nn.Softmax()(self.out_layer(out))
        return out

In [17]:
# Configuration

batch_size = 32
num_epochs = 40
learning_rate = 0.3
model_save_dir = Path('models')

In [18]:
# Create model save directory
model_save_dir.mkdir(parents=True, exist_ok=True)

# Load dataset
train_dataset = RNADataset.load("data/train_dataset.pkl")
val_dataset = RNADataset.load("data/val_dataset.pkl")
test_dataset = RNADataset.load("data/test_dataset.pkl")

In [19]:
# Create data loaders
train_loader = DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=True,
    # collate_fn=lambda b: collate_fn(b, device)
)

val_loader = DataLoader(
    test_dataset,
    batch_size=batch_size,
    shuffle=False,
    # collate_fn=lambda b: collate_fn(b, device)
)

In [20]:
# model = RNATransformer(
#     num_nucleotides=5,  # A, U, G, C, N
#     num_structure_labels=60,  
#     d_model=128
# ).to(device)
# predictor = RNAStructurePredictor(model)


# model = SimpleRNAPredictor(input_size=5).to(device)

In [21]:
model = SimpleRNAPredictor().to(device)


# Use mixed precision criterion
criterion = nn.CrossEntropyLoss(ignore_index=-100)
optimizer = optim.Adam(model.parameters(), lr=learning_rate)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=5, verbose=True
)

In [22]:
best_loss = float('inf')
validate_every = 5
for epoch in range(num_epochs):
    # Clear cache before each epoch
    gc.collect()
    torch.cuda.empty_cache()
    
    # Train
    train_loss = train_epoch(
        model,
        train_loader, 
        criterion, 
        optimizer,
        device=device
    )
    print(f'Epoch {epoch+1} - Training Loss: {train_loss:.4f}')
    
    if (epoch + 1) % validate_every == 0:
        
        val_loss, metrics = validate_epoch(model, val_loader, criterion, device=device)
        # print(20*"=" + " : Validation : " + 20*"=")
        print(f'Validation Loss: {val_loss:.4f}')
        # print(20*"=" + " : Validation : " + 20*"=")
    
    # Update learning rate
    scheduler.step(train_loss)
    
    # Save checkpoint
    if train_loss < best_loss:
        best_loss = train_loss
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': best_loss,
        }, 'best_model.pt')
    
    # Optional: clear cache after each epoch
    gc.collect()
    torch.cuda.empty_cache()

Training: 100%|██████████| 160/160 [00:09<00:00, 16.08it/s, loss=3.8872, avg_loss=3.9361]


Epoch 1 - Training Loss: 3.9115


Training: 100%|██████████| 160/160 [00:13<00:00, 12.30it/s, loss=3.9870, avg_loss=3.9071]


Epoch 2 - Training Loss: 3.9071


Training: 100%|██████████| 160/160 [00:12<00:00, 12.91it/s, loss=3.8550, avg_loss=3.9541]


Epoch 3 - Training Loss: 3.9046


Training: 100%|██████████| 160/160 [00:08<00:00, 19.94it/s, loss=3.9719, avg_loss=3.9093]


Epoch 4 - Training Loss: 3.9093


Training: 100%|██████████| 160/160 [00:07<00:00, 20.01it/s, loss=3.9526, avg_loss=3.9557]


Epoch 5 - Training Loss: 3.9062


Validation: 100%|██████████| 160/160 [00:07<00:00, 20.91it/s, loss=3.8562, avg_loss=3.9572]


Validation Loss: 3.9077


Training: 100%|██████████| 160/160 [00:07<00:00, 20.23it/s, loss=3.9440, avg_loss=3.9351]


Epoch 6 - Training Loss: 3.9105


Training: 100%|██████████| 160/160 [00:07<00:00, 20.19it/s, loss=3.9045, avg_loss=3.9452]


Epoch 7 - Training Loss: 3.9205


Training: 100%|██████████| 160/160 [00:11<00:00, 14.29it/s, loss=3.7986, avg_loss=3.9597]


Epoch 8 - Training Loss: 3.9350


Training: 100%|██████████| 160/160 [00:12<00:00, 12.40it/s, loss=3.9720, avg_loss=3.9357]


Epoch 9 - Training Loss: 3.9357


Training: 100%|██████████| 160/160 [00:12<00:00, 12.40it/s, loss=3.9806, avg_loss=3.9589]


Epoch 10 - Training Loss: 3.9342


Validation: 100%|██████████| 160/160 [00:12<00:00, 12.72it/s, loss=3.8951, avg_loss=3.9598]


Validation Loss: 3.9351


Training: 100%|██████████| 160/160 [00:12<00:00, 12.40it/s, loss=3.9621, avg_loss=3.9350]


Epoch 11 - Training Loss: 3.9350


Training: 100%|██████████| 160/160 [00:12<00:00, 12.40it/s, loss=3.9453, avg_loss=3.9344]


Epoch 12 - Training Loss: 3.9344


Training: 100%|██████████| 160/160 [00:12<00:00, 12.41it/s, loss=3.9212, avg_loss=3.9346]


Epoch 13 - Training Loss: 3.9346


Training: 100%|██████████| 160/160 [00:12<00:00, 12.42it/s, loss=3.9282, avg_loss=3.9605]


Epoch 14 - Training Loss: 3.9358


Training: 100%|██████████| 160/160 [00:12<00:00, 12.43it/s, loss=3.8872, avg_loss=3.9343]


Epoch 15 - Training Loss: 3.9343


Validation: 100%|██████████| 160/160 [00:11<00:00, 14.36it/s, loss=3.8951, avg_loss=3.9598]


Validation Loss: 3.9351


Training: 100%|██████████| 160/160 [00:07<00:00, 20.24it/s, loss=3.9383, avg_loss=3.9347]


Epoch 16 - Training Loss: 3.9347


Training: 100%|██████████| 160/160 [00:07<00:00, 20.27it/s, loss=3.8960, avg_loss=3.9587]


Epoch 17 - Training Loss: 3.9339


Training: 100%|██████████| 160/160 [00:07<00:00, 20.30it/s, loss=3.8741, avg_loss=3.9352]


Epoch 18 - Training Loss: 3.9352


Training: 100%|██████████| 160/160 [00:07<00:00, 20.32it/s, loss=3.7495, avg_loss=3.9838]


Epoch 19 - Training Loss: 3.9340


Training: 100%|██████████| 160/160 [00:11<00:00, 13.56it/s, loss=3.9105, avg_loss=3.9835]


Epoch 20 - Training Loss: 3.9337


Validation: 100%|██████████| 160/160 [00:11<00:00, 13.66it/s, loss=3.8951, avg_loss=3.9351]


Validation Loss: 3.9351


Training: 100%|██████████| 160/160 [00:12<00:00, 12.77it/s, loss=3.9432, avg_loss=3.9578]


Epoch 21 - Training Loss: 3.9330


Training: 100%|██████████| 160/160 [00:12<00:00, 12.41it/s, loss=3.8710, avg_loss=3.9600]


Epoch 22 - Training Loss: 3.9353


Training: 100%|██████████| 160/160 [00:12<00:00, 12.42it/s, loss=3.9349, avg_loss=3.9353]


Epoch 23 - Training Loss: 3.9353


Training: 100%|██████████| 160/160 [00:12<00:00, 12.41it/s, loss=3.8966, avg_loss=3.9601]


Epoch 24 - Training Loss: 3.9353


Training: 100%|██████████| 160/160 [00:12<00:00, 12.42it/s, loss=3.9203, avg_loss=3.9598]


Epoch 25 - Training Loss: 3.9351


Validation: 100%|██████████| 160/160 [00:12<00:00, 12.72it/s, loss=3.8951, avg_loss=3.9598]


Validation Loss: 3.9351


Training: 100%|██████████| 160/160 [00:12<00:00, 12.43it/s, loss=3.9237, avg_loss=3.9582]


Epoch 26 - Training Loss: 3.9335


Training: 100%|██████████| 160/160 [00:12<00:00, 12.39it/s, loss=3.9008, avg_loss=3.9600]


Epoch 27 - Training Loss: 3.9353


Training: 100%|██████████| 160/160 [00:12<00:00, 12.39it/s, loss=3.9255, avg_loss=3.9349]


Epoch 28 - Training Loss: 3.9349


Training: 100%|██████████| 160/160 [00:12<00:00, 12.38it/s, loss=3.9279, avg_loss=3.9584]


Epoch 29 - Training Loss: 3.9336


Training: 100%|██████████| 160/160 [00:08<00:00, 18.40it/s, loss=3.8346, avg_loss=3.9337]


Epoch 30 - Training Loss: 3.9337


Validation: 100%|██████████| 160/160 [00:07<00:00, 20.94it/s, loss=3.8951, avg_loss=3.9351]


Validation Loss: 3.9351


Training: 100%|██████████| 160/160 [00:07<00:00, 20.28it/s, loss=3.9242, avg_loss=3.9580]


Epoch 31 - Training Loss: 3.9333


Training: 100%|██████████| 160/160 [00:07<00:00, 20.28it/s, loss=3.9322, avg_loss=3.9336]


Epoch 32 - Training Loss: 3.9336


Training: 100%|██████████| 160/160 [00:07<00:00, 20.30it/s, loss=3.9304, avg_loss=3.9596]


Epoch 33 - Training Loss: 3.9349


Training: 100%|██████████| 160/160 [00:07<00:00, 20.29it/s, loss=3.9214, avg_loss=3.9353]


Epoch 34 - Training Loss: 3.9353


Training: 100%|██████████| 160/160 [00:07<00:00, 20.30it/s, loss=3.9228, avg_loss=3.9582]


Epoch 35 - Training Loss: 3.9334


Validation: 100%|██████████| 160/160 [00:07<00:00, 20.94it/s, loss=3.8951, avg_loss=3.9849]


Validation Loss: 3.9351


Training: 100%|██████████| 160/160 [00:07<00:00, 20.32it/s, loss=3.8971, avg_loss=3.9341]


Epoch 36 - Training Loss: 3.9341


Training: 100%|██████████| 160/160 [00:07<00:00, 20.33it/s, loss=3.9070, avg_loss=3.9591]


Epoch 37 - Training Loss: 3.9343


Training: 100%|██████████| 160/160 [00:07<00:00, 20.34it/s, loss=3.8849, avg_loss=3.9593]


Epoch 38 - Training Loss: 3.9346


Training: 100%|██████████| 160/160 [00:07<00:00, 20.33it/s, loss=3.9364, avg_loss=3.9615]


Epoch 39 - Training Loss: 3.9367


Training: 100%|██████████| 160/160 [00:07<00:00, 20.30it/s, loss=3.9040, avg_loss=3.9832]


Epoch 40 - Training Loss: 3.9334


Validation: 100%|██████████| 160/160 [00:07<00:00, 20.99it/s, loss=3.8951, avg_loss=3.9849]


Validation Loss: 3.9351


AttributeError: 'SimpleRNAPredictor' object has no attribute 'device'