# Cell 1: Setup and Imports

This cell imports all necessary libraries. We import standard data science tools (`pandas`, `sklearn`), PyTorch, and crucially, PyTorch Geometric (`torch_geometric`) for handling graph data. We also import `os` and `sys` to manage our project paths.

In [None]:
import os
import sys
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import random_split
from torch_geometric.data import Dataset, Batch
from torch_geometric.loader import DataLoader
from sklearn.metrics import roc_auc_score, f1_score, accuracy_score, precision_recall_curve, auc
from tqdm.notebook import tqdm
import numpy as np

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")

# Cell 2: Path Configuration & Custom Module Imports

This is a critical step for running notebooks inside a project. We add the `src` directory (which is one level up, `../`) to the system path. This allows us to import our own custom Python modules (`FeatureEngineer`, `DTIPredictor`, `load_config`) just like in the Streamlit app.

In [None]:
SRC_PATH = os.path.abspath(os.path.join(os.getcwd(), '..', 'src'))
if SRC_PATH not in sys.path:
    sys.path.append(SRC_PATH)

try:
    from preprocessing.feature_engineer import FeatureEngineer
    from models.dti_model import DTIPredictor
    from utils.config_loader import load_config
except ImportError as e:
    print(f"Error importing custom modules: {e}")
    print(f"Please ensure __init__.py files exist in src subdirectories and all requirements are installed.")

# Cell 3: Load Configuration and Set Hyperparameters

Instead of hard-coding parameters, we load the *exact same* `config.yaml` file the Streamlit app uses. This ensures consistency between training and deployment. We define our training-specific hyperparameters (epochs, batch size, etc.) here.

In [None]:
# Load configuration
CONFIG_PATH = '../config/config.yaml'
config = load_config(CONFIG_PATH)

# Training Hyperparameters
EPOCHS = 50
BATCH_SIZE = 32
LEARNING_RATE = 1e-4
VALIDATION_SPLIT = 0.2
RANDOM_SEED = 42

# Set device (GPU if available, else CPU)
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {DEVICE}")
print(f"Loaded model hyperparameters: {config['model']['hyperparameters']}")

# Cell 4: Load Training Data

We load the `relationships.tsv` file. This file should contain the ground-truth data, including SMILES strings for drugs, amino acid sequences for proteins, and the binary label (1 for interaction, 0 for no interaction).

In [None]:
DATA_PATH = '../data/relationships.tsv'
try:
    df = pd.read_csv(DATA_PATH, sep='\t')
    # We assume columns are named 'drug_smiles', 'protein_sequence', 'label'
    # Adjust column names if your .tsv file is different
    print(f"Loaded data with {len(df)} samples.")
    print(df.head())
    print("\nData Info:")
    df.info()
    print("\nLabel distribution:")
    print(df['label'].value_counts(normalize=True))
except FileNotFoundError:
    print(f"Error: Data file not found at {DATA_PATH}")
except KeyError as e:
    print(f"Error: Missing expected column {e}. Please check your .tsv file.")

# Cell 5: Define Custom PyTorch Geometric Dataset

This is the core of our data pipeline. We create a custom `Dataset` class that integrates our `FeatureEngineer`.

1.  `__init__`: Stores the DataFrame and an instance of our `FeatureEngineer`.
2.  `len`: Returns the total number of samples.
3.  `get`: This is the magic. For a given index `idx`:
    * It fetches the SMILES string and protein sequence.
    * It uses `self.feature_engineer.featurize()` to convert them into a PyG `Data` object (a graph).
    * It attaches the `label` to the graph object as `data.y`.
    * It includes robust error handling, as featurization for complex molecules can sometimes fail.

In [None]:
class DTIDataset(Dataset):
    def __init__(self, df, feature_engineer):
        super().__init__()
        self.df = df
        self.feature_engineer = feature_engineer

    def len(self):
        return len(self.df)

    def get(self, idx):
        row = self.df.iloc[idx]
        try:
            # Adjust column names if different
            smiles = row['drug_smiles']
            sequence = row['protein_sequence']
            label = row['label']

            # Use the feature engineer to create the graph data object
            data = self.feature_engineer.featurize(smiles, sequence)
            
            # Attach the label
            data.y = torch.tensor([label], dtype=torch.float)
            return data
        
        except Exception as e:
            #print(f"Warning: Skipping index {idx}. Failed to featurize: {e}")
            return None # Will be filtered by our custom collate function


# Cell 6: Instantiate Dataset and Create DataLoaders

1.  **Instantiate `FeatureEngineer`**: We use the settings from our `config.yaml`.
2.  **Instantiate `DTIDataset`**: We pass the DataFrame and the feature engineer to it.
3.  **Split Data**: We split the full dataset into training and validation sets.
4.  **Define `collate_fn`**: This function is crucial. It gathers a list of `Data` objects into a `Batch` and filters out any `None` values that resulted from featurization errors.
5.  **Create `DataLoaders`**: We create PyG `DataLoader`s for both train and validation sets, using our `collate_fn`.

In [None]:
# 1. Instantiate FeatureEngineer
feature_engineer = FeatureEngineer(config['model']['featurization'])

# 2. Instantiate Dataset
print("Initializing dataset... This may take a moment as it checks data.")
dataset = DTIDataset(df, feature_engineer)
print("Dataset initialized.")

# 3. Split Data
torch.manual_seed(RANDOM_SEED)
val_size = int(len(dataset) * VALIDATION_SPLIT)
train_size = len(dataset) - val_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
print(f"Training samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

# 4. Define custom collate function to filter None values
def collate_fn(batch):
    batch = [item for item in batch if item is not None]
    if not batch:
        return None
    return Batch.from_data_list(batch)

# 5. Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collate_fn, num_workers=2, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, collate_fn=collate_fn, num_workers=2, pin_memory=True)

print("DataLoaders created.")

# Cell 7: Define Model, Loss, and Optimizer

1.  **Model**: We instantiate our `DTIPredictor` using the hyperparameters from `config.yaml` and move it to the `DEVICE`.
2.  **Loss Function**: We use `BCEWithLogitsLoss`, which is standard for binary classification. It's numerically stable and expects raw logits from the model.
3.  **Optimizer**: We use `Adam`, a robust and popular optimizer.

In [None]:
# 1. Instantiate Model
model_params = config['model']['hyperparameters']
model = DTIPredictor(**model_params).to(DEVICE)

# 2. Define Loss Function
criterion = nn.BCEWithLogitsLoss() # For binary classification

# 3. Define Optimizer
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE, weight_decay=1e-5)

print(f"Model loaded on {DEVICE}.")
print(model)

# Cell 8: Training and Validation Loop

This is the main training logic. For each epoch, we:

1.  **Train (`model.train()`):**
    * Iterate through the `train_loader`.
    * Move the batch of graphs to the `DEVICE`.
    * Perform the forward pass: `output = model(batch)`.
    * Calculate the loss.
    * Perform the backward pass (`loss.backward()`) and update weights (`optimizer.step()`).

2.  **Validate (`model.eval()`):**
    * Iterate through the `val_loader` with `torch.no_grad()`.
    * Collect all predictions (`all_preds`) and true labels (`all_labels`).
    * Calculate metrics: Loss, Accuracy, ROC-AUC, and F1-Score.

3.  **Save Best Model**: We track the best validation AUC and save the model weights (`.pt` file) only when it improves. This prevents overfitting and ensures we save the most performant model.

In [None]:
MODEL_SAVE_DIR = '../models'
MODEL_SAVE_PATH = os.path.join(MODEL_SAVE_DIR, 'dti_model.pt')

# Ensure the model directory exists
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)

best_val_auc = 0.0
print("Starting training...")

for epoch in range(EPOCHS):
    # --- Training Phase ---
    model.train()
    train_loss = 0.0
    train_loop = tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]", leave=False)
    
    for batch in train_loop:
        if batch is None: # Skip bad batches
            continue
        
        batch = batch.to(DEVICE)
        optimizer.zero_grad()
        
        output = model(batch)
        loss = criterion(output, batch.y)
        
        loss.backward()
        optimizer.step()
        train_loss += loss.item() * batch.num_graphs

    avg_train_loss = train_loss / len(train_loader.dataset)

    # --- Validation Phase ---
    model.eval()
    val_loss = 0.0
    all_preds = []
    all_labels = []
    val_loop = tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]", leave=False)

    with torch.no_grad():
        for batch in val_loop:
            if batch is None: # Skip bad batches
                continue
            
            batch = batch.to(DEVICE)
            output = model(batch)
            loss = criterion(output, batch.y)
            val_loss += loss.item() * batch.num_graphs

            preds = torch.sigmoid(output)
            all_preds.append(preds.cpu())
            all_labels.append(batch.y.cpu())

    avg_val_loss = val_loss / len(val_loader.dataset)
    
    # Calculate metrics
    if not all_labels or not all_preds:
        print(f"Epoch {epoch+1}/{EPOCHS} - Validation data empty. Skipping metrics.")
        continue
        
    all_preds = torch.cat(all_preds).numpy().flatten()
    all_labels = torch.cat(all_labels).numpy().flatten()
    
    val_auc = roc_auc_score(all_labels, all_preds)
    val_preds_binary = (all_preds > 0.5).astype(int)
    val_accuracy = accuracy_score(all_labels, val_preds_binary)
    val_f1 = f1_score(all_labels, val_preds_binary)

    print(f"Epoch {epoch+1}/{EPOCHS} | Train Loss: {avg_train_loss:.4f} | Val Loss: {avg_val_loss:.4f} | Val AUC: {val_auc:.4f} | Val F1: {val_f1:.4f}")

    # Save the best model
    if val_auc > best_val_auc:
        best_val_auc = val_auc
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
        print(f"  -> New best model saved to {MODEL_SAVE_PATH} (Val AUC: {best_val_auc:.4f})")

print("\nTraining complete.")
print(f"Best validation AUC achieved: {best_val_auc:.4f}")
print(f"Final model saved at: {MODEL_SAVE_PATH}")

# Cell 9: Final Check

This cell just verifies that the model file was created in the correct location. Once this notebook is run, your Streamlit application will have the `dti_model.pt` file it needs to load the `CoreProcessor`.

In [None]:
if os.path.exists(MODEL_SAVE_PATH):
    print(f"SUCCESS: Model file found at {MODEL_SAVE_PATH}")
    print(f"File size: {os.path.getsize(MODEL_SAVE_PATH) / (1024*1024):.2f} MB")
else:
    print(f"ERROR: Model file was NOT created at {MODEL_SAVE_PATH}")