In [None]:
# In: notebooks/06_Train_Baseline_Model.ipynb
# Purpose: Train the baseline LSTM regression model to predict next CDR score
#          using pre-computed features and the prepared data splits.
#          Loads configuration (feature lists, etc.) from the relevant NB04 W&B run.

# Notebook 06: Train Baseline LSTM Model

**Purpose:** Train the baseline LSTM regression model, defined in `src/models.py`, to predict the next CDR score based on longitudinal clinical/demographic features.

**Workflow:**
1.  **Setup:** Load base configuration (`config.json`), define training hyperparameters (`HP`), define paths to data splits and preprocessors.
2.  **Fetch Prior Config:** Use the W&B API to load the definitive configuration (feature lists, preprocessing column lists) logged by the successful run of Notebook 04 (`04_Fit_Preprocessors.ipynb`). This ensures consistency.
3.  **Initialize W&B Run:** Start a *new* W&B run specifically for this training job, logging all hyperparameters (`HP`) including the fetched configuration details.
4.  **Setup Device:** Detect and set the appropriate device (GPU or CPU) for PyTorch.
5.  **Load Data:** Instantiate `OASISDataset` (using the fetched config) for train and validation splits and create PyTorch `DataLoader`s using the custom `pad_collate_fn`.
6.  **Define Model:** Instantiate the `BaselineLSTMRegressor` model, loss function (`MSELoss`), and optimizer (`Adam`).
7.  **Train & Validate:** Run the main training loop:
     * Iterate through epochs.
     * Perform training step (forward pass, loss calculation, backward pass, optimizer step).
     * Perform validation step (forward pass, loss calculation, no gradients).
     * Calculate metrics (MSE, MAE, R2) for train and validation sets.
     * Log metrics to W&B.
     * Implement model checkpointing (save best model based on validation loss) locally and as a W&B artifact.
     * Implement early stopping based on validation loss patience.
 8.  **Finalize:** Finish the W&B run, logging summary metrics.

 **Input:**
 * `cohort_{train|validation}.parquet` (From NB 03)
 * `standard_scaler.joblib`, `simple_imputer_median.joblib` (From NB 04)
 * Configuration logged by NB 04 run (fetched via W&B API)
 * `src/datasets.py`, `src/models.py`

 **Output:**
 * Trained model checkpoints (`.pth` files saved locally in `notebooks/outputs/06_.../` under a run-specific folder).
 * W&B Run: Logs hyperparameters, metrics (loss, MAE, R2), best model checkpoint artifact.


## Setup: Imports, Paths, Config, Hyperparameters

In [None]:
# --- Standard Libraries & Imports ---
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import wandb
import json
from pathlib import Path
import time
import os
import sys
import joblib
from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error

In [None]:
# --- Add src directory to Python path ---
try:
    module_path = os.path.abspath(os.path.join('..'))
    if module_path not in sys.path:
        sys.path.append(module_path)
    print(f"Added {module_path} to sys.path")
    # --- Import custom classes/functions ---
    from src.datasets import OASISDataset, pad_collate_fn
    from src.models import BaselineLSTMRegressor # Import the model definition
    print("Successfully imported custom Dataset and Model.")
except ModuleNotFoundError:
     print("Error: Could not import from src directory.")
     print("Ensure src/datasets.py and src/models.py exist.")
     exit()
except Exception as e:
     print(f"An unexpected error occurred during import: {e}")
     exit()

In [None]:
# --- Config Loading, Hyperparameters & Fetching Prior Config ---
print("\n--- Loading Configuration, Setting Hyperparameters, Fetching Prior Config ---")
CONFIG_PATH = Path('../config.json')
nb04_config = {} # To store config fetched from NB04 run

try:
    # Load base config (paths, W&B project/entity)
    PROJECT_ROOT = CONFIG_PATH.parent.resolve()
    with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
        base_config = json.load(f)
        WANDB_PROJECT = base_config['wandb']['project_name']
        WANDB_ENTITY = base_config['wandb'].get('entity', None)
    print("Base configuration loaded successfully.")

    # --- Define Training Hyperparameters ---
    HP = {
        'batch_size': 32,
        'learning_rate': 1e-4,
        'epochs': 50,
        'lstm_hidden_size': 128,
        'lstm_num_layers': 2,
        'lstm_dropout_prob': 0.3,
        'num_workers': 0,
        'random_seed': 42,
        # Preprocessing strategies below should ideally match NB04 config, used for path finding primarily
        'imputation_strategy': 'median',
        'scaling_strategy': 'StandardScaler',
    }
    print("Training hyperparameters set.")

    # --- Define Paths ---
    OUTPUT_DIR_BASE = PROJECT_ROOT / base_config['data']['output_dir_base']
    NB03_OUTPUT_DIR = OUTPUT_DIR_BASE / "03_Feature_Engineering_Splitting"
    NB04_OUTPUT_DIR = OUTPUT_DIR_BASE / "04_Fit_Preprocessors"
    NB06_OUTPUT_DIR = OUTPUT_DIR_BASE / "06_Train_Baseline_Model"
    NB06_OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

    TRAIN_DATA_PATH = NB03_OUTPUT_DIR / "cohort_train.parquet"
    VAL_DATA_PATH = NB03_OUTPUT_DIR / "cohort_validation.parquet"
    TEST_DATA_PATH = NB03_OUTPUT_DIR / "cohort_test.parquet" # Define but likely use in NB07
    SCALER_PATH = NB04_OUTPUT_DIR / f"{HP['scaling_strategy'].lower()}.joblib" # Construct path dynamically
    IMPUTER_PATH = NB04_OUTPUT_DIR / f"simple_imputer_{HP['imputation_strategy']}.joblib"

    # --- Fetch Configuration from NB04 W&B Run ---
    print("\nFetching configuration from NB04 W&B Run...")
    # *** IMPORTANT: Replace with actual Run Path "entity/project/run_id" from NB04 ***
    nb04_run_id = "RUN_PATH_FROM_NB04"
    # Ensure WANDB_ENTITY and WANDB_PROJECT were loaded correctly from config.json
    if not WANDB_PROJECT or not WANDB_ENTITY:
        raise ValueError("WANDB_PROJECT or WANDB_ENTITY not defined. Check config loading.")
    nb04_run_path = f"{WANDB_ENTITY}/{WANDB_PROJECT}/{nb04_run_id}"

    try:
        # Ensure wandb is logged in if needed (wandb login)
        api = wandb.Api(timeout=19) # Increase timeout if needed
        nb04_run = api.run(nb04_run_path)
        nb04_config = nb04_run.config # Fetch the entire config dict logged by NB04
        print(f"Successfully fetched config from NB04 run: {nb04_run_path}")

        # Extract feature/preprocess lists and add them to HP for logging with this run
        # Use .get() for safety in case keys slightly differ
        features_dict = nb04_config.get('features', {})
        preprocess_dict = nb04_config.get('preprocess', {})

        HP['features_time_varying'] = features_dict.get('time_varying', [])
        HP['features_static'] = features_dict.get('static', [])
        HP['input_size'] = len(HP['features_time_varying']) + len(HP['features_static'])
        HP['preprocess_scaling_cols'] = preprocess_dict.get('scaling_cols', [])
        HP['preprocess_imputation_cols'] = preprocess_dict.get('imputation_cols', [])
        HP['source_config_run_id'] = nb04_run_path # Track provenance

        print("-" * 20)
        print("DEBUG NB06: Config Fetched from W&B (NB04 Run)")
        print(f"  Time Varying ({len(HP['features_time_varying'])}): {HP['features_time_varying']}")
        print(f"  Static ({len(HP['features_static'])}): {HP['features_static']}")
        print(f"  Calculated HP['input_size']: {HP['input_size']}") # Should be 12
        print("-" * 20)

        # Validate loaded config
        if not HP['features_time_varying'] or not HP['features_static'] or HP['input_size'] == 0:
             raise ValueError("Loaded feature lists from NB04 config are empty or invalid.")
        print(f"Input size calculated from fetched features: {HP['input_size']}")
        print(f"  Fetched Time Varying: {HP['features_time_varying']}")
        print(f"  Fetched Static: {HP['features_static']}")

    except Exception as e:
        print(f"Error fetching config from W&B run '{nb04_run_path}': {e}")
        print("This notebook requires the configuration (esp. feature lists) logged by NB04.")
        print("Please ensure NB04 ran successfully, logged the config correctly, and update 'nb04_run_path'.")
        exit()

    # --- Seed for reproducibility ---
    np.random.seed(HP['random_seed'])
    torch.manual_seed(HP['random_seed'])
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(HP['random_seed'])

except Exception as e:
    print(f"Error during configuration setup: {e}")
    exit()

## Initialize W&B Training Run

Start a new Weights & Biases run specifically for this training experiment. We log all hyperparameters defined in the `HP` dictionary, which now includes the feature lists and preprocessing details fetched from the NB04 run, ensuring full traceability. A unique directory is created for saving model checkpoints locally for this run.


In [None]:
# --- Initialize W&B Run for THIS Training Job ---
print("\n--- Initializing Weights & Biases Run for Training ---")
run = None
try:
    run = wandb.init(
        project=WANDB_PROJECT,
        entity=WANDB_ENTITY,
        job_type="train-baseline-lstm",
        name=f"train-lstm-hs{HP['lstm_hidden_size']}-nl{HP['lstm_num_layers']}-dp{HP['lstm_dropout_prob']}-{time.strftime('%Y%m%d-%H%M')}",
        config=HP # Log all hyperparameters, including fetched feature lists and source run ID
    )
    print(f"W&B run '{run.name}' initialized successfully. View at: {run.url}")
    # Define output directory for this specific run's checkpoints
    run_output_dir = NB06_OUTPUT_DIR / run.name
    run_output_dir.mkdir(parents=True, exist_ok=True)
    print(f"Checkpoints for this run will be saved to: {run_output_dir}")

except Exception as e:
    print(f"Error initializing W&B: {e}")
    print("Proceeding without W&B logging.")
    run_output_dir = NB06_OUTPUT_DIR / f"local_run_{time.strftime('%Y%m%d-%H%M%S')}"
    run_output_dir.mkdir(parents=True, exist_ok=True)
    print(f"Proceeding locally. Checkpoints will be saved to: {run_output_dir}")

## Setup Device

Check for CUDA availability and set the PyTorch device accordingly (GPU or CPU).


In [None]:
# --- Setup Device (GPU/CPU) ---
# (Keep this cell as it was - looks correct)
print("\n--- Setting up Device ---")
if torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("CUDA not available. Using CPU.")
if run: run.config.update({'device': str(device)}, allow_val_change=True) # Log device used

## Load Data & Create DataLoaders

Instantiate the custom `OASISDataset` for the training and validation data splits. Crucially, pass the configuration dictionary (`nb04_config`) fetched from the NB04 W&B run to ensure the Dataset uses the correct feature lists and applies the corresponding preprocessors internally. Then, wrap these datasets in PyTorch `DataLoader`s, specifying batch size and using the `pad_collate_fn` to handle variable sequence lengths.


In [None]:
# --- Instantiate Datasets and DataLoaders ---
print("\n--- Loading Data and Creating DataLoaders ---")
try:
    print("Instantiating training dataset...")
    # --- Pass the NB04 Config ---
    train_dataset = OASISDataset(TRAIN_DATA_PATH, SCALER_PATH, IMPUTER_PATH, config=nb04_config)

    print("Instantiating validation dataset...")
    # --- Pass the NB04 Config ---
    val_dataset = OASISDataset(VAL_DATA_PATH, SCALER_PATH, IMPUTER_PATH, config=nb04_config)

    train_loader = DataLoader(
        train_dataset, batch_size=HP['batch_size'], shuffle=True,
        collate_fn=pad_collate_fn, num_workers=HP['num_workers'], persistent_workers=(HP['num_workers']>0)
    )
    val_loader = DataLoader(
        val_dataset, batch_size=HP['batch_size'], shuffle=False,
        collate_fn=pad_collate_fn, num_workers=HP['num_workers'], persistent_workers=(HP['num_workers']>0)
    )
    print("Train and Validation DataLoaders created.")
    # Use the input_size calculated from fetched config and stored in HP
    print(f"Input size (num features) for model: {HP['input_size']}")

except FileNotFoundError as e:
     print(f"Error: Data file or preprocessor file not found: {e}")
     print("Ensure notebooks 01-04 ran successfully and paths in config are correct.")
     if run: run.finish()
     exit()
except Exception as e:
     print(f"Error creating Datasets/DataLoaders: {e}")
     if run: run.finish()
     exit()

## Define Model, Loss, Optimizer

Instantiate the `BaselineLSTMRegressor` model using the input size determined from the fetched configuration. Define the Mean Squared Error loss function (suitable for regression) and the Adam optimizer. Optionally, use `wandb.watch` to monitor model gradients and parameters during training.


In [None]:



# --- Instantiate Model, Loss, Optimizer ---
print("\n--- Initializing Model, Loss Function, and Optimizer ---")
try:
    # Instantiate the model
    model = BaselineLSTMRegressor(
        input_size=HP['input_size'],
        hidden_size=HP['lstm_hidden_size'],
        num_layers=HP['lstm_num_layers'],
        dropout_prob=HP['lstm_dropout_prob']
    )
    model.to(device) # Move model to GPU if available
    print("Model instantiated:")
    print(model)

    # Define the Loss Function (Criterion)
    # Mean Squared Error is common for regression tasks like predicting CDR score
    criterion = nn.MSELoss()
    print(f"Loss function: {type(criterion).__name__}")

    # Define the Optimizer
    # Adam is a popular choice
    optimizer = optim.Adam(model.parameters(), lr=HP['learning_rate'])
    print(f"Optimizer: {type(optimizer).__name__} with LR={HP['learning_rate']}")

    # Watch model with W&B for gradients etc.
    if run:
        wandb.watch(model, log='all', log_freq=100) # Log gradients and parameters every 100 batches

except Exception as e:
    print(f"Error initializing model/loss/optimizer: {e}")
    if run: run.finish()
    exit()


## Train and Validate Model

This section contains the main training loop over the specified number of epochs.

**Inside each epoch:**
* **Training Phase:**
     * Set model to `train()` mode.
     * Iterate through batches from `train_loader`.
     * Move data to the correct device.
     * Perform forward pass, calculate loss, perform backward pass, and update optimizer.
     * Accumulate loss and predictions/targets for epoch metrics.
     * Optionally log batch loss to W&B.
 * **Validation Phase:**
     * Set model to `eval()` mode.
     * Disable gradient calculations (`torch.no_grad()`).
     * Iterate through batches from `val_loader`.
     * Perform forward pass and calculate loss.
     * Accumulate loss and predictions/targets for epoch metrics.
 * **End of Epoch:**
     * Calculate and print average train/validation loss, MAE, and R2.
     * Log epoch metrics to W&B.
     * **Checkpointing:** If validation loss improved, save the model's state dictionary locally and log it as a W&B artifact tagged as 'best'.
     * **Early Stopping:** Check if validation loss has stopped improving for `patience` epochs; if so, stop training.


In [None]:
# --- Training and Validation Loop ---
print("\n--- Starting Model Training ---")

# Add tqdm for progress bars if you like (install with: pip install tqdm)
try:
    from tqdm.notebook import tqdm
except ImportError:
    print("Consider installing tqdm for progress bars: pip install tqdm")
    # Define a dummy tqdm if not installed
    def tqdm(iterable, **kwargs):
        return iterable

best_val_loss = float('inf') # Initialize best validation loss to infinity
epochs_no_improve = 0 # Counter for early stopping (optional)
patience = 10 # Example: Stop after 10 epochs with no improvement (optional)

# Make sure essential variables from previous cell exist
required_vars = ['model', 'criterion', 'optimizer', 'train_loader', 'val_loader', 'device', 'HP', 'run_output_dir']
if not all(v in locals() or v in globals() for v in required_vars):
     print("Error: Not all required variables (model, criterion, etc.) are defined.")
     print("Please ensure the previous cells executed correctly.")
     # Exit or handle appropriately
     exit()


for epoch in range(HP['epochs']):
    print(f"\n--- Epoch {epoch+1}/{HP['epochs']} ---")
    epoch_start_time = time.time()

    # --- Training Phase ---
    model.train() # Set model to training mode (enables dropout, batchnorm updates etc.)
    train_loss = 0.0
    train_targets_all = []
    train_preds_all = []

    # Use tqdm for a progress bar over batches
    train_pbar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1} Train")
    for i, batch in train_pbar:
        try:
            sequences_padded, lengths, targets, masks = batch
            # Move data to the appropriate device (GPU or CPU)
            sequences_padded = sequences_padded.to(device)
            targets = targets.to(device)
            # Lengths and masks might be needed on CPU or GPU depending on usage

            # 1. Zero the gradients
            optimizer.zero_grad()

            # 2. Forward pass: Get model predictions
            predictions = model(sequences_padded, lengths) # Pass lengths if model uses them

            # 3. Calculate the loss
            loss = criterion(predictions, targets)

            # 4. Backward pass: Calculate gradients
            loss.backward()

            # 5. Optimizer step: Update weights
            optimizer.step()

            # Accumulate loss and results for metrics
            train_loss += loss.item()
            train_targets_all.extend(targets.detach().cpu().numpy().flatten())
            train_preds_all.extend(predictions.detach().cpu().numpy().flatten())

            # Update progress bar description
            train_pbar.set_postfix({'batch_loss': loss.item()})

            # Optional: Log batch loss to W&B periodically
            if run and (i % 5 == 0): # Log every 5 batches
                 wandb.log({'train/batch_loss': loss.item(), 'epoch': epoch + i/len(train_loader)})


        except Exception as e:
            print(f"\nError during training batch {i}: {e}")
            # Decide how to handle - skip batch? stop training?
            continue # Skip batch for now

    avg_train_loss = train_loss / len(train_loader) if len(train_loader) > 0 else 0
    # Calculate training metrics for the epoch
    train_mae = mean_absolute_error(train_targets_all, train_preds_all) if train_targets_all else 0
    train_r2 = r2_score(train_targets_all, train_preds_all) if train_targets_all else 0


    # --- Validation Phase ---

    model.eval() # Set model to evaluation mode (disables dropout, fixes batchnorm)
    val_loss = 0.0
    val_targets_all = []
    val_preds_all = []

    # Disable gradient calculations during validation
    with torch.no_grad():
        val_pbar = tqdm(enumerate(val_loader), total=len(val_loader), desc=f"Epoch {epoch+1} Val")
        for i, batch in val_pbar:
            try:
                sequences_padded, lengths, targets, masks = batch
                sequences_padded = sequences_padded.to(device)
                targets = targets.to(device)

                # Forward pass
                predictions = model(sequences_padded, lengths)

                # Calculate loss
                loss = criterion(predictions, targets)

                val_loss += loss.item()
                val_targets_all.extend(targets.detach().cpu().numpy().flatten())
                val_preds_all.extend(predictions.detach().cpu().numpy().flatten())

                val_pbar.set_postfix({'batch_loss': loss.item()})

            except Exception as e:
                 print(f"\nError during validation batch {i}: {e}")
                 continue # Skip batch

    avg_val_loss = val_loss / len(val_loader) if len(val_loader) > 0 else 0
    # Calculate validation metrics for the epoch
    val_mae = mean_absolute_error(val_targets_all, val_preds_all) if val_targets_all else 0
    val_r2 = r2_score(val_targets_all, val_preds_all) if val_targets_all else 0


    # --- End of Epoch ---
    epoch_duration = time.time() - epoch_start_time
    print(f"Epoch {epoch+1}/{HP['epochs']} finished in {epoch_duration:.2f}s.")
    print(f"  Train Loss: {avg_train_loss:.4f} | Train MAE: {train_mae:.4f} | Train R2: {train_r2:.4f}")
    print(f"  Val Loss:   {avg_val_loss:.4f} | Val MAE:   {val_mae:.4f} | Val R2:   {val_r2:.4f}")

    # Log metrics to W&B
    if run:
        wandb.log({
            'epoch': epoch + 1,
            'train/epoch_loss': avg_train_loss,
            'train/epoch_mae': train_mae,
            'train/epoch_r2': train_r2,
            'val/epoch_loss': avg_val_loss,
            'val/epoch_mae': val_mae,
            'val/epoch_r2': val_r2,
            'epoch_duration_sec': epoch_duration
        })

    # --- Model Checkpointing ---
    if avg_val_loss < best_val_loss:
        best_val_loss = avg_val_loss
        epochs_no_improve = 0 # Reset counter
        # Define checkpoint path
        checkpoint_path = run_output_dir / f"best_model_epoch_{epoch+1}.pth"
        try:
            # Save the model state dictionary
            torch.save(model.state_dict(), checkpoint_path)
            print(f"  Validation loss improved to {best_val_loss:.4f}. Saving model to {checkpoint_path}")

            # Log checkpoint as W&B artifact (optional but good)
            if run:
                print("  Logging checkpoint artifact to W&B...")
                artifact_name = f'baseline-lstm-model-checkpoint'
                artifact_type = 'model'
                description = f"Model checkpoint at Epoch {epoch+1} with best val_loss: {best_val_loss:.4f}"
                model_artifact = wandb.Artifact(artifact_name, type=artifact_type, description=description,
                                                metadata={'epoch': epoch+1, 'val_loss': best_val_loss,
                                                          'val_mae': val_mae, 'val_r2': val_r2})
                model_artifact.add_file(str(checkpoint_path))
                run.log_artifact(model_artifact, aliases=['best', f'epoch_{epoch+1}']) # Alias 'best' points to this version
                print("  Checkpoint artifact logged.")

        except Exception as e:
            print(f"  Error saving checkpoint: {e}")

    else:
        epochs_no_improve += 1
        print(f"  Validation loss did not improve from {best_val_loss:.4f} ({epochs_no_improve}/{patience}).")

    # --- Early Stopping Check (Optional) ---
    if epochs_no_improve >= patience:
        print(f"\nEarly stopping triggered after {patience} epochs with no improvement.")
        break # Exit the training loop

# End of Training

Summary of the best validation performance achieved.

In [None]:
# --- End of Training ---
print("\n--- Training Complete ---")
print(f"Best validation loss achieved: {best_val_loss:.4f}")

In [None]:
# --- Final step: Load best model and evaluate on test set (In a separate cell/script ideally) ---
# print("\nLoading best model for final test set evaluation...")
# best_model_path = run_output_dir / "best_model_..." # Find the actual best path
# if best_model_path.exists():
#      model.load_state_dict(torch.load(best_model_path))
#      print("Best model loaded.")
#      # ... Add test set evaluation logic here ...
# else:
#      print("Warning: Best model checkpoint not found.")

# Save Run

The best model checkpoint (`best_model_epoch_*.pth`) is saved locally and logged as the 'best' alias in the W&B artifacts for this run.

In [None]:
# --- Finish W&B Run ---
print("\n--- Finishing W&B run ---")
if run:
    # Log best validation score as summary metric
    run.summary["best_val_loss"] = best_val_loss
    run.finish()
    print("W&B run finished.")
else:
    print("No active W&B run to finish.")

print("\nScript execution finished.")