# Notebook 06: Train Baseline LSTM Model (OASIS-2)

**Project Phase:** 1 (Model Training - Baseline)
**Dataset:** OASIS-2 Longitudinal MRI & Clinical Data

**Purpose:**
This notebook trains the `BaselineLSTMRegressor` (defined in `src/models.py`) to predict the next visit's Clinical Dementia Rating (CDR) score. It utilizes only longitudinal tabular clinical and demographic features derived from the OASIS-2 dataset. This model serves as a crucial baseline for evaluating the performance of more complex, multimodal architectures.

**Workflow:**
1.  **Setup:** Import necessary libraries, configure `sys.path` for `src/` utilities, and load the main project configuration (`config.json`). Define training hyperparameters (`HP`).
2.  **Path Resolution:** Use the `get_dataset_paths` utility to resolve paths for input data splits (`train`, `validation`, `test` Parquet files from Notebook 03).
3.  **W&B Initialization & Artifact/Config Ingestion:**
    * Initialize a new Weights & Biases (W&B) run for this training experiment using the `initialize_wandb_run` utility.
    * Consume the versioned **fitted preprocessor W&B Artifacts** (e.g., `scaler_standard_oasis2:latest`, `imputer_median_oasis2:latest`) that were produced and logged by Notebook 04. This step downloads the `.joblib` preprocessor files.
    * From one of these consumed preprocessor artifacts, identify the W&B Run from Notebook 04 that **produced** them using `artifact.logged_by()`.
    * Fetch the **authoritative `features` (time-varying, static) and `preprocess` (imputation/scaling columns, strategies) configurations** directly from this producer Notebook 04 run's W&B config. This becomes the `config_for_dataset` passed to `OASISDataset`.
    * Update the current training run's (`HP['input_size']`) based on the fetched feature lists.
4.  **Setup Device:** Detect and set the appropriate PyTorch device (CPU, CUDA, MPS).
5.  **Load Data & Create DataLoaders:** Instantiate `OASISDataset` for training and validation splits, passing the local paths to the downloaded preprocessor artifacts and the definitive `config_for_dataset`. MRI data is explicitly excluded (`include_mri=False`). Wrap datasets in PyTorch `DataLoader`s using `pad_collate_fn`.
6.  **Define Model, Loss, Optimizer:** Instantiate `BaselineLSTMRegressor`, `MSELoss` criterion, and `Adam` optimizer using parameters from `HP`.
7.  **Train Model:** Execute the main training loop by calling the `train_model` utility function from `src/training_utils.py`. This utility handles epoch iteration, training/validation phases (using `evaluate_model`), W&B metric logging, best model checkpointing (local save + W&B artifact), periodic checkpointing, and early stopping.
8.  **Evaluate Best Model on Test Set:**
    * Load the best model checkpoint saved during training (identified by `train_model`).
    * Instantiate `OASISDataset` and `DataLoader` for the test set.
    * Use the `evaluate_model` utility to calculate performance metrics (Loss, MAE, R2, MSE) on the test set.
    * Log these test metrics to the W&B run's summary.
9.  **Finalize W&B Run.**

**Input:**
* `config.json`: Main project configuration file.
* W&B Artifact Names for training, validation, and test data splits (e.g., `cohort-split-train_oasis2:latest` - from Notebook 03).
* W&B Artifact Names for fitted preprocessors (Scaler & Imputer - from Notebook 04).
* `src/` modules: `datasets.py`, `models.py`, `training_utils.py`, `evaluation_utils.py`, `wandb_utils.py`, `paths_utils.py`.

**Output:**
* **Local Files (in run-specific output directory for this NB06 run, e.g., `notebooks/outputs/06_Train_Baseline_Model_OASIS2/<run_name>/`):**
    * Trained `BaselineLSTMRegressor` model checkpoints (`best_model_epoch_X.pth`, `model_epoch_X_periodic.pth`).
* **W&B Run:**
    * Logged hyperparameters (`HP`), including source NB04 run ID and resolved feature/preprocessing configurations.
    * Per-epoch training and validation metrics.
    * Best model checkpoint logged as a W&B Artifact (e.g., `Training-NB06-BaselineLSTM-OASIS2-checkpoint:best`).
    * Periodic model checkpoints logged as W&B Artifacts.
    * Final test set performance metrics in the run's summary.

In [None]:
# In: notebooks/06_Train_Baseline_Model.ipynb
# Purpose: Train the baseline LSTM regression model to predict next CDR score
#          using pre-computed features and the prepared data splits.
#          Loads configuration (feature lists, etc.) from the relevant NB04 W&B run.

In [None]:
# --- Standard Libraries & Imports ---
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import wandb
import json
from pathlib import Path
import time
import sys

## 1. Setup: Project Configuration, Paths, Utilities, and Hyperparameters

This section initializes the notebook environment:
* Determines the project's root directory (`PROJECT_ROOT`) and adds the `src` directory to `sys.path` for custom module imports.
* Imports all necessary custom utility functions from the `src/` modules (`wandb_utils`, `paths_utils`, `datasets`, `models`, `training_utils`, `evaluation_utils`).
* Loads the main project configuration (`base_config`) from `config.json`.
* Defines dataset and notebook-specific identifiers (e.g., `DATASET_IDENTIFIER`, `NOTEBOOK_MODULE_NAME`).
* **Uses the `get_dataset_paths` utility to resolve paths to input data splits** (training, validation, and test `.parquet` files produced by Notebook 03). The paths to preprocessor `.joblib` files will be obtained later by downloading W&B artifacts from a Notebook 04 run.
* Defines the initial set of training hyperparameters (`HP`) for the `BaselineLSTMRegressor` model. The `input_size` hyperparameter will be dynamically updated later based on feature configurations fetched from Notebook 04.
* Defines the locator key from `config.json` that specifies the base output directory for this notebook's locally saved artifacts (e.g., model checkpoints).

In [None]:
# --- Project Setup, Configuration Loading, Utility Imports & Hyperparameters ---
print("--- Initializing Project Setup, Configuration, and HPs for NB06 ---")

# Initialize
PROJECT_ROOT = None
base_config = {} 

try:
    current_notebook_path = Path.cwd() 
    potential_project_root = current_notebook_path.parent 
    if (potential_project_root / "src").is_dir() and (potential_project_root / "config.json").is_file():
        PROJECT_ROOT = potential_project_root
    else: 
        PROJECT_ROOT = current_notebook_path
    if not (PROJECT_ROOT / "src").is_dir() or not (PROJECT_ROOT / "config.json").is_file():
        raise FileNotFoundError(f"Could not find 'src' or 'config.json'. PROJECT_ROOT: {PROJECT_ROOT}")
    if str(PROJECT_ROOT) not in sys.path:
        sys.path.insert(0, str(PROJECT_ROOT))
    print(f"PROJECT_ROOT: {PROJECT_ROOT}, added to sys.path.")

    # Import all necessary custom utilities
    from src.wandb_utils import initialize_wandb_run, load_model_from_wandb_artifact # load_model for test phase
    from src.paths_utils import get_dataset_paths, get_notebook_run_output_dir
    from src.datasets import OASISDataset, pad_collate_fn
    from src.models import BaselineLSTMRegressor
    from src.training_utils import train_model
    from src.evaluation_utils import evaluate_model
    print("Successfully imported all required custom utilities and classes.")

except Exception as e_setup:
    print(f"CRITICAL ERROR during initial setup or imports: {e_setup}")
    # exit()

# --- Load Main Project Configuration ---
print("\n--- Loading Main Project Configuration ---")
try:
    if PROJECT_ROOT is None: raise ValueError("PROJECT_ROOT not set.")
    CONFIG_PATH_MAIN = PROJECT_ROOT / 'config.json'
    with open(CONFIG_PATH_MAIN, 'r', encoding='utf-8') as f:
        base_config = json.load(f)
    print(f"Main project config loaded from: {CONFIG_PATH_MAIN}")
except Exception as e_cfg:
    print(f"CRITICAL ERROR loading main config.json: {e_cfg}")
    # exit() 

# --- Define Dataset, Notebook Specifics, and Resolve Data Split Paths ---
DATASET_IDENTIFIER = "oasis2" 
NOTEBOOK_MODULE_NAME = "06_Train_Baseline_Model"
# Key from config.json locators for this notebook's *output directory* (for model checkpoints)
NB06_OUTPUT_LOCATOR_KEY = base_config.get(f"pipeline_artefact_locators_{DATASET_IDENTIFIER}", {})\
                                     .get("train_baseline_subdir_nb06_key", "train_baseline_subdir")
                                    

TRAIN_DATA_PATH = None
VAL_DATA_PATH = None
TEST_DATA_PATH = None
# Scaler and Imputer paths will be determined by downloading artifacts in the next cell

try:
    if not base_config: raise ValueError("base_config is empty.")
    
    # Get paths for data splits using the utility
    # Training data paths
    train_stage_paths = get_dataset_paths(
        PROJECT_ROOT, base_config, DATASET_IDENTIFIER, stage="training"
    )
    TRAIN_DATA_PATH = train_stage_paths.get('train_data_parquet')
    VAL_DATA_PATH = train_stage_paths.get('val_data_parquet')
    
    # Test data path
    test_stage_paths = get_dataset_paths(
        PROJECT_ROOT, base_config, DATASET_IDENTIFIER, stage="testing" # Or "analysis"
    )
    TEST_DATA_PATH = test_stage_paths.get('test_data_parquet')
    
    if not all([TRAIN_DATA_PATH, VAL_DATA_PATH, TEST_DATA_PATH]):
        raise ValueError("Failed to resolve one or more critical data split paths via get_dataset_stage_paths.")

    print(f"\nKey data input paths for Notebook 06 ({DATASET_IDENTIFIER}):")
    print(f"  Training Data Parquet (from NB03): {TRAIN_DATA_PATH}")
    print(f"  Validation Data Parquet (from NB03): {VAL_DATA_PATH}")
    print(f"  Test Data Parquet (from NB03): {TEST_DATA_PATH}")
    
    for p_name, p_obj in [("Training Data", TRAIN_DATA_PATH), ("Validation Data", VAL_DATA_PATH), ("Test Data", TEST_DATA_PATH)]:
        if not p_obj.is_file(): raise FileNotFoundError(f"CRITICAL: {p_name} parquet file not found at: {p_obj}")
    print("All critical input data split paths for NB06 verified.")

except (KeyError, ValueError, FileNotFoundError) as e_paths_nb06:
    print(f"CRITICAL ERROR during path setup for NB06: {e_paths_nb06}")
    # exit()
except Exception as e_general_nb06_setup:
    print(f"CRITICAL ERROR during setup for NB06: {e_general_nb06_setup}")
    # exit()

# --- Define Training Hyperparameters (HP) ---
print("\n--- Defining Training Hyperparameters (HP) for BaselineLSTM ---")
HP = {
    'model_type': 'BaselineLSTM', 
    'dataset_identifier': DATASET_IDENTIFIER,
    'batch_size': 32,          
    'learning_rate': 1e-4,     
    'epochs': 50, # Max epochs, early stopping might finish sooner
    'lstm_hidden_size': 128,   
    'num_lstm_layers': 2,      
    'lstm_dropout_prob': 0.3,  
    'dataloader_num_workers': 0, 
    'patience': 10,            
    'save_checkpoint_every_n_epochs': 5, 
    'random_seed': 42,
    'input_size': None # CRITICAL: To be updated after fetching config from NB04
}
# Set seed for reproducibility early
np.random.seed(HP['random_seed'])
torch.manual_seed(HP['random_seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(HP['random_seed'])
print("Training hyperparameters (HP) defined. 'input_size' will be set after fetching NB04 config.")

## 2. Initialize W&B Run, Fetch Preprocessor Artifacts & Definitive Configuration from Notebook 04

This crucial step prepares the ground for consistent data handling and model training:
1.  **Initialize W&B Run for Notebook 06:** A new W&B run is started for this specific `BaselineLSTMRegressor` training experiment using the `initialize_wandb_run` utility. This run will log all hyperparameters, configurations, metrics, and model artifacts.
2.  **Define Output Directory:** A unique local directory is created for this W&B run's outputs (e.g., model checkpoints) using `get_notebook_run_output_dir`.
3.  **Consume Preprocessor Artifacts:** The W&B Artifacts for the *fitted* `StandardScaler` and `SimpleImputer` (produced by a specific Notebook 04 execution) are identified by name and version (e.g., `:latest`) and downloaded using `run.use_artifact()`. This provides the local paths to the `.joblib` files.
4.  **Fetch Authoritative NB04 Configuration:** From one of the consumed preprocessor artifacts (e.g., the scaler artifact), the W&B Run that *produced it* (the relevant Notebook 04 run) is identified via `artifact.logged_by()`.
5.  **Set Up `config_for_dataset`:** The W&B configuration of this producer Notebook 04 run is fetched. This configuration contains the **authoritative `features` (time-varying, static) and `preprocess` (imputation/scaling columns, strategies) dictionaries**. This fetched dictionary becomes `config_for_dataset`, which will be passed to `OASISDataset`.
6.  **Update `HP['input_size']`:** The `input_size` for the `BaselineLSTMRegressor` is dynamically determined from the length of the feature lists in the fetched `config_for_dataset`.
7.  All source information (NB04 run ID, consumed artifact versions, final feature/preprocess configs) is logged to the current Notebook 06 W&B run's configuration for complete traceability.

In [None]:
# --- Initialize W&B Run, Fetch Preprocessor Artifacts & NB04 Config ---
print("\n--- Initializing W&B Run for NB06 & Fetching NB04 Preprocessing Setup ---")

# Variables to be populated by this cell
run = None
run_output_dir_for_checkpoints = None
SCALER_PATH_FROM_ARTIFACT = None
IMPUTER_PATH_FROM_ARTIFACT = None
config_for_dataset = {} # This will be the authoritative config for OASISDataset
source_nb04_run_id_logged = "N/A"

# --- 1. Initialize W&B Run for THIS Notebook 06 execution ---
# We need the 'run' object to use artifacts, so initialize W&B first.
# The HP dictionary will be updated with input_size later and then logged more fully.
temp_hp_for_init = HP.copy() # Use a copy for initial logging

nb_number_prefix_nb06 = NOTEBOOK_MODULE_NAME.split('_')[0] if '_' in NOTEBOOK_MODULE_NAME else "NB"
job_specific_type_nb06 = f"{nb_number_prefix_nb06}-BaselineLSTM-{DATASET_IDENTIFIER}"
custom_elements_for_name_nb06 = [
    nb_number_prefix_nb06, DATASET_IDENTIFIER.upper(), "Baseline",
    f"h{HP['lstm_hidden_size']}", f"nl{HP['num_lstm_layers']}", f"dp{HP['lstm_dropout_prob']:.1f}", # Format dropout
    f"lr{HP['learning_rate']:.0e}", f"bs{HP['batch_size']}"
]
run = initialize_wandb_run(
    base_project_config=base_config,
    job_group="Training",
    job_specific_type=job_specific_type_nb06,
    run_specific_config=temp_hp_for_init, # Log initial HP
    custom_run_name_elements=custom_elements_for_name_nb06,
    notes=f"Training BaselineLSTMRegressor on {DATASET_IDENTIFIER.upper()}."
)

# --- 2. Define Run-Specific Output Directory for Local Checkpoints ---
if run:
    print(f"W&B run '{run.name}' (Job Type: '{run.job_type}') initialized. View at: {run.url}")
    run_output_dir_for_checkpoints = get_notebook_run_output_dir(
        PROJECT_ROOT, base_config, NB06_OUTPUT_LOCATOR_KEY, run, DATASET_IDENTIFIER
    )
    print(f"Model checkpoints for this run will be saved locally to: {run_output_dir_for_checkpoints}")
    run.config.update({"run_outputs/local_checkpoint_dir": str(run_output_dir_for_checkpoints)}, allow_val_change=True)
else:
    print("W&B run initialization failed. Proceeding with local fallback for outputs.")
    # Fallback local output directory if W&B init failed
    default_nb_locator = f"{NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER}_local_outputs"
    run_output_dir_for_checkpoints = get_notebook_run_output_dir(
        PROJECT_ROOT, base_config, NB06_OUTPUT_LOCATOR_KEY if base_config else default_nb_locator, 
        None, DATASET_IDENTIFIER
    )
    print(f"Model checkpoints (local fallback) will be saved to: {run_output_dir_for_checkpoints}")

if not isinstance(run_output_dir_for_checkpoints, Path): # Should be Path from utility
    run_output_dir_for_checkpoints = Path(run_output_dir_for_checkpoints)


# --- 3. Fetch Preprocessor Artifacts & NB04 Config (requires active NB06 W&B run) ---
if run: # Only proceed if W&B run for NB06 is active
    try:
        # Define expected preprocessor artifact names (must match what NB04 logged)
        preprocessing_cfg = base_config.get('preprocessing_config', {})
        scaling_strategy_name = preprocessing_cfg.get('scaling_strategy', 'standard_scaler')
        imputation_strategy_name = preprocessing_cfg.get('imputation_strategy', 'median')

        SCALER_ARTIFACT_LOGICAL_NAME = f"scaler_{scaling_strategy_name.lower().replace('scaler','').replace('_','')}_{DATASET_IDENTIFIER}"
        IMPUTER_ARTIFACT_LOGICAL_NAME = f"simple_imputer_{imputation_strategy_name.lower()}_{DATASET_IDENTIFIER}"
        PREPROCESSOR_ARTIFACT_TYPE = f"preprocessor_{DATASET_IDENTIFIER}"
        PREPROCESSOR_ARTIFACT_VERSION = "latest" # Or a specific pinned version

        # --- Download SCALER Artifact ---
        scaler_artifact_full_wandb_path = f"{base_config['wandb']['entity']}/{base_config['wandb']['project_name']}/{SCALER_ARTIFACT_LOGICAL_NAME}:{PREPROCESSOR_ARTIFACT_VERSION}"
        print(f"  Attempting to use Scaler artifact: {scaler_artifact_full_wandb_path}")
        scaler_artifact = run.use_artifact(scaler_artifact_full_wandb_path, type=PREPROCESSOR_ARTIFACT_TYPE)
        scaler_artifact_dir = Path(scaler_artifact.download())
        scaler_fname_pattern = base_config.get(f"pipeline_artefact_locators_{DATASET_IDENTIFIER}", {})\
            .get('scaler_fname_pattern', '{scaling_strategy}_{dataset_identifier}.joblib')
        scaler_filename_in_artifact = scaler_fname_pattern.format(
            scaling_strategy=scaling_strategy_name.lower(), dataset_identifier=DATASET_IDENTIFIER)
        SCALER_PATH_FROM_ARTIFACT = scaler_artifact_dir / scaler_filename_in_artifact
        if not SCALER_PATH_FROM_ARTIFACT.is_file(): # Fallback scan
            found_scalers = list(scaler_artifact_dir.glob("*.joblib"))
            if found_scalers: SCALER_PATH_FROM_ARTIFACT = found_scalers[0]
            else: raise FileNotFoundError(f"Scaler .joblib ('{scaler_filename_in_artifact}') not found in artifact {scaler_artifact.name}")
        print(f"  Scaler artifact '{scaler_artifact.name}' downloaded. Local path: {SCALER_PATH_FROM_ARTIFACT}")

        # --- Download IMPUTER Artifact ---
        imputer_artifact_full_wandb_path = f"{base_config['wandb']['entity']}/{base_config['wandb']['project_name']}/{IMPUTER_ARTIFACT_LOGICAL_NAME}:{PREPROCESSOR_ARTIFACT_VERSION}"
        print(f"  Attempting to use Imputer artifact: {imputer_artifact_full_wandb_path}")
        imputer_artifact = run.use_artifact(imputer_artifact_full_wandb_path, type=PREPROCESSOR_ARTIFACT_TYPE)
        imputer_artifact_dir = Path(imputer_artifact.download())
        imputer_fname_pattern = base_config.get(f"pipeline_artefact_locators_{DATASET_IDENTIFIER}", {})\
            .get('imputer_fname_pattern', 'simple_imputer_{imputation_strategy}_{dataset_identifier}.joblib')
        imputer_filename_in_artifact = imputer_fname_pattern.format(
            imputation_strategy=imputation_strategy_name.lower(), dataset_identifier=DATASET_IDENTIFIER)
        IMPUTER_PATH_FROM_ARTIFACT = imputer_artifact_dir / imputer_filename_in_artifact
        if not IMPUTER_PATH_FROM_ARTIFACT.is_file(): # Fallback scan
            found_imputers = list(imputer_artifact_dir.glob("*.joblib"))
            if found_imputers: IMPUTER_PATH_FROM_ARTIFACT = found_imputers[0]
            else: raise FileNotFoundError(f"Imputer .joblib ('{imputer_filename_in_artifact}') not found in artifact {imputer_artifact.name}")
        print(f"  Imputer artifact '{imputer_artifact.name}' downloaded. Local path: {IMPUTER_PATH_FROM_ARTIFACT}")

        # --- Fetch Full Configuration from the NB04 Run that Produced these Preprocessors ---
        nb04_producer_run = imputer_artifact.logged_by() # Use one of the artifacts to get the producer run. CHECK if the run is correct
        if nb04_producer_run:
            source_nb04_run_id_logged = nb04_producer_run.id
            nb04_run_name = nb04_producer_run.name
            print(f"  Fetching config from NB04 producer run: '{nb04_run_name}' (ID: {source_nb04_run_id_logged})")
            config_from_nb04_run = dict(nb04_producer_run.config)
            if 'features' not in config_from_nb04_run or 'preprocess' not in config_from_nb04_run:
                raise ValueError("Fetched NB04 config missing 'features' or 'preprocess' sections.")
            config_for_dataset = config_from_nb04_run # This is the authoritative config for OASISDataset
            print("  Definitive 'features' and 'preprocess' config for OASISDataset fetched.")

            # Update HP with input_size from this fetched config
            time_varying_cfg = config_for_dataset.get('features', {}).get('time_varying', [])
            static_cfg = config_for_dataset.get('features', {}).get('static', [])
            # Note: 'M/F_encoded' is expected in static_cfg if 'M/F' was handled.
            # The input_size should match the number of features OASISDataset will output per time step.
            HP['input_size'] = len(time_varying_cfg) + len(static_cfg)
            print(f"  Updated HP['input_size'] to {HP['input_size']} based on fetched feature lists.")
            if HP['input_size'] == 0:
                print("  WARNING: Calculated input_size is 0. Check fetched feature lists from NB04.")

            # Update current NB06 run's config with source info and final dataset config used
            run.config.update({
                "source_config_from_nb04/producer_run_id": source_nb04_run_id_logged,
                "source_config_from_nb04/producer_run_name": nb04_run_name,
                "source_config_from_nb04/scaler_artifact_used": scaler_artifact.name,
                "source_config_from_nb04/imputer_artifact_used": imputer_artifact.name,
                "dataset_config_used/features": config_for_dataset.get('features',{}),
                "dataset_config_used/preprocess": config_for_dataset.get('preprocess',{}),
                "dataset_config_used/cnn_model_params": config_for_dataset.get('cnn_model_params',{}), # For OASISDataset
                "dataset_config_used/preprocessing_config_mri": config_for_dataset.get('preprocessing_config',{}), # For OASISDataset
                "model_input_size_resolved": HP['input_size'] # Log resolved input size
            }, allow_val_change=True)
            print("  NB06 W&B run config updated with source NB04 info and final dataset config.")
        else:
            raise ConnectionError("Could not retrieve producer run from preprocessor artifact (artifact.logged_by() failed).")

    except Exception as e_fetch_nb04_all:
        print(f"CRITICAL ERROR fetching preprocessor artifacts or config from NB04: {e_fetch_nb04_all}")
        if run: run.finish(exit_code=1)
        # exit()
else: # W&B run for NB06 failed to initialize
    print("CRITICAL ERROR: W&B run for NB06 not initialized. Cannot fetch artifacts or proceed with training setup.")
    # Define fallbacks if absolutely necessary for code to not break immediately, but this is a critical failure
    SCALER_PATH_FROM_ARTIFACT = Path(base_config.get(f"pipeline_artefact_locators_{DATASET_IDENTIFIER}", {}).get('scaler_path_fallback', 'scaler.joblib')) # Highly unlikely to work
    IMPUTER_PATH_FROM_ARTIFACT = Path(base_config.get(f"pipeline_artefact_locators_{DATASET_IDENTIFIER}", {}).get('imputer_path_fallback', 'imputer.joblib'))
    config_for_dataset.setdefault('features', {'time_varying': [], 'static': []})
    config_for_dataset.setdefault('preprocess', {'imputation_cols': [], 'scaling_cols': []})
    config_for_dataset.setdefault('cnn_model_params', base_config.get('cnn_model_params', {}))
    config_for_dataset.setdefault('preprocessing_config', base_config.get('preprocessing_config', {}))
    HP['input_size'] = 0 # Fallback
    # exit()

# Final check on critical variables for OASISDataset
if not all([SCALER_PATH_FROM_ARTIFACT, IMPUTER_PATH_FROM_ARTIFACT, config_for_dataset.get('features'), HP.get('input_size') is not None]):
    print("CRITICAL ERROR: Essential components for OASISDataset (preprocessor paths, feature config, or model input_size) are not properly set up.")
    # exit()

## 4. Setup Device

Detect and set the appropriate PyTorch device (CUDA GPU, MPS for Apple Silicon, or CPU) for model training and data handling. The chosen device is logged to W&B.

In [None]:
# --- Setup Device (GPU/CPU/MPS) ---
print("\n--- Setting up PyTorch Device ---")
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
    print("Using MPS (Apple Silicon GPU).")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("Using CPU.")

# Log the determined device to W&B run config
if run: 
    run.config.update({'training_details/device_used': str(device)}, allow_val_change=True)
else:
    print(f"Device set to {device} (W&B run not active for logging).")

## 5. Load Data & Create DataLoaders for Baseline Model

This section instantiates the custom `OASISDataset` for the training and validation data splits (`cohort_train.parquet` and `cohort_validation.parquet`). Key configurations for the dataset are critical for consistency:

* **Preprocessor Paths:** The paths to the *fitted* `StandardScaler` and `SimpleImputer` (obtained by downloading W&B Artifacts logged by Notebook 04) are provided to `OASISDataset`.
* **Feature & Preprocessing Configuration:** The `config_for_dataset` dictionary (fetched from the W&B configuration of the Notebook 04 run that produced the preprocessors) is passed to `OASISDataset`. This ensures it uses the authoritative lists of time-varying/static features and imputation/scaling columns.
* **MRI Data:** For this baseline model, `include_mri` is explicitly set to `False`.

The instantiated `OASISDataset` objects are then wrapped in PyTorch `DataLoader`s, which handle batching and use the custom `pad_collate_fn` to manage variable sequence lengths. The `input_size` hyperparameter for the model is also confirmed here.

In [None]:
# --- Instantiate Datasets and DataLoaders for Baseline Model ---
print("\n--- Loading Data and Creating DataLoaders for Baseline Model Training ---")

# Initialize dataset and loader variables to ensure they are defined
train_dataset: OASISDataset | None = None
val_dataset: OASISDataset | None = None
train_loader: DataLoader | None = None
val_loader: DataLoader | None = None

# --- Prerequisite Variable Check ---
# These are expected to be defined and populated from preceding cells
required_vars_for_dataloading = {
    'TRAIN_DATA_PATH': TRAIN_DATA_PATH,
    'VAL_DATA_PATH': VAL_DATA_PATH,
    'SCALER_PATH_FROM_ARTIFACT': SCALER_PATH_FROM_ARTIFACT,
    'IMPUTER_PATH_FROM_ARTIFACT': IMPUTER_PATH_FROM_ARTIFACT,
    'config_for_dataset': config_for_dataset, # Authoritative config from NB04 run
    'HP': HP, # For batch_size, num_workers, and input_size
    'pad_collate_fn': pad_collate_fn,
    'OASISDataset': OASISDataset
}

proceed_with_loading = True
for var_name, var_value in required_vars_for_dataloading.items():
    if var_value is None:
        print(f"  CRITICAL ERROR: Prerequisite variable '{var_name}' for DataLoaders is None.")
        proceed_with_loading = False
    if var_name == 'config_for_dataset' and \
       (not var_value or 'features' not in var_value or 'preprocess' not in var_value):
        print(f"  CRITICAL ERROR: 'config_for_dataset' is empty or missing 'features'/'preprocess' keys.")
        proceed_with_loading = False
    if var_name == 'HP' and (not var_value or 'input_size' not in var_value or HP.get('input_size') is None):
        print(f"  CRITICAL ERROR: 'HP' dictionary or 'HP['input_size']' is not correctly set "
              "(should be updated after fetching NB04 config).")
        proceed_with_loading = False

if not proceed_with_loading:
    print("Halting DataLoader creation due to missing or invalid prerequisites from previous cells.")
    if run: run.finish(exit_code=1) # Mark W&B run as failed
    # exit() # Or raise an error
else:
    try:
        print(f"\nInstantiating training dataset for {DATASET_IDENTIFIER.upper()} (Baseline - MRI Excluded)...")
        train_dataset = OASISDataset(
            data_parquet_path=TRAIN_DATA_PATH,
            scaler_path=SCALER_PATH_FROM_ARTIFACT,   # Path to downloaded .joblib
            imputer_path=IMPUTER_PATH_FROM_ARTIFACT, # Path to downloaded .joblib
            config=config_for_dataset,               # Config from NB04 W&B run
            include_mri=False                        # Explicitly False for baseline model
            # mri_data_dir is not needed when include_mri is False
        )
        num_train_subjects = len(train_dataset)
        print(f"  Train dataset (baseline) created. Number of subjects (sequences): {num_train_subjects}")

        print(f"\nInstantiating validation dataset for {DATASET_IDENTIFIER.upper()} (Baseline - MRI Excluded)...")
        val_dataset = OASISDataset(
            data_parquet_path=VAL_DATA_PATH,
            scaler_path=SCALER_PATH_FROM_ARTIFACT,   # Use the SAME scaler/imputer
            imputer_path=IMPUTER_PATH_FROM_ARTIFACT,
            config=config_for_dataset,               # Use the SAME config
            include_mri=False
        )
        num_val_subjects = len(val_dataset)
        print(f"  Validation dataset (baseline) created. Number of subjects (sequences): {num_val_subjects}")

        # --- Create DataLoaders ---
        print("\nCreating DataLoaders...")
        train_loader = DataLoader(
            train_dataset, 
            batch_size=HP['batch_size'], 
            shuffle=True, # Shuffle training data each epoch
            collate_fn=pad_collate_fn,
            num_workers=HP.get('dataloader_num_workers', 0),
            persistent_workers=(HP.get('dataloader_num_workers',0) > 0) if HP.get('dataloader_num_workers',0) > 0 else False # Only if num_workers > 0
        )

        val_loader = DataLoader(
            val_dataset, 
            batch_size=HP['batch_size'], 
            shuffle=False, # No need to shuffle validation data
            collate_fn=pad_collate_fn,
            num_workers=HP.get('dataloader_num_workers', 0),
            persistent_workers=(HP.get('dataloader_num_workers',0) > 0) if HP.get('dataloader_num_workers',0) > 0 else False
        )
        print("Train and Validation DataLoaders for baseline model created.")
        print(f"  Number of batches in train_loader: ~{len(train_loader)}")
        print(f"  Number of batches in val_loader: ~{len(val_loader)}")

        # Log dataset info to W&B for this training run
        if run:
            run.log({
                f'dataset_info_{DATASET_IDENTIFIER}/train_subjects': num_train_subjects,
                f'dataset_info_{DATASET_IDENTIFIER}/val_subjects': num_val_subjects,
                f'dataset_info_{DATASET_IDENTIFIER}/input_size_for_model': HP['input_size'],
                f'dataset_info_{DATASET_IDENTIFIER}/batch_size': HP['batch_size']
            })
        
        # Confirm input_size for the model (this was set in HP after fetching NB04 config)
        print(f"\nInput size (number of features) for BaselineLSTMRegressor: {HP['input_size']}")
        if HP['input_size'] is None or HP['input_size'] == 0:
            print("WARNING: HP['input_size'] is None or 0. The model will likely fail. "
                  "Check fetching of NB04 config and feature list processing.")

    except FileNotFoundError as e_fnf_ds_nb06:
        print(f"CRITICAL ERROR during OASISDataset instantiation in NB06: File not found - {e_fnf_ds_nb06}")
        print("  Verify paths to data splits and downloaded preprocessor artifacts from NB04.")
        if run: run.finish(exit_code=1)
        # exit()
    except KeyError as e_key_ds_nb06:
        print(f"CRITICAL ERROR during OASISDataset instantiation in NB06: Missing a key in 'config_for_dataset' - {e_key_ds_nb06}")
        print("  Ensure the configuration fetched from the Notebook 04 W&B run is complete.")
        if run: run.finish(exit_code=1)
        # exit()
    except Exception as e_ds_init_nb06:
        print(f"An unexpected CRITICAL ERROR occurred during OASISDataset or DataLoader instantiation in NB06: {e_ds_init_nb06}")
        import traceback
        traceback.print_exc()
        if run: run.finish(exit_code=1)
        # exit()

# Final check for subsequent cells
if train_loader is None or val_loader is None:
    print("CRITICAL ERROR: DataLoaders were not successfully created. Cannot proceed with model training.")
    # exit()

## 6. Define Model, Loss Function, and Optimizer

This section sets up the components required for training the `BaselineLSTMRegressor`:
1.  **Model Instantiation:** The `BaselineLSTMRegressor` (imported from `src/models.py`) is instantiated using the hyperparameters defined in the `HP` dictionary, including the `input_size` (number of input features per timestep, determined from the NB04 configuration), `lstm_hidden_size`, `num_lstm_layers`, and `lstm_dropout_prob`. The model is then moved to the configured PyTorch `device` (e.g., CUDA, MPS, or CPU).
2.  **Loss Function:** The Mean Squared Error loss (`nn.MSELoss`) is chosen, which is appropriate for regression tasks like predicting a continuous CDR score.
3.  **Optimizer:** The Adam optimizer (`optim.Adam`) is selected to update the model's weights during training, configured with the learning rate from `HP`.
4.  **W&B Model Watching (Optional):** If a W&B run is active, `wandb.watch()` is called to monitor model gradients, parameters, and architecture throughout the training process.

In [None]:
# --- Define Model, Loss Function, and Optimizer ---
print("\n--- Defining Model, Loss Function, and Optimizer for BaselineLSTM ---")

# These variables are expected to be defined from previous cells:
# HP (dictionary with model hyperparameters like input_size, lstm_hidden_size, etc.)
# device (torch.device)
# BaselineLSTMRegressor (imported model class)
# run (active W&B run object, or None)

model = None
criterion = None
optimizer = None

# Ensure HP dictionary and critical model parameters are available
if 'HP' not in locals() or not HP or HP.get('input_size') is None:
    print("CRITICAL ERROR: HP dictionary not defined or 'input_size' is missing or None. "
          "Ensure 'input_size' was correctly updated after fetching NB04 config.")
    if run: run.finish(exit_code=1)
    # exit() # Or raise error
else:
    try:
        # 1. Instantiate the BaselineLSTMRegressor model
        print(f"  Instantiating BaselineLSTMRegressor with input_size: {HP['input_size']}")
        model = BaselineLSTMRegressor(
            input_size=HP['input_size'],
            hidden_size=HP['lstm_hidden_size'],
            num_layers=HP['num_lstm_layers'],
            dropout_prob=HP['lstm_dropout_prob']
        )
        model.to(device) # Move model to the selected device
        print("  BaselineLSTMRegressor model instantiated and moved to device.")
        # print(model) # Optional: print model architecture

        # 2. Define the Loss Function (Criterion)
        criterion = nn.MSELoss()
        print(f"  Loss function set to: {type(criterion).__name__}")

        # 3. Define the Optimizer
        optimizer = optim.Adam(model.parameters(), lr=HP['learning_rate'])
        print(f"  Optimizer set to: {type(optimizer).__name__} with LR={HP['learning_rate']}")

        # 4. Optional: Watch model with W&B for gradients, parameters, etc.
        if run:
            wandb.watch(model, criterion=criterion, log='all', log_freq=100) # Log gradients, parameters, and outputs
            print("  W&B model watching enabled.")

    except KeyError as e_key_model:
        print(f"CRITICAL ERROR: Missing a key in HP dictionary needed for model instantiation: {e_key_model}")
        if run: run.finish(exit_code=1)
        # exit()
    except Exception as e_model_setup:
        print(f"CRITICAL ERROR during model, loss, or optimizer setup: {e_model_setup}")
        import traceback
        traceback.print_exc()
        if run: run.finish(exit_code=1)
        # exit()

# Check if model setup was successful for subsequent cells
if model is None or criterion is None or optimizer is None:
    print("CRITICAL ERROR: Model setup failed. Cannot proceed to training.")
    # exit()

## 7. Train Baseline Model using Utility Function

The core model training and validation process is now encapsulated within the `train_model` utility function (from `src/training_utils.py`). This function is called with the instantiated model, data loaders, criterion, optimizer, device, number of epochs, active W&B run object, checkpoint directory, and other relevant hyperparameters.

The `train_model` utility handles:
* The main loop over epochs.
* The training phase for each epoch (forward pass, loss calculation, backpropagation, optimizer step).
* The validation phase for each epoch (using the `evaluate_model` utility passed to it).
* Logging all training and validation metrics (Loss, MAE, R2, epoch duration) to W&B.
* **Best Model Checkpointing:** Saving the model state (`.pth` file) locally to the run-specific output directory (`run_output_dir_for_checkpoints`) whenever validation loss improves. This best model is also logged as a W&B Artifact with the 'best' alias.
* **Periodic Checkpointing:** Saving model state, optimizer state, and current metrics periodically (e.g., every `HP['save_checkpoint_every_n_epochs']` epochs) for resumable training. These are also logged as W&B Artifacts.
* **Early Stopping:** Monitoring validation loss and stopping training if no improvement is seen for a defined `patience` number of epochs.

The function returns the local path to the best model checkpoint and the best validation loss achieved.

In [None]:
# --- Train Baseline Model using the train_model Utility ---
print("\n--- Starting Baseline Model Training using 'train_model' utility ---")

# Ensure all necessary inputs for train_model are defined from previous cells:
# model, train_loader, val_loader, criterion, optimizer, device, HP, run, 
# run_output_dir_for_checkpoints (for saving .pth files), evaluate_model (imported function)

path_to_best_checkpoint_local = None
best_validation_loss_achieved = float('inf')

required_vars_for_training = [
    'model', 'train_loader', 'val_loader', 'criterion', 'optimizer', 'device', 'HP',
    'run_output_dir_for_checkpoints', 'evaluate_model' 
    # 'run' can be None if W&B init failed, train_model should handle that
]

proceed_with_training = True
for var_name in required_vars_for_training:
    if var_name not in locals() or locals()[var_name] is None:
        # Special case for 'run' which can be None
        if var_name == 'run' and ('run' not in locals() or locals()['run'] is None):
            print("  Note: W&B run object is None. Training will proceed without W&B logging within train_model.")
            # No need to set proceed_with_training to False just for this
        else:
            print(f"  CRITICAL ERROR: Prerequisite variable '{var_name}' for training is not defined or is None.")
            proceed_with_training = False

if not proceed_with_training:
    print("Halting model training due to missing or invalid prerequisites.")
    if run: run.finish(exit_code=1)
    # exit()
else:
    try:
        path_to_best_checkpoint_local, best_validation_loss_achieved = train_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            num_epochs=HP['epochs'],
            wandb_run=run, # Pass the active W&B run object (can be None)
            checkpoint_dir=run_output_dir_for_checkpoints, 
            evaluate_fn=evaluate_model, # Pass the imported evaluate_model function
            model_type_flag="baseline", # Critical for batch unpacking in evaluate_fn & train_model
            hp_dict=HP, # Pass the HP dict for patience, save_every_n_epochs
            best_val_loss_init=float('inf') # Start with fresh best loss tracking
        )

        print(f"\n--- Training Complete (via 'train_model' utility) ---")
        print(f"Best validation loss achieved during training: {best_validation_loss_achieved:.4f}")
        if path_to_best_checkpoint_local and path_to_best_checkpoint_local.exists():
            print(f"Best model checkpoint saved locally at: {path_to_best_checkpoint_local}")
            # Log this path to W&B summary for easy reference if run is active
            if run:
                run.summary['best_model_local_path'] = str(path_to_best_checkpoint_local)
                run.summary['best_val_loss_final'] = best_validation_loss_achieved # Ensure final best_val_loss is in summary
        elif path_to_best_checkpoint_local: # Path was returned but file doesn't exist (should not happen if train_model is correct)
            print(f"Warning: train_model returned a best checkpoint path, but file not found: {path_to_best_checkpoint_local}")
        else: # No checkpoint saved (e.g., 0 epochs, or error in checkpointing, or no improvement)
            print("No best model checkpoint was saved during training (e.g., no improvement in val_loss or training error).")

    except Exception as e_train:
        print(f"CRITICAL ERROR occurred during model training: {e_train}")
        import traceback
        traceback.print_exc()
        if run: run.finish(exit_code=1)
        # exit()

## 8. Evaluate Best Model on Test Set

After training is complete, the best performing model (based on the lowest validation loss achieved during training) is loaded from its saved checkpoint. This model is then evaluated on the held-out test set (`cohort_test.parquet`) to assess its generalization performance.

The evaluation uses the `evaluate_model` utility, providing metrics such as Test Loss, Mean Squared Error (MSE), Mean Absolute Error (MAE), and R-squared (R²). These final test metrics are logged to the W&B run's summary for this training experiment.

In [None]:
# --- Evaluate Best Model on Test Set ---
print("\n--- Evaluating Best Model on Test Set ---")

# Ensure path_to_best_checkpoint_local, HP, config_for_dataset, device are available
# Also TEST_DATA_PATH, SCALER_PATH_FROM_ARTIFACT, IMPUTER_PATH_FROM_ARTIFACT

if 'path_to_best_checkpoint_local' in locals() and \
   path_to_best_checkpoint_local is not None and \
   path_to_best_checkpoint_local.exists() and \
   'HP' in locals() and HP and \
   'config_for_dataset' in locals() and config_for_dataset and \
   'TEST_DATA_PATH' in locals() and TEST_DATA_PATH.is_file() and \
   'SCALER_PATH_FROM_ARTIFACT' in locals() and SCALER_PATH_FROM_ARTIFACT.is_file() and \
   'IMPUTER_PATH_FROM_ARTIFACT' in locals() and IMPUTER_PATH_FROM_ARTIFACT.is_file():

    print(f"Loading best model for test evaluation from: {path_to_best_checkpoint_local}")

    # 1. Instantiate a new model of the same architecture
    # Ensure HP contains the necessary parameters for model instantiation
    try:
        test_model = BaselineLSTMRegressor(
            input_size=HP['input_size'], # Should be correctly set from NB04 config via HP
            hidden_size=HP['lstm_hidden_size'],
            num_layers=HP['num_lstm_layers'],
            dropout_prob=HP['lstm_dropout_prob'] # Use training dropout, will be disabled by model.eval()
        )
        
        # Load the saved state dictionary
        test_model.load_state_dict(torch.load(path_to_best_checkpoint_local, map_location=device))
        test_model.to(device)
        # model.eval() is called inside evaluate_model, but good practice here too
        test_model.eval() 
        print("  Best model loaded and set to evaluation mode.")

        # 2. Create DataLoader for the Test Set
        print(f"  Instantiating test dataset from: {TEST_DATA_PATH} (Baseline - MRI Excluded)...")
        test_dataset = OASISDataset(
            data_parquet_path=TEST_DATA_PATH,
            scaler_path=SCALER_PATH_FROM_ARTIFACT,
            imputer_path=IMPUTER_PATH_FROM_ARTIFACT,
            config=config_for_dataset, # Use the same authoritative config from NB04
            include_mri=False # Explicitly False for baseline model
        )
        print(f"  Test dataset created with {len(test_dataset)} subjects.")

        test_loader = DataLoader(
            test_dataset,
            batch_size=HP['batch_size'], # Can use training batch size or a different one
            shuffle=False, # No need to shuffle test data
            collate_fn=pad_collate_fn,
            num_workers=HP.get('dataloader_num_workers', 0)
        )
        print(f"  Test DataLoader created. Number of batches: ~{len(test_loader)}")

        # 3. Perform Evaluation using the utility function
        # Ensure criterion is defined (it was defined before calling train_model)
        if 'criterion' not in locals() or criterion is None: 
            print("  Warning: Criterion not defined, re-initializing to MSELoss for test evaluation.")
            criterion = nn.MSELoss()

        print(f"  Evaluating best model on {len(test_dataset)} test subjects...")
        test_set_metrics = evaluate_model(
            test_model, 
            test_loader, 
            criterion, 
            device, 
            model_name_for_batch_unpack="baseline"
        )

        print(f"\n--- Test Set Performance of Best Model ---")
        print(f"  (Model from epoch with validation loss: {best_validation_loss_achieved:.4f})") # From train_model output
        print(f"  Test Loss (Criterion): {test_set_metrics['loss']:.4f}")
        print(f"  Test MSE (from preds): {test_set_metrics['mse']:.4f}") # Added MSE
        print(f"  Test MAE:  {test_set_metrics['mae']:.4f}")
        print(f"  Test R2:   {test_set_metrics['r2']:.4f}")

        # 4. Log Test Metrics to W&B Summary for this training run
        if run:
            run.summary["test_set/loss"] = test_set_metrics['loss']
            run.summary["test_set/mse"] = test_set_metrics['mse']
            run.summary["test_set/mae"] = test_set_metrics['mae']
            run.summary["test_set/r2"] = test_set_metrics['r2']
            # best_validation_loss_achieved was already logged by train_model to run.summary['best_val_loss_final'] if W&B active
            # but good to have it explicitly related to the test checkpoint here too.
            run.summary["test_set/checkpoint_best_val_loss"] = best_validation_loss_achieved 
            print("  Test metrics logged to W&B run summary.")
            
    except Exception as e_test_eval:
        print(f"CRITICAL ERROR during test set evaluation: {e_test_eval}")
        import traceback
        traceback.print_exc()
        if run: run.summary["test_set/status"] = "EvaluationError"

else:
    print("Skipping test set evaluation: Best model checkpoint not found, not saved, or other prerequisites missing.")
    if run: run.summary["test_set/status"] = "SkippedEvaluation"

## 9. Finalize W&B Run

Complete the execution of this training notebook and finish the associated Weights & Biases run. This ensures all queued logs, metrics, configurations, and model artifacts are fully uploaded and synchronized with the W&B platform.

In [None]:
# --- Finish W&B Run for Notebook 06 ---
print(f"\n--- {NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER} (Baseline Training) complete. Finishing W&B run. ---")

if run: # Check if 'run' object exists and is an active W&B run
    try:
        # Ensure final best_val_loss is in summary if not already covered by train_model or test summary
        if 'best_validation_loss_achieved' in locals() and 'best_val_loss_final' not in run.summary:
            run.summary['best_val_loss_final'] = best_validation_loss_achieved
        
        run.finish()
        run_name_to_print = run.name_synced if hasattr(run, 'name_synced') and run.name_synced else \
                            run.name if hasattr(run, 'name') and run.name else \
                            run.id if hasattr(run, 'id') else "current NB06 run"
        print(f"W&B run '{run_name_to_print}' finished successfully.")
    except Exception as e_finish_run_nb06:
        print(f"Error during wandb.finish() for Notebook 06: {e_finish_run_nb06}")
        print("The run may not have finalized correctly on the W&B server.")
else:
    print("No active W&B run to finish for this training session.")

print(f"\n--- Notebook {NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER} execution finished. ---")