# DTI Model Training Pipeline

This notebook imports all model, data, and training logic from the `/src` directory to run a reproducible training pipeline. It no longer contains any class or function definitions.

In [None]:
import sys
import os
import yaml
import torch
import pandas as pd
import numpy as np
import logging
from pathlib import Path
from torch.utils.data import DataLoader
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt

# --- Add Project Root to Python Path ---
# This allows the notebook to find and import modules from the 'src' folder

# Get the absolute path of the 'notebooks' directory (where this file lives)
notebooks_dir = os.path.abspath(os.getcwd())
# Get the project root (one level up)
project_root = os.path.dirname(notebooks_dir)

# Add the project root to sys.path
if project_root not in sys.path:
    sys.path.append(project_root)
# ---------------------------------------


# --- Import all components from our 'src' library ---
try:
    from src.utils.config_loader import load_config
    from src.utils.logger import setup_logging
    from src.models.dti_model import (
        DTIDataset, 
        collate_fn, 
        DTIModel, 
        Trainer
    )
    from src.preprocessing.feature_engineer import ProteinEmbedder
    from src.molecular_3d.conformer_generator import smiles_to_3d_graph
    print("'src' modules imported successfully.")
except ImportError as e:
    print(f"Error: Failed to import modules from 'src'.")
    print(f"Make sure 'src' directory exists at: {os.path.join(project_root, 'src')}")
    print(f"Details: {e}")
    raise

## 1. Setup and Configuration

Load config, set up logging, and define the device. The `load_config` function now correctly resolves all paths relative to the project root.

In [None]:
# Load configuration
# The load_config function is imported from src/utils/config_loader.py
try:
    config = load_config("config/config.yaml") 
    print("Config file loaded and paths resolved.")
except Exception as e:
    print(f"Error loading configuration: {e}")
    raise

# Setup logging
setup_logging(log_path=config['paths'].get('log_file', 'logs/train.log'))
log = logging.getLogger(__name__)
log.info("--- Starting New Training Run (from Notebook) ---")

# Set device
device = torch.device(config['training']['device'] if torch.cuda.is_available() else "cpu")
log.info(f"Using device: {device}")
print(f"Using device: {device}")

# Ensure output directory exists
output_dir = Path(config['paths']['output_dir'])
output_dir.mkdir(parents=True, exist_ok=True)
log.info(f"Outputs will be saved to: {output_dir}")

## 2. Load and Prepare Data

Load the dataset and create training/validation splits.

In [None]:
log.info("Loading dataset...")
data_path = config['paths']['dataset'] # This path is resolved by load_config
if not Path(data_path).exists():
    log.error(f"Dataset file not found at: {data_path}")
    raise FileNotFoundError(f"Dataset file not found at: {data_path}")

df = pd.read_csv(data_path, sep='\t')
# Using the exact same preprocessing logic from your original notebook
df = df[['SMILES', 'sequence', 'label']].dropna().sample(frac=1, random_state=42)
log.info(f"Loaded {len(df)} total data points.")

# Split data
train_df, val_df = train_test_split(
    df, 
    test_size=config['training']['val_split'], 
    random_state=42, 
    stratify=df['label']
)
log.info(f"Training samples: {len(train_df)}, Validation samples: {len(val_df)}")
print(f"Training samples: {len(train_df)}, Validation samples: {len(val_df)}")

train_df.head()

## 3. Initialize DataLoaders

We now use the `DTIDataset` class imported from `src.models.dti_model`. The heavy `ProteinEmbedder` is initialized once and passed to the datasets.

In [None]:
log.info("Initializing Protein Embedder (ESM-2)...")
print("Initializing Protein Embedder (ESM-2)... This may take a moment.")
try:
    # This heavy object is initialized once and passed to the datasets
    protein_embedder = ProteinEmbedder(config)
    print("Protein Embedder loaded.")
except Exception as e:
    log.error(f"Failed to initialize ProteinEmbedder: {e}")
    raise

log.info("Creating datasets and dataloaders...")

# Create dataset instances
# We pass the imported graph_gen_func to the dataset
train_dataset = DTIDataset(
    df=train_df,
    config=config,
    protein_embedder=protein_embedder,
    graph_gen_func=smiles_to_3d_graph
)

val_dataset = DTIDataset(
    df=val_df,
    config=config,
    protein_embedder=protein_embedder,
    graph_gen_func=smiles_to_3d_graph
)

# Create dataloader instances
train_loader = DataLoader(
    train_dataset,
    batch_size=config['training']['batch_size'],
    shuffle=True,
    num_workers=config['training'].get('num_workers', 0),
    collate_fn=collate_fn, # Use the imported collate_fn
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=config['training']['batch_size'],
    shuffle=False,
    num_workers=config['training'].get('num_workers', 0),
    collate_fn=collate_fn, # Use the imported collate_fn
    pin_memory=True
)

log.info("DataLoaders created successfully.")
print("DataLoaders created.")

## 4. Initialize Model

We now use the `DTIModel` class imported from `src.models.dti_model`.

In [None]:
log.info("Initializing DTIModel...")
try:
    model = DTIModel(config).to(device)
    param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
    log.info(f"Model initialized with {param_count:,} trainable parameters.")
    print(f"Model initialized with {param_count:,} trainable parameters.")
except Exception as e:
    log.error(f"Failed to initialize DTIModel: {e}")
    raise

print(model)

## 5. Run Training

We now use the `Trainer` class imported from `src.models.dti_model`.

In [None]:
log.info("Initializing Trainer...")
try:
    trainer = Trainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        config=config
    )
    print("Trainer initialized.")
except Exception as e:
    log.error(f"Failed to initialize Trainer: {e}")
    raise

log.info("--- Starting training loop ---")
print("--- Starting training loop ---")
# This will run for the number of epochs specified in config.yaml
trainer.train() # This is the correct method name from src/models/dti_model.py

log.info("--- Training complete ---")
print("--- Training complete ---")

print(f"Best Validation AUROC: {trainer.best_val_auroc:.4f}")

## 6. Plot Metrics and Save

Visualize the training and validation performance.

In [None]:
# Ensure plots appear inline in the notebook
%matplotlib inline

log.info("Plotting metrics...")
print("Plotting metrics...")
try:
    # This is the correct method name from src/models/dti_model.py
    fig = trainer.plot_metrics() 
    
    # Save plot
    plot_path = output_dir / "training_metrics.png"
    fig.savefig(plot_path)
    log.info(f"Training metrics plot saved to {plot_path}")
    print(f"Training metrics plot saved to {plot_path}")
    
    # Display the plot in the notebook
    plt.show()

except Exception as e:
    log.error(f"Failed to plot metrics: {e}")
    print(f"Failed to plot metrics: {e}")

---
 
End of pipeline. The trained model is saved at `models/dti_model.pth` (or as specified in your config).
 
---