## Training a Neural Network Model

- using `pytorch`

https://pytorch.org/tutorials/beginner/basics/buildmodel_tutorial.html

In [1]:
import sys
import os

# Add the path to the 'code' directory
sys.path.append(os.path.abspath('../code'))

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision import datasets, transforms, models
from torchvision.utils import make_grid
import pandas as pd
import numpy as np

# include feature engineering pipeline
from feature_eng_pipeline import pipeline_nn

In [3]:
# switch to using CUDA - GPU
device = (
    "cuda"
    if torch.cuda.is_available()
    else "mps"
    if torch.backends.mps.is_available()
    else "cpu"
)
print(f"Using {device} device")

Using cuda device


In [4]:
torch.cuda.is_available()

True

In [21]:
# loading data
data_path = '../../data/mean_with_labels.csv'

class RNANanoporeDataset(Dataset):
    """Dataset used to train and test RNA Nanopore data"""

    def __init__(self, csv_file):
        """Initializes instance of class RNANanoporeDataset.

        Args:
            csv_file (str): Path to the csv file with the nanopore data
        """

        self.df = pd.read_csv(csv_file)
        v, s, X_df, y_df = pipeline_nn(self.df)
        X_drop = X_df.drop(["transcript_name", "gene_id", "nucleotide_seq"], axis=1).reset_index(drop=True)  

        # TODO: for now we drop all trigram columns
        self.X = X_drop[["json_position", "dwelling_time_min1", "sd_min1", "mean_min1", "dwelling_time", "sd", "mean", "dwelling_time_plus1", "sd_plus1", "mean_plus1"]]
        self.y = y_df.reset_index(drop=True).squeeze()  

    def __len__(self):
        """Returns the size of the dataset"""
        return len(self.X)

    def __getitem__(self, idx):
    # Handle if idx is a tensor (converting to list if needed)
        if isinstance(idx, torch.Tensor):
            idx = idx.tolist()

        signal_features = self.X.iloc[idx].values  
        label = self.y.iloc[idx]  

        # Convert to tensors
        signal_features = torch.tensor(signal_features, dtype=torch.float32)
        label = torch.tensor(label, dtype=torch.float32)

        return signal_features, label


In [36]:
# preparing data for training using DataLoaders
from torch.utils.data.dataset import random_split

dataset = RNANanoporeDataset(data_path)
train_size = int(0.8 * len(dataset))
test_size = len(dataset) - train_size
trainset, testset = random_split(dataset, [train_size, test_size])

# Dataloaders
trainloader = DataLoader(trainset, batch_size=256, shuffle=True)
testloader = DataLoader(testset, batch_size=256, shuffle=False)

Final datasetp preview:    json_position  dwelling_time_min1   sd_min1  mean_min1  dwelling_time  \
0      -1.167740            0.205115 -0.093805  -1.129683       0.162595   
1      -1.167740           -1.243149 -0.135977  -1.120311      -0.879851   
2      -1.166286           -0.834348  2.163181   0.889406       1.298831   
3      -1.166286           -0.068915 -0.988879  -0.556731      -0.627451   
4      -1.165559            0.602928 -0.207387   0.883266      -0.869697   

         sd      mean  dwelling_time_plus1  sd_plus1  mean_plus1  
0  1.967826  0.531283            -0.868368  0.233028   -1.451873  
1  1.899411  0.572909            -0.352487  0.843539   -2.066144  
2 -0.595935 -1.297428            -0.779589 -1.095111    0.038191  
3 -0.497492 -0.996726            -1.390108  0.406858   -1.292909  
4  0.594147  0.764529             0.004848  0.661955   -0.154392  


In [37]:
class ModNet(nn.Module):
    def __init__(self, signal_input_dim):
        super(ModNet, self).__init__()

        # Read-level Encoder: MLP with two hidden layers
        self.encoder = nn.Sequential(
            nn.Linear(signal_input_dim, 150),  # Change hidden dimensions as needed
            nn.ReLU(),
            nn.Linear(150, 32),
            nn.ReLU(),
            nn.Linear(32, 1)  # Single output for binary classification
        )

    def forward(self, signal_features):
        read_level_probs = self.encoder(signal_features)
        return torch.sigmoid(read_level_probs)  # Apply sigmoid for probabilities


    def noisy_or_pooling(self, read_level_probs):
        """
        :param read_level_probs: Tensor of shape (batch_size, 1)
        :return: Site-level modification probability for each site (batch_size, 1)
        """
        site_level_probs = 1 - torch.prod(1 - read_level_probs, dim=1)
        return site_level_probs


In [24]:
# Set CUDA launch blocking for better error messages
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [39]:
# Assume that ModNet is already defined
model = ModNet(signal_input_dim=10) 
               #trigram_vocab_size=64,
               #embedding_dim=20)

# Loss function and optimizer
criterion = nn.BCEWithLogitsLoss()  # Binary Cross-Entropy Loss for binary classification

In [40]:
def train_model_with_checks(model, trainloader, criterion, optimizer, num_epochs=10, clip_value=1.0):
    """
    Training loop with additional checks to prevent NaN losses
    """
    model.to(device) 
    model.train()
    
    # Track losses for monitoring
    all_losses = []
    
    for epoch in range(num_epochs):
        running_loss = 0.0
        
        for i, data in enumerate(trainloader, 0):
            signal_features, labels = data
            signal_features = signal_features.to(device)
            labels = labels.to(device).float()

            # Zero gradients
            optimizer.zero_grad()
            
            try:
                # Forward pass
                outputs = model(signal_features)
                loss = criterion(outputs, labels.view(-1, 1))
                
                # Backward pass
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), clip_value)
                optimizer.step()
                
                # Track loss
                all_losses.append(loss.item())
                running_loss += loss.item()
                
                if i % 10 == 9:
                    avg_loss = running_loss / 10
                    print(f'Epoch [{epoch+1}/{num_epochs}], Batch [{i+1}], Loss: {avg_loss:.4f}')
                    running_loss = 0.0
                    
            except RuntimeError as e:
                print(f"Runtime Error in batch {i}:", e)
                continue
    
    return all_losses

# Additional helper functions for model debugging
def check_model_weights(model):
    """Check if model weights are properly initialized"""
    for name, param in model.named_parameters():
        print(f"\nLayer: {name}")
        print(f"Shape: {param.shape}")
        print(f"Mean: {param.mean().item():.4f}")
        print(f"Std: {param.std().item():.4f}")
        print(f"Min: {param.min().item():.4f}")
        print(f"Max: {param.max().item():.4f}")

def initialize_weights(model):
    """Initialize model weights properly"""
    for m in model.modules():
        if isinstance(m, (torch.nn.Linear, torch.nn.Conv1d)):
            torch.nn.init.kaiming_normal_(m.weight)
            if m.bias is not None:
                torch.nn.init.zeros_(m.bias)

# print("Checking model weights before training:")
# check_model_weights(model)

# # Initialize weights properly
# initialize_weights(model)

# print("\nChecking model weights after initialization:")
# check_model_weights(model)

# Modified optimizer with gradient clipping
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
losses = train_model_with_checks(model, trainloader, criterion, optimizer, num_epochs=10)

Epoch [1/10], Batch [10], Loss: 0.7113
Epoch [1/10], Batch [20], Loss: 0.6950
Epoch [1/10], Batch [30], Loss: 0.6853
Epoch [1/10], Batch [40], Loss: 0.6837
Epoch [1/10], Batch [50], Loss: 0.6781
Epoch [1/10], Batch [60], Loss: 0.6720
Epoch [1/10], Batch [70], Loss: 0.6659
Epoch [1/10], Batch [80], Loss: 0.6644
Epoch [1/10], Batch [90], Loss: 0.6586
Epoch [1/10], Batch [100], Loss: 0.6549
Epoch [1/10], Batch [110], Loss: 0.6553
Epoch [1/10], Batch [120], Loss: 0.6466
Epoch [1/10], Batch [130], Loss: 0.6384
Epoch [1/10], Batch [140], Loss: 0.6357
Epoch [1/10], Batch [150], Loss: 0.6356
Epoch [1/10], Batch [160], Loss: 0.6321
Epoch [1/10], Batch [170], Loss: 0.6276
Epoch [1/10], Batch [180], Loss: 0.6225
Epoch [1/10], Batch [190], Loss: 0.6255
Epoch [1/10], Batch [200], Loss: 0.6209
Epoch [1/10], Batch [210], Loss: 0.6272
Epoch [1/10], Batch [220], Loss: 0.6141
Epoch [1/10], Batch [230], Loss: 0.6128
Epoch [1/10], Batch [240], Loss: 0.6200
Epoch [1/10], Batch [250], Loss: 0.6136
Epoch [1/

In [41]:
from sklearn.metrics import roc_auc_score, precision_recall_curve, auc

# Function to evaluate on the test set
def evaluate_model(model, testloader, criterion):
    model.eval()  # Set model to evaluation mode
    total_loss = 0.0
    all_labels = []
    all_predictions = []

    with torch.no_grad():
        for data in testloader:
            signal_features, labels = data
            
            # Move data to device
            signal_features = signal_features.to(device)
            labels = labels.to(device)
            
            # Forward pass
            read_level_probs = model(signal_features)
            site_level_probs = model.noisy_or_pooling(read_level_probs).squeeze()  # Shape: (batch_size,)

            # Compute loss
            loss = criterion(site_level_probs, labels.float())
            total_loss += loss.item()

            # Collect predictions and labels for ROC and PR AUC
            all_labels.append(labels.cpu())
            all_predictions.append(site_level_probs.cpu())

    # Convert lists to tensors
    all_labels = torch.cat(all_labels)
    all_predictions = torch.cat(all_predictions)

    # Compute ROC-AUC
    roc_auc = roc_auc_score(all_labels, all_predictions)

    # Compute PR-AUC
    precision, recall, _ = precision_recall_curve(all_labels, all_predictions)
    pr_auc = auc(recall, precision)

    # Average loss
    avg_loss = total_loss / len(testloader)
    
    print(f'Test Loss: {avg_loss:.4f}, ROC-AUC: {roc_auc:.4f}, PR-AUC: {pr_auc:.4f}')
    
    return avg_loss, roc_auc, pr_auc

evaluate_model(model, testloader, criterion)

Test Loss: 0.5807, ROC-AUC: 0.9303, PR-AUC: 0.9249


(0.5806522011756897,
 np.float64(0.9303084699078698),
 np.float64(0.9249033300115246))