# Genetic Mutation Prioritization - Model Training

This notebook provides an interactive environment to train and evaluate the mutation prioritization models. It includes:
- Data Loading via `src.utils.data_loader`
- Model Initialization (`MLP`)
- Training Loop with `tqdm` progress bars
- Live Loss Visualization

In [None]:
import os
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

# Set OpenMP env var
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

# Add project root to path
project_root = os.path.abspath(os.path.join(os.getcwd(), '..'))
if project_root not in sys.path:
    sys.path.append(project_root)

from src.config.data_config import Config
from src.utils.data_loader import get_data_loaders
from src.models.mlp import MLP
from src.evaluation.eval_metrics import calculate_metrics

## 1. Load Configuration

In [None]:
config_path = os.path.join(project_root, 'src/config/config.yaml')
config = Config(config_path)

print("Model Type:", config.model['type'])
print("Epochs:", config.training['epochs'])
print("Device:", "cuda" if torch.cuda.is_available() else "cpu")

## 2. Load Data

In [None]:
try:
    train_loader, val_loader, test_loader, input_dim = get_data_loaders(config)
    print(f"Data loaded successfully. Input feature dimension: {input_dim}")
    print(f"Training batches: {len(train_loader)}")
    print(f"Validation batches: {len(val_loader)}")
except Exception as e:
    print(f"Error loading data: {e}")
    print("Ensure 'data/processed/feature_matrix.csv' exists.")

## 3. Initialize Model

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = MLP(
    input_dim=input_dim, 
    hidden_layers=config.model['mlp']['hidden_layers'], 
    dropout=config.model['mlp']['dropout']
).to(device)

optimizer = optim.Adam(model.parameters(), lr=config.training['learning_rate'])
criterion = nn.BCEWithLogitsLoss()

## 4. Training Loop

In [None]:
epochs = config.training['epochs']
train_losses = []
val_losses = []

print("Starting training...")

for epoch in range(epochs):
    model.train()
    running_loss = 0.0
    
    # Training Step
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)
    for X_batch, y_batch in pbar:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        
        optimizer.zero_grad()
        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        pbar.set_postfix({'loss': loss.item()})
        
    avg_train_loss = running_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    
    # Validation Step
    model.eval()
    val_loss = 0.0
    all_labels = []
    all_probs = []
    
    with torch.no_grad():
        for X_batch, y_batch in val_loader:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)
            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)
            val_loss += loss.item()
            
            probs = torch.sigmoid(outputs)
            all_labels.extend(y_batch.cpu().numpy())
            all_probs.extend(probs.cpu().numpy())
            
    avg_val_loss = val_loss / len(val_loader)
    val_losses.append(avg_val_loss)
    
    metrics = calculate_metrics(np.array(all_labels), np.array(all_probs))
    
    print(f"Epoch {epoch+1}: Train Loss={avg_train_loss:.4f}, Val Loss={avg_val_loss:.4f}, Val AUC={metrics['auc']:.4f}")

## 5. Visualization

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Training Convergence')
plt.legend()
plt.grid(True)
plt.show()

## 6. Optimization
You can now save the model or proceed to evaluation.

In [None]:
save_path = os.path.join(project_root, 'reports/results/checkpoints/manual_notebook_model.pth')
os.makedirs(os.path.dirname(save_path), exist_ok=True)
torch.save(model.state_dict(), save_path)
print(f"Model saved to {save_path}")