# Notebook 07: Train Hybrid CNN+LSTM Model (OASIS-2)

**Project Phase:** 1 (Model Training - Hybrid)
**Dataset:** OASIS-2 Longitudinal MRI & Clinical Data

**Purpose:**
This notebook trains the `ModularLateFusionLSTM` hybrid model (defined in `src/models.py`) to predict the next visit's CDR score. It utilizes both longitudinal tabular clinical/demographic features and features extracted from 3D T1w MRI scans via an internal 3D CNN. This model explores the benefit of incorporating neuroimaging data.

**Workflow:**
1.  **Setup:** Load `config.json`, define training hyperparameters (`HP`), and resolve paths for data splits.
2.  **Consume Preprocessor Artifacts & Fetch NB04 Config:** Download fitted preprocessors (Scaler, Imputer) from W&B Artifacts (logged by NB04). Fetch the authoritative `features` and `preprocess` configurations from the Notebook 04 W&B run that produced these preprocessors.
3.  **W&B Initialization:** Start a new W&B run for this training experiment, logging HPs and source configurations. Create a local output directory for model checkpoints.
4.  **Setup Device:** Set PyTorch device (CPU, CUDA, MPS).
5.  **Load Data & Create DataLoaders:** Instantiate `OASISDataset` for train/validation (with `include_mri=True`), passing downloaded preprocessor paths and fetched NB04 config. Create `DataLoader`s using `pad_collate_fn`.
6.  **Define Model, Loss, Optimizer:** Instantiate `ModularLateFusionLSTM`, `MSELoss`, and `Adam` optimizer.
7.  **Train Model:** Call the `train_model` utility, which handles epochs, training/validation (using `evaluate_model` with `model_type_flag="hybrid"`), W&B logging, checkpointing, and early stopping.
8.  **Evaluate Best Model on Test Set:** Load the best checkpoint, create test `DataLoader`, use `evaluate_model` for test metrics, and log to W&B summary.
9.  **Finalize W&B Run.**

**Input:**
* `config.json`: Main project configuration file.
* W&B Artifact Names for training, validation, and test data splits (from Notebook 03).
* W&B Artifact Names for fitted preprocessors (Scaler & Imputer - from Notebook 04).
* Directory containing preprocessed MRI scans.
* `src/` modules.

**Output:**
* **Local Files (in run-specific directory, e.g., `notebooks/outputs/07_Train_Hybrid_Model_OASIS2/<run_name>/`):**
    * Trained `ModularLateFusionLSTM` model checkpoints.
* **W&B Run:**
    * Logged HPs (including modality dropout rate, source NB04 config).
    * Metrics, best model artifact (`Training-HybridCNNLSTM-OASIS2-checkpoint:best`), periodic checkpoints, test metrics.

In [None]:
# In: notebooks/07_Train_Hybrid_Model.ipynb
# Purpose: Train the Hybrid (CNN+LSTM) model to predict next CDR score
#          using tabular clinical features and 3D MRI scan data.
#          Loads data processing configurations from the relevant NB04 W&B run.

## Setup: Imports, Paths, Config, Hyperparameters

In [None]:
# --- Import Libraries ---
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import wandb
import json
from pathlib import Path
import time
import os
import sys 

## 1. Setup: Project Configuration, Paths, Utilities, and Hyperparameters

This section initializes the notebook environment:
* Determines `PROJECT_ROOT` and adds `src/` to `sys.path`.
* Imports all necessary custom utilities from `src/` modules.
* Loads the main project configuration (`base_config`) from `config.json`.
* Defines dataset and notebook-specific identifiers.
* **Uses `get_dataset_paths` to resolve paths for input data splits (train, validation, test from Notebook 03) and the general MRI data directory.** Paths to preprocessor `.joblib` files will be obtained later by downloading W&B artifacts.
* Defines initial training hyperparameters (`HP`) for the `ModularLateFusionLSTM` model. Key parameters like `tabular_input_size` and CNN configurations will be confirmed or updated after fetching the definitive configuration from Notebook 04.
* Defines the locator key from `config.json` for this notebook's local output directory (for model checkpoints).

In [None]:
# --- Project Setup, Configuration Loading, Utility Imports & Hyperparameters ---
print("--- Initializing Project Setup, Configuration, and HPs for NB07 (Hybrid Model) ---")

# Initialize
PROJECT_ROOT = None
base_config = {} 

try:
    current_notebook_path = Path.cwd() 
    potential_project_root = current_notebook_path.parent 
    if (potential_project_root / "src").is_dir() and (potential_project_root / "config.json").is_file():
        PROJECT_ROOT = potential_project_root
    else: 
        PROJECT_ROOT = current_notebook_path
    if not (PROJECT_ROOT / "src").is_dir() or not (PROJECT_ROOT / "config.json").is_file():
        raise FileNotFoundError(f"Could not find 'src' or 'config.json'. PROJECT_ROOT: {PROJECT_ROOT}")
    if str(PROJECT_ROOT) not in sys.path:
        sys.path.insert(0, str(PROJECT_ROOT))
    print(f"PROJECT_ROOT: {PROJECT_ROOT}, added to sys.path.")

    # Import all necessary custom utilities
    from src.wandb_utils import initialize_wandb_run, load_model_from_wandb_artifact
    from src.paths_utils import get_dataset_paths, get_notebook_run_output_dir
    from src.datasets import OASISDataset, pad_collate_fn
    from src.models import ModularLateFusionLSTM # Import the HYBRID model
    from src.training_utils import train_model
    from src.evaluation_utils import evaluate_model
    print("Successfully imported all required custom utilities and classes.")

except Exception as e_setup:
    print(f"CRITICAL ERROR during initial setup or imports: {e_setup}")
    # exit()

# --- Load Main Project Configuration ---
print("\n--- Loading Main Project Configuration ---")
try:
    if PROJECT_ROOT is None: raise ValueError("PROJECT_ROOT not set.")
    CONFIG_PATH_MAIN = PROJECT_ROOT / 'config.json'
    with open(CONFIG_PATH_MAIN, 'r', encoding='utf-8') as f:
        base_config = json.load(f)
    print(f"Main project config loaded from: {CONFIG_PATH_MAIN}")
except Exception as e_cfg:
    print(f"CRITICAL ERROR loading main config.json: {e_cfg}")
    # exit() 

# --- Define Dataset, Notebook Specifics, and Resolve Data Split Paths ---
DATASET_IDENTIFIER = "oasis2" 
NOTEBOOK_MODULE_NAME = "07_Train_Hybrid_Model"
# Key from config.json locators for this notebook's output directory
NB07_OUTPUT_LOCATOR_KEY = base_config.get(f"pipeline_artefact_locators_{DATASET_IDENTIFIER}", {})\
                                     .get("train_hybrid_subdir_nb07_key", "train_hybrid_subdir_nb07")
                                     # Example key in config: "train_hybrid_subdir_nb07": "07_Train_Hybrid_Outputs"

TRAIN_DATA_PATH = None
VAL_DATA_PATH = None
TEST_DATA_PATH = None
MRI_DATA_DIR = None # For OASISDataset

try:
    if not base_config: raise ValueError("base_config is empty.")
    
    train_stage_paths = get_dataset_paths(
        PROJECT_ROOT, base_config, DATASET_IDENTIFIER, stage="training"
    )
    TRAIN_DATA_PATH = train_stage_paths.get('train_data_parquet')
    VAL_DATA_PATH = train_stage_paths.get('val_data_parquet')
    MRI_DATA_DIR = train_stage_paths.get('mri_data_dir') # MRI dir is needed for hybrid
    
    test_stage_paths = get_dataset_paths(
        PROJECT_ROOT, base_config, DATASET_IDENTIFIER, stage="testing"
    )
    TEST_DATA_PATH = test_stage_paths.get('test_data_parquet')
    
    if not all([TRAIN_DATA_PATH, VAL_DATA_PATH, TEST_DATA_PATH, MRI_DATA_DIR]):
        raise ValueError("Failed to resolve one or more critical data/MRI paths via get_dataset_stage_paths.")

    print(f"\nKey data input paths for Notebook 07 ({DATASET_IDENTIFIER}):")
    print(f"  Training Data Parquet (from NB03): {TRAIN_DATA_PATH}")
    print(f"  Validation Data Parquet (from NB03): {VAL_DATA_PATH}")
    print(f"  Test Data Parquet (from NB03): {TEST_DATA_PATH}")
    print(f"  MRI Data Directory: {MRI_DATA_DIR}")
    
    for p_name, p_obj in [("Training Data", TRAIN_DATA_PATH), ("Validation Data", VAL_DATA_PATH), ("Test Data", TEST_DATA_PATH)]:
        if not p_obj.is_file(): raise FileNotFoundError(f"CRITICAL: {p_name} parquet file not found at: {p_obj}")
    if not MRI_DATA_DIR.is_dir(): raise FileNotFoundError(f"CRITICAL: MRI Data Directory not found at: {MRI_DATA_DIR}")
    print("All critical input data and MRI paths for NB07 verified.")

except (KeyError, ValueError, FileNotFoundError) as e_paths_nb07:
    print(f"CRITICAL ERROR during path setup for NB07: {e_paths_nb07}")
    # exit()
except Exception as e_general_nb07_setup:
    print(f"CRITICAL ERROR during setup for NB07: {e_general_nb07_setup}")
    # exit()

# --- Define Training Hyperparameters (HP) for Hybrid Model ---
print("\n--- Defining Training Hyperparameters (HP) for Hybrid Model ---")
HP = {
    'model_type': 'HybridCNNLSTM', 
    'dataset_identifier': DATASET_IDENTIFIER,
    'batch_size': 1, # Might need to be smaller than baseline due to MRI memory
    'learning_rate': 1e-4,     
    'epochs': 2, # Reduced for local testing, increase for full runs
    
    # LSTM parameters (can be shared or specific if fetched from NB04 config later)
    'lstm_hidden_size': 128, # General LSTM hidden size, can be overridden for mri/tabular streams
    'mri_lstm_hidden_size': 128, # Specific for MRI stream
    'tabular_lstm_hidden_size': 128, # Specific for Tabular stream
    'num_lstm_layers': 2,      
    'lstm_dropout_prob': 0.3,
    'modality_dropout_rate': 0.0, # Set to 0.0 for no MD, or e.g., 0.1, 0.2 for MD
    
    'dataloader_num_workers': 0, 
    'patience': 10,            
    'save_checkpoint_every_n_epochs': 5, 
    'random_seed': 42,
    
    # These will be updated/confirmed from NB04's config:
    'tabular_input_size': None, 
    'cnn_input_channels': None, # Will come from cnn_model_params in NB04 config
    'cnn_output_features': None # Will come from cnn_model_params in NB04 config
}
# Seed for reproducibility
np.random.seed(HP['random_seed'])
torch.manual_seed(HP['random_seed'])
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(HP['random_seed'])
print("Training hyperparameters (HP) for Hybrid model defined. Key sizes will be set after fetching NB04 config.")

## 2. Initialize W&B Run, Fetch Preprocessor Artifacts & Definitive Configuration from Notebook 04

This step is identical in principle to Notebook 06:
1.  **Initialize W&B Run for Notebook 07:** A new W&B run is started for this `ModularLateFusionLSTM` training experiment using `initialize_wandb_run`.
2.  **Define Output Directory:** A unique local directory for this run's model checkpoints is created.
3.  **Consume Preprocessor Artifacts:** Fitted `StandardScaler` and `SimpleImputer` W&B Artifacts (from Notebook 04) are downloaded via `run.use_artifact()`.
4.  **Fetch Authoritative NB04 Configuration:** The W&B Run that produced these preprocessors is identified using `artifact.logged_by()`. Its configuration, containing the definitive `features` (time-varying, static), `preprocess` (imputation/scaling columns, strategies), `cnn_model_params` (for MRI input shape and CNN output size), and `preprocessing_config` (for MRI file suffix) is fetched. This becomes `config_for_dataset`.
5.  **Update `HP`:** Key hyperparameters like `HP['tabular_input_size']`, `HP['cnn_input_channels']`, and `HP['cnn_output_features']` are dynamically set or confirmed based on the fetched `config_for_dataset`.
6.  All source information is logged to the current Notebook 07 W&B run's configuration.

In [None]:
# --- Initialize W&B Run, Fetch Preprocessor Artifacts & NB04 Config ---
print("\n--- Initializing W&B Run for NB07 & Fetching NB04 Preprocessing Setup ---")

run = None
run_output_dir_for_checkpoints = None
SCALER_PATH_FROM_ARTIFACT = None
IMPUTER_PATH_FROM_ARTIFACT = None
config_for_dataset = {} 
source_nb04_run_id_logged = "N/A"

# --- 1. Initialize W&B Run for THIS Notebook 07 execution ---
temp_hp_for_init = HP.copy() # Log initial HP; it will be updated with input_size etc.

nb_number_prefix_nb07 = NOTEBOOK_MODULE_NAME.split('_')[0] if '_' in NOTEBOOK_MODULE_NAME else "NB"
job_specific_type_nb07 = f"{nb_number_prefix_nb07}-HybridCNNLSTM-{DATASET_IDENTIFIER}"
mod_drop_suffix = f"MD{HP.get('modality_dropout_rate',0.0):.1f}" if HP.get('modality_dropout_rate',0.0) > 0 else "NoMD"

custom_elements_for_name_nb07 = [
    nb_number_prefix_nb07, DATASET_IDENTIFIER.upper(), "Hybrid", mod_drop_suffix,
    f"h{HP.get('mri_lstm_hidden_size', HP['lstm_hidden_size'])}", # Use specific or general
    f"nl{HP['num_lstm_layers']}", f"dp{HP['lstm_dropout_prob']:.1f}",
    f"lr{HP['learning_rate']:.0e}", f"bs{HP['batch_size']}"
]
run = initialize_wandb_run(
    base_project_config=base_config,
    job_group="Training",
    job_specific_type=job_specific_type_nb07,
    run_specific_config=temp_hp_for_init, 
    custom_run_name_elements=custom_elements_for_name_nb07,
    notes=f"Training ModularLateFusionLSTM on {DATASET_IDENTIFIER.upper()} (Modality Dropout: {HP.get('modality_dropout_rate',0.0)})."
)

# --- 2. Define Run-Specific Output Directory for Local Checkpoints ---
if run:
    print(f"W&B run '{run.name}' (Job Type: '{run.job_type}') initialized. View at: {run.url}")
    run_output_dir_for_checkpoints = get_notebook_run_output_dir(
        PROJECT_ROOT, base_config, NB07_OUTPUT_LOCATOR_KEY, run, DATASET_IDENTIFIER # Use NB07's locator key
    )
    print(f"Model checkpoints for this run will be saved locally to: {run_output_dir_for_checkpoints}")
    run.config.update({"run_outputs/local_checkpoint_dir": str(run_output_dir_for_checkpoints)}, allow_val_change=True)
else:
    print("W&B run initialization failed. Proceeding with local fallback for outputs.")
    default_nb_locator_nb07 = f"{NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER}_local_outputs"
    run_output_dir_for_checkpoints = get_notebook_run_output_dir(
        PROJECT_ROOT, base_config, NB07_OUTPUT_LOCATOR_KEY if base_config else default_nb_locator_nb07, 
        None, DATASET_IDENTIFIER
    )
    print(f"Model checkpoints (local fallback) will be saved to: {run_output_dir_for_checkpoints}")

if not isinstance(run_output_dir_for_checkpoints, Path):
    run_output_dir_for_checkpoints = Path(run_output_dir_for_checkpoints)

# --- 3. Fetch Preprocessor Artifacts & NB04 Config (requires active NB07 W&B run) ---
if run: 
    try:
        # Define expected preprocessor artifact names (must match what NB04 logged)
        cfg_preprocessing = base_config.get('preprocessing_config', {}) # Main project config
        cfg_locators = base_config.get(f"pipeline_artefact_locators_{DATASET_IDENTIFIER}", {}) # Main project config

        scaling_strategy_name = cfg_preprocessing.get('scaling_strategy', 'standard_scaler')
        imputation_strategy_name = cfg_preprocessing.get('imputation_strategy', 'median')

        SCALER_ARTIFACT_LOGICAL_NAME = f"scaler_{scaling_strategy_name.lower().replace('scaler','').replace('_','')}_{DATASET_IDENTIFIER}"
        IMPUTER_ARTIFACT_LOGICAL_NAME = f"simple_imputer_{imputation_strategy_name.lower()}_{DATASET_IDENTIFIER}"
        
        PREPROCESSOR_ARTIFACT_TYPE = f"preprocessor_{DATASET_IDENTIFIER}"
        PREPROCESSOR_ARTIFACT_VERSION = "latest"

        # --- Download SCALER Artifact ---
        scaler_artifact_full_wandb_path = f"{base_config['wandb']['entity']}/{base_config['wandb']['project_name']}/{SCALER_ARTIFACT_LOGICAL_NAME}:{PREPROCESSOR_ARTIFACT_VERSION}"
        print(f"  Attempting to use Scaler artifact: {scaler_artifact_full_wandb_path}")
        scaler_artifact = run.use_artifact(scaler_artifact_full_wandb_path, type=PREPROCESSOR_ARTIFACT_TYPE)
        scaler_artifact_dir = Path(scaler_artifact.download())
        scaler_filename_in_artifact = cfg_locators.get('scaler_fname_pattern', '{scaling_strategy}_{dataset_identifier}.joblib').format(
            scaling_strategy=scaling_strategy_name.lower(), dataset_identifier=DATASET_IDENTIFIER)
        SCALER_PATH_FROM_ARTIFACT = scaler_artifact_dir / scaler_filename_in_artifact
        if not SCALER_PATH_FROM_ARTIFACT.is_file(): # Fallback scan
            found_scalers = list(scaler_artifact_dir.glob("*.joblib"))
            if found_scalers: SCALER_PATH_FROM_ARTIFACT = found_scalers[0]
            else: raise FileNotFoundError(f"Scaler .joblib ('{scaler_filename_in_artifact}') not found in artifact {scaler_artifact.name}")
        print(f"  Scaler artifact '{scaler_artifact.name}' downloaded. Local path: {SCALER_PATH_FROM_ARTIFACT}")

        # --- Download IMPUTER Artifact ---
        imputer_artifact_full_wandb_path = f"{base_config['wandb']['entity']}/{base_config['wandb']['project_name']}/{IMPUTER_ARTIFACT_LOGICAL_NAME}:{PREPROCESSOR_ARTIFACT_VERSION}"
        print(f"  Attempting to use Imputer artifact: {imputer_artifact_full_wandb_path}")
        imputer_artifact = run.use_artifact(imputer_artifact_full_wandb_path, type=PREPROCESSOR_ARTIFACT_TYPE)
        imputer_artifact_dir = Path(imputer_artifact.download())
        imputer_fname_pattern = cfg_locators.get('imputer_fname_pattern', 'simple_imputer_{imputation_strategy}_{dataset_identifier}.joblib')
        imputer_filename_in_artifact = imputer_fname_pattern.format(
            imputation_strategy=imputation_strategy_name.lower(), dataset_identifier=DATASET_IDENTIFIER)
        IMPUTER_PATH_FROM_ARTIFACT = imputer_artifact_dir / imputer_filename_in_artifact
        if not IMPUTER_PATH_FROM_ARTIFACT.is_file(): # Fallback scan
            found_imputers = list(imputer_artifact_dir.glob("*.joblib"))
            if found_imputers: IMPUTER_PATH_FROM_ARTIFACT = found_imputers[0]
            else: raise FileNotFoundError(f"Imputer .joblib ('{imputer_filename_in_artifact}') not found in artifact {imputer_artifact.name}")
        print(f"  Imputer artifact '{imputer_artifact.name}' downloaded. Local path: {IMPUTER_PATH_FROM_ARTIFACT}")

        # --- Fetch Full Configuration from the NB04 Run that Produced these Preprocessors ---
        nb04_producer_run = imputer_artifact.logged_by()  # Check if the run is correct  
        if nb04_producer_run:
            source_nb04_run_id_logged = nb04_producer_run.id
            nb04_run_name = nb04_producer_run.name
            print(f"  Fetching config from NB04 producer run: '{nb04_run_name}' (ID: {source_nb04_run_id_logged})")
            config_from_nb04_run = dict(nb04_producer_run.config)
            if 'features' not in config_from_nb04_run or 'preprocess' not in config_from_nb04_run \
                or 'cnn_model_params' not in config_from_nb04_run \
                or 'preprocessing_config' not in config_from_nb04_run: # For MRI suffix
                raise ValueError("Fetched NB04 config missing critical sections: 'features', 'preprocess', 'cnn_model_params', or 'preprocessing_config'.")
            config_for_dataset = config_from_nb04_run 
            print("  Definitive config for OASISDataset fetched from NB04 run.")

            # --- Update HP with input_size and CNN params from fetched config_for_dataset ---
            time_varying_cfg = config_for_dataset.get('features', {}).get('time_varying', [])
            static_cfg = config_for_dataset.get('features', {}).get('static', [])
            HP['tabular_input_size'] = len(time_varying_cfg) + len(static_cfg) # Specific for hybrid model
            
            cnn_params_from_nb04_config = config_for_dataset.get('cnn_model_params', {})
            HP['cnn_input_channels'] = cnn_params_from_nb04_config.get('input_shape', [1,0,0,0])[0] # Default from Simple3DCNN
            HP['cnn_output_features'] = cnn_params_from_nb04_config.get('output_features', 128)   # Default from Simple3DCNN
            
            print(f"  Updated HP: 'tabular_input_size' to {HP['tabular_input_size']}")
            print(f"  Updated HP: 'cnn_input_channels' to {HP['cnn_input_channels']}")
            print(f"  Updated HP: 'cnn_output_features' to {HP['cnn_output_features']}")
            if HP['tabular_input_size'] == 0 or HP['cnn_input_channels'] == 0:
                print("  WARNING: Calculated tabular_input_size or cnn_input_channels is 0. Check NB04 config.")

            # Update current NB07 run's config with all this provenance and final HPs
            current_run_config_update = {
                "source_config_from_nb04/producer_run_id": source_nb04_run_id_logged,
                "source_config_from_nb04/producer_run_name": nb04_run_name,
                "source_config_from_nb04/scaler_artifact_used": scaler_artifact.name,
                "source_config_from_nb04/imputer_artifact_used": imputer_artifact.name,
                "dataset_config_used/features": config_for_dataset.get('features',{}),
                "dataset_config_used/preprocess": config_for_dataset.get('preprocess',{}),
                "dataset_config_used/cnn_model_params": config_for_dataset.get('cnn_model_params',{}),
                "dataset_config_used/preprocessing_config_mri": config_for_dataset.get('preprocessing_config',{}),
            }
            current_run_config_update.update(HP) # Add all HPs to the run config
            run.config.update(current_run_config_update, allow_val_change=True)
            print("  NB07 W&B run config updated with source NB04 info, final dataset config, and all HPs.")
        else:
            raise ConnectionError("Could not retrieve producer run from preprocessor artifact.")
    except Exception as e_fetch_nb04_all:
        print(f"CRITICAL ERROR fetching preprocessor artifacts or config from NB04: {e_fetch_nb04_all}")
        if run: run.finish(exit_code=1)
        # exit()
else: # W&B run for NB07 failed to initialize
    print("CRITICAL ERROR: W&B run for NB07 not initialized. Cannot fetch artifacts or proceed.")

# Final check
if not all([SCALER_PATH_FROM_ARTIFACT, IMPUTER_PATH_FROM_ARTIFACT, 
            config_for_dataset.get('features'), HP.get('tabular_input_size') is not None,
            HP.get('cnn_input_channels') is not None, HP.get('cnn_output_features') is not None]):
    print("CRITICAL ERROR: Essential components for model training are not properly set up after NB04 config/artifact fetch.")
    # exit()

## 4. Setup Device

Detect and configure the appropriate PyTorch device (CUDA GPU, MPS for Apple Silicon, or CPU) for model training and data handling. The selected device is logged to the current W&B run's configuration.

In [None]:
# --- Setup Device (GPU/CPU/MPS) ---
print("\n--- Setting up PyTorch Device ---")

# This logic prioritizes MPS if available, then CUDA, then CPU.
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
    print("Using MPS (Apple Silicon GPU).")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("CUDA and MPS not available. Using CPU.")

# Log the determined device to the W&B run config for this training session
if run: # 'run' is the W&B run object for this NB07 execution
    run.config.update({'training_details/device_used': str(device)}, allow_val_change=True)
    print(f"Device '{str(device)}' logged to W&B run config.")
else:
    print(f"Device set to '{str(device)}' (W&B run not active for logging).")

## 5. Load Data & Create DataLoaders for Hybrid Model

This section instantiates the custom `OASISDataset` for the training and validation data splits (`cohort_train.parquet` and `cohort_validation.parquet`). Key configurations for the dataset are critical:

* **Preprocessor Paths:** The paths to the *fitted* `StandardScaler` and `SimpleImputer` (obtained by downloading W&B Artifacts logged by Notebook 04) are provided to `OASISDataset`.
* **Feature & Preprocessing Configuration:** The `config_for_dataset` dictionary (fetched from the W&B configuration of the Notebook 04 run that produced the preprocessors) is passed to `OASISDataset`. This ensures it uses the authoritative lists of time-varying/static features and applies the corresponding imputation/scaling.
* **MRI Data:** For this hybrid model, `include_mri` is explicitly set to `True`, and the path to the preprocessed MRI scans (`MRI_DATA_DIR`) is provided.

The instantiated `OASISDataset` objects are then wrapped in PyTorch `DataLoader`s, which handle batching and use the custom `pad_collate_fn` to manage variable sequence lengths and the dual-modality input. The model input sizes (tabular and CNN-derived) are also confirmed here.

In [None]:
# --- Instantiate Datasets and DataLoaders for Hybrid Model Training ---
print("\n--- Loading Data and Creating DataLoaders for Hybrid Model Training ---")

# Initialize dataset and loader variables
train_dataset: OASISDataset | None = None
val_dataset: OASISDataset | None = None
train_loader: DataLoader | None = None
val_loader: DataLoader | None = None

# --- Prerequisite Variable Check ---
# These are expected from preceding cells:
#   TRAIN_DATA_PATH, VAL_DATA_PATH (Paths to .parquet files)
#   SCALER_PATH_FROM_ARTIFACT, IMPUTER_PATH_FROM_ARTIFACT (Paths to downloaded .joblib files)
#   config_for_dataset (Authoritative config from NB04 run for OASISDataset)
#   MRI_DATA_DIR (Path to preprocessed MRI scans)
#   HP (Hyperparameter dictionary for batch_size, num_workers, and model input sizes)
#   pad_collate_fn, OASISDataset (Imported classes)

required_vars_for_hybrid_loading = {
    'TRAIN_DATA_PATH': TRAIN_DATA_PATH, 'VAL_DATA_PATH': VAL_DATA_PATH,
    'SCALER_PATH_FROM_ARTIFACT': SCALER_PATH_FROM_ARTIFACT, 
    'IMPUTER_PATH_FROM_ARTIFACT': IMPUTER_PATH_FROM_ARTIFACT,
    'config_for_dataset': config_for_dataset, 
    'MRI_DATA_DIR': MRI_DATA_DIR, 'HP': HP
}

proceed_with_hybrid_loading = True
for var_name, var_value in required_vars_for_hybrid_loading.items():
    if var_value is None:
        print(f"  CRITICAL ERROR: Prerequisite variable '{var_name}' for DataLoaders is None.")
        proceed_with_hybrid_loading = False
    if var_name == 'config_for_dataset' and \
       (not var_value or 'features' not in var_value or 'preprocess' not in var_value or \
        'cnn_model_params' not in var_value or 'preprocessing_config' not in var_value): # MRI config also needed
        print(f"  CRITICAL ERROR: 'config_for_dataset' is empty or missing essential sections "
              "('features', 'preprocess', 'cnn_model_params', 'preprocessing_config').")
        proceed_with_hybrid_loading = False
    if var_name == 'HP' and (not var_value or HP.get('tabular_input_size') is None or \
                             HP.get('cnn_input_channels') is None or HP.get('cnn_output_features') is None):
        print(f"  CRITICAL ERROR: 'HP' dictionary missing key model input sizes "
              "('tabular_input_size', 'cnn_input_channels', 'cnn_output_features'). "
              "Ensure these were set after fetching NB04 config.")
        proceed_with_hybrid_loading = False

if not proceed_with_hybrid_loading:
    print("Halting DataLoader creation due to missing or invalid prerequisites.")
    if run: run.finish(exit_code=1)
    # exit()
else:
    try:
        print(f"\nInstantiating training dataset for {DATASET_IDENTIFIER.upper()} (Hybrid - MRI Included)...")
        train_dataset = OASISDataset(
            data_parquet_path=TRAIN_DATA_PATH,
            scaler_path=SCALER_PATH_FROM_ARTIFACT,
            imputer_path=IMPUTER_PATH_FROM_ARTIFACT,
            config=config_for_dataset, # Authoritative config from NB04 W&B run
            mri_data_dir=MRI_DATA_DIR,
            include_mri=True # Explicitly True for hybrid model
        )
        num_train_subjects = len(train_dataset)
        print(f"  Train dataset (hybrid) created. Number of subjects (sequences): {num_train_subjects}")

        print(f"\nInstantiating validation dataset for {DATASET_IDENTIFIER.upper()} (Hybrid - MRI Included)...")
        val_dataset = OASISDataset(
            data_parquet_path=VAL_DATA_PATH,
            scaler_path=SCALER_PATH_FROM_ARTIFACT, 
            imputer_path=IMPUTER_PATH_FROM_ARTIFACT,
            config=config_for_dataset, 
            mri_data_dir=MRI_DATA_DIR,
            include_mri=True 
        )
        num_val_subjects = len(val_dataset)
        print(f"  Validation dataset (hybrid) created. Number of subjects (sequences): {num_val_subjects}")

        # --- Create DataLoaders ---
        print("\nCreating DataLoaders for Hybrid Model...")
        train_loader = DataLoader(
            train_dataset, 
            batch_size=HP['batch_size'], 
            shuffle=True, 
            collate_fn=pad_collate_fn, # Handles 5-item tuples for hybrid
            num_workers=HP.get('dataloader_num_workers', 0),
            persistent_workers=(HP.get('dataloader_num_workers',0) > 0) if HP.get('dataloader_num_workers',0) > 0 else False
        )

        val_loader = DataLoader(
            val_dataset, 
            batch_size=HP['batch_size'], 
            shuffle=False, 
            collate_fn=pad_collate_fn, # Handles 5-item tuples for hybrid
            num_workers=HP.get('dataloader_num_workers', 0),
            persistent_workers=(HP.get('dataloader_num_workers',0) > 0) if HP.get('dataloader_num_workers',0) > 0 else False
        )
        print("Train and Validation DataLoaders for hybrid model created.")
        print(f"  Number of batches in train_loader: ~{len(train_loader)}")
        print(f"  Number of batches in val_loader: ~{len(val_loader)}")

        # Log dataset info to W&B for this training run
        if run:
            run.log({
                f'dataset_info_{DATASET_IDENTIFIER}/train_subjects_hybrid': num_train_subjects,
                f'dataset_info_{DATASET_IDENTIFIER}/val_subjects_hybrid': num_val_subjects,
                f'dataset_info_{DATASET_IDENTIFIER}/tabular_input_size_for_model': HP['tabular_input_size'],
                f'dataset_info_{DATASET_IDENTIFIER}/cnn_output_features_as_mri_lstm_input': HP['cnn_output_features'],
                f'dataset_info_{DATASET_IDENTIFIER}/batch_size_hybrid': HP['batch_size']
            })
        
        print(f"\nInput sizes for ModularLateFusionLSTM:")
        print(f"  Tabular Input Size (num_tabular_features): {HP['tabular_input_size']}")
        print(f"  CNN Input Channels: {HP['cnn_input_channels']}")
        print(f"  CNN Output Features (becomes MRI LSTM input_size): {HP['cnn_output_features']}")
        if HP['tabular_input_size'] == 0 or HP['cnn_input_channels'] == 0 or HP['cnn_output_features'] == 0 :
            print("WARNING: One of the key input sizes for the model is 0. Model instantiation might fail or behave unexpectedly.")

    except FileNotFoundError as e_fnf_ds_nb07:
        print(f"CRITICAL ERROR during OASISDataset instantiation in NB07: File not found - {e_fnf_ds_nb07}")
        if run: run.finish(exit_code=1)
        # exit()
    except KeyError as e_key_ds_nb07:
        print(f"CRITICAL ERROR during OASISDataset instantiation in NB07: Missing key in 'config_for_dataset' - {e_key_ds_nb07}")
        if run: run.finish(exit_code=1)
        # exit()
    except Exception as e_ds_init_nb07:
        print(f"An unexpected CRITICAL ERROR occurred during OASISDataset or DataLoader instantiation in NB07: {e_ds_init_nb07}")
        import traceback
        traceback.print_exc()
        if run: run.finish(exit_code=1)
        # exit()

# Final check for subsequent cells
if train_loader is None or val_loader is None:
    print("CRITICAL ERROR: DataLoaders were not successfully created. Cannot proceed with model training.")
    # exit()

## 6. Define Hybrid Model, Loss Function, and Optimizer

This section sets up the components for training the `ModularLateFusionLSTM` model:
1.  **Model Instantiation:** The `ModularLateFusionLSTM` (imported from `src/models.py`) is instantiated. It uses hyperparameters from the `HP` dictionary, including the `tabular_input_size` and CNN parameters (`cnn_input_channels`, `cnn_output_features`) which were dynamically determined from the Notebook 04 configuration. It also uses LSTM specific parameters like hidden sizes, number of layers, dropout, and the `modality_dropout_rate`. The model is then moved to the configured PyTorch `device`.
2.  **Loss Function:** `nn.MSELoss` is used for this regression task.
3.  **Optimizer:** `optim.Adam` is selected.
4.  **W&B Model Watching:** `wandb.watch()` is called to monitor model gradients and parameters.

In [None]:
# --- Define Hybrid Model, Loss Function, and Optimizer ---
print("\n--- Defining Hybrid Model (ModularLateFusionLSTM), Loss Function, and Optimizer ---")

# These variables are expected from previous cells:
# HP (dictionary with all necessary HPs for ModularLateFusionLSTM)
# device (torch.device)
# ModularLateFusionLSTM (imported model class)
# run (active W&B run object for NB07, or None)

model: ModularLateFusionLSTM | None = None # Type hint for clarity
criterion: nn.Module | None = None
optimizer: optim.Optimizer | None = None

# Ensure HP dictionary and critical model parameters are available
required_hp_keys_hybrid = [
    'tabular_input_size', 'cnn_input_channels', 'cnn_output_features',
    'mri_lstm_hidden_size', 'tabular_lstm_hidden_size', 'num_lstm_layers',
    'lstm_dropout_prob', 'learning_rate' 
    # 'modality_dropout_rate' is optional (defaults to 0.0 in model __init__)
]
if 'HP' not in locals() or not isinstance(HP, dict) or \
   not all(HP.get(key) is not None for key in required_hp_keys_hybrid): # Check if all required HPs are set
    missing_keys_str = [key for key in required_hp_keys_hybrid if HP.get(key) is None]
    print(f"CRITICAL ERROR: HP dictionary is not defined or missing one or more critical keys for ModularLateFusionLSTM: {missing_keys_str}")
    print("Ensure these were correctly set after fetching NB04 config and defining HPs for NB07.")
    if run: run.finish(exit_code=1)
    # exit() 
else:
    try:
        # 1. Instantiate the ModularLateFusionLSTM model
        print(f"  Instantiating ModularLateFusionLSTM with:")
        print(f"    Tabular Input Size: {HP['tabular_input_size']}")
        print(f"    CNN Input Channels: {HP['cnn_input_channels']}")
        print(f"    CNN Output Features (MRI LSTM Input): {HP['cnn_output_features']}")
        print(f"    MRI LSTM Hidden Size: {HP.get('mri_lstm_hidden_size', HP.get('lstm_hidden_size', 128))}") # Fallback
        print(f"    Tabular LSTM Hidden Size: {HP.get('tabular_lstm_hidden_size', HP.get('lstm_hidden_size', 128))}")
        print(f"    Num LSTM Layers: {HP['num_lstm_layers']}")
        print(f"    LSTM Dropout: {HP['lstm_dropout_prob']}")
        print(f"    Modality Dropout Rate: {HP.get('modality_dropout_rate', 0.0)}")
        
        model = ModularLateFusionLSTM(
            cnn_input_channels=HP['cnn_input_channels'],
            cnn_output_features=HP['cnn_output_features'],
            tabular_input_size=HP['tabular_input_size'],
            mri_lstm_hidden_size=HP.get('mri_lstm_hidden_size', HP.get('lstm_hidden_size')), # Use .get for safety
            tabular_lstm_hidden_size=HP.get('tabular_lstm_hidden_size', HP.get('lstm_hidden_size')),
            num_lstm_layers=HP['num_lstm_layers'],
            lstm_dropout_prob=HP['lstm_dropout_prob'],
            modality_dropout_rate=HP.get('modality_dropout_rate', 0.0), # Defaults to 0.0 if not in HP
            num_classes=1 # For regression of CDR score
        )
        model.to(device) # Move model to the selected device
        print(f"  ModularLateFusionLSTM model instantiated and moved to device: {device}.")
        # print(model) # Optional: print model architecture

        # 2. Define the Loss Function (Criterion)
        criterion = nn.MSELoss()
        print(f"  Loss function set to: {type(criterion).__name__}")

        # 3. Define the Optimizer
        optimizer = optim.Adam(model.parameters(), lr=HP['learning_rate'])
        print(f"  Optimizer set to: {type(optimizer).__name__} with Learning Rate = {HP['learning_rate']}")

        # 4. Optional: Watch model with W&B
        if run:
            try:
                wandb.watch(model, criterion=criterion, log='all', log_freq=100) 
                print("  W&B model watching enabled (gradients, parameters, etc.).")
            except Exception as e_watch_hybrid:
                print(f"  Warning: wandb.watch() for hybrid model failed. Error: {e_watch_hybrid}")

    except KeyError as e_key_model_hybrid_init:
        print(f"CRITICAL ERROR: Missing a key in HP dictionary required for ModularLateFusionLSTM: {e_key_model_hybrid_init}")
        if run: run.finish(exit_code=1)
        # exit()
    except Exception as e_model_setup_hybrid_generic:
        print(f"CRITICAL ERROR during hybrid model, loss, or optimizer setup: {e_model_setup_hybrid_generic}")
        import traceback
        traceback.print_exc()
        if run: run.finish(exit_code=1)
        # exit()

# Final check for subsequent cells
if model is None or criterion is None or optimizer is None:
    print("CRITICAL ERROR: Hybrid model, criterion, or optimizer was not successfully initialized. Cannot proceed.")
    # exit()

## 7. Train Hybrid Model using `train_model` Utility

The core training and validation loop for the `ModularLateFusionLSTM` is executed by calling the `train_model` utility function from `src/training_utils.py`. This utility is passed the instantiated hybrid model, data loaders, criterion, optimizer, and other training parameters.

**Crucially, `model_type_flag="hybrid"` is passed to `train_model`.** This flag ensures that the utility correctly unpacks the 5-item batches (tabular sequences, MRI sequences, lengths, targets, masks) produced by `pad_collate_fn` for hybrid data and calls the hybrid model's `forward` method with the appropriate arguments.

The `train_model` utility handles epoch iteration, metric calculation (via `evaluate_model`), W&B logging, best and periodic model checkpointing (saving `.pth` files locally and as W&B Artifacts), and early stopping.

In [None]:
# --- Train Hybrid Model using the train_model Utility from src/training_utils.py ---
print("\n--- Starting Hybrid Model Training using 'train_model' utility ---")

# Prerequisites from previous cells:
# model (ModularLateFusionLSTM instance), train_loader, val_loader, criterion, optimizer, 
# device, HP, run (W&B object), run_output_dir_for_checkpoints, evaluate_model (imported function)

path_to_best_checkpoint_local_hybrid: Path | None = None 
best_validation_loss_achieved_hybrid: float = float('inf')

# Check if all necessary components for calling train_model are ready
training_prerequisites_met_hybrid = True
required_vars_for_hybrid_training_call = [
    'model', 'train_loader', 'val_loader', 'criterion', 'optimizer', 'device', 'HP',
    'run_output_dir_for_checkpoints', 'evaluate_model'
]
for var_name_check in required_vars_for_hybrid_training_call:
    if var_name_check not in locals() or locals()[var_name_check] is None:
        if var_name_check == 'run' and ('run' not in locals() or locals()['run'] is None):
            print("  Note: W&B run object is None. Training will proceed with logging disabled within train_model.")
        else:
            print(f"  CRITICAL ERROR: Prerequisite variable '{var_name_check}' for training hybrid model is not defined or is None.")
            training_prerequisites_met_hybrid = False

if not training_prerequisites_met_hybrid:
    print("Halting hybrid model training due to missing or invalid prerequisites for 'train_model' utility.")
    if run: run.finish(exit_code=1)
    # exit()
else:
    try:
        print(f"  Calling 'train_model' for {HP['epochs']} epochs. Checkpoints: {run_output_dir_for_checkpoints}")
        
        path_to_best_checkpoint_local_hybrid, best_validation_loss_achieved_hybrid = train_model(
            model=model,
            train_loader=train_loader,
            val_loader=val_loader,
            criterion=criterion,
            optimizer=optimizer,
            device=device,
            num_epochs=HP['epochs'],
            wandb_run=run, 
            checkpoint_dir=run_output_dir_for_checkpoints, 
            evaluate_fn=evaluate_model, 
            model_type_flag="hybrid",
            hp_dict=HP, 
            best_val_loss_init=float('inf') 
        )

        print(f"\n--- Hybrid Model Training Complete (via 'train_model' utility) ---")
        print(f"Best validation loss achieved during training: {best_validation_loss_achieved_hybrid:.4f}")
        if path_to_best_checkpoint_local_hybrid and path_to_best_checkpoint_local_hybrid.exists():
            print(f"Best hybrid model checkpoint saved locally at: {path_to_best_checkpoint_local_hybrid}")
            if run:
                run.summary['best_model_local_path_hybrid'] = str(path_to_best_checkpoint_local_hybrid)
                run.summary['best_val_loss_final_hybrid'] = best_validation_loss_achieved_hybrid
        elif path_to_best_checkpoint_local_hybrid:
            print(f"Warning: train_model returned best checkpoint path, but file not found: {path_to_best_checkpoint_local_hybrid}")
        else:
            print("No best model checkpoint was saved for hybrid model (e.g., no improvement or training error).")

    except Exception as e_train_hybrid:
        print(f"CRITICAL ERROR occurred during hybrid model training: {e_train_hybrid}")
        import traceback
        traceback.print_exc()
        if run: run.finish(exit_code=1)
        # exit()

## 8. Evaluate Best Hybrid Model on Test Set

Following training, the best `ModularLateFusionLSTM` model (based on the lowest validation loss) is loaded from its saved checkpoint. This model is then evaluated on the held-out test set (`cohort_test.parquet`) to assess its generalization performance.

The `evaluate_model` utility is used with `model_name_for_batch_unpack="hybrid"` to ensure correct data handling. Test metrics (Loss, MSE, MAE, R²) are calculated and logged to this W&B training run's summary.

In [None]:
# --- Evaluate Best Hybrid Model on Test Set ---
print("\n--- Evaluating Best Hybrid Model on Test Set ---")

# Ensure prerequisites for test evaluation are available
# path_to_best_checkpoint_local_hybrid, HP, config_for_dataset, device, TEST_DATA_PATH,
# SCALER_PATH_FROM_ARTIFACT, IMPUTER_PATH_FROM_ARTIFACT, MRI_DATA_DIR, criterion

if 'path_to_best_checkpoint_local_hybrid' in locals() and \
   path_to_best_checkpoint_local_hybrid is not None and \
   path_to_best_checkpoint_local_hybrid.is_file() and \
   'HP' in locals() and HP and \
   HP.get('tabular_input_size') is not None and \
   HP.get('cnn_input_channels') is not None and \
   HP.get('cnn_output_features') is not None and \
   'config_for_dataset' in locals() and config_for_dataset and \
   'TEST_DATA_PATH' in locals() and TEST_DATA_PATH.is_file() and \
   'SCALER_PATH_FROM_ARTIFACT' in locals() and SCALER_PATH_FROM_ARTIFACT.is_file() and \
   'IMPUTER_PATH_FROM_ARTIFACT' in locals() and IMPUTER_PATH_FROM_ARTIFACT.is_file() and \
   'MRI_DATA_DIR' in locals() and MRI_DATA_DIR.is_dir() and \
   'criterion' in locals() and criterion is not None and \
   'device' in locals() and device is not None:

    print(f"Loading best hybrid model for test evaluation from: {path_to_best_checkpoint_local_hybrid}")

    try:
        # 1. Instantiate a new ModularLateFusionLSTM model with the same HPs
        test_model_hybrid = ModularLateFusionLSTM(
            cnn_input_channels=HP['cnn_input_channels'],
            cnn_output_features=HP['cnn_output_features'],
            tabular_input_size=HP['tabular_input_size'],
            mri_lstm_hidden_size=HP.get('mri_lstm_hidden_size', HP['lstm_hidden_size']),
            tabular_lstm_hidden_size=HP.get('tabular_lstm_hidden_size', HP['lstm_hidden_size']),
            num_lstm_layers=HP['num_lstm_layers'],
            lstm_dropout_prob=HP['lstm_dropout_prob'],
            modality_dropout_rate=0.0, # Crucial: Set to 0 for deterministic test evaluation
            num_classes=1
        )
        
        # Load the saved state dictionary of the best model
        # The checkpoint saved by train_model utility contains only model.state_dict()
        test_model_hybrid.load_state_dict(torch.load(path_to_best_checkpoint_local_hybrid, map_location=device))
        test_model_hybrid.to(device)
        test_model_hybrid.eval() 
        print("  Best hybrid model loaded and set to evaluation mode.")

        # 2. Create DataLoader for the Test Set (with MRI data)
        print(f"  Instantiating test dataset from: {TEST_DATA_PATH} (Hybrid - MRI Included)...")
        test_dataset_hybrid = OASISDataset(
            data_parquet_path=TEST_DATA_PATH,
            scaler_path=SCALER_PATH_FROM_ARTIFACT,
            imputer_path=IMPUTER_PATH_FROM_ARTIFACT,
            config=config_for_dataset, # Use the same authoritative config from NB04
            mri_data_dir=MRI_DATA_DIR,
            include_mri=True # Explicitly True for the hybrid model
        )
        num_test_subjects_hybrid = len(test_dataset_hybrid)
        print(f"  Test dataset (hybrid) created with {num_test_subjects_hybrid} subjects.")

        test_loader_hybrid = DataLoader(
            test_dataset_hybrid,
            batch_size=HP['batch_size'], 
            shuffle=False, 
            collate_fn=pad_collate_fn, # Will yield 5-item tuples
            num_workers=HP.get('dataloader_num_workers', 0)
        )
        print(f"  Test DataLoader (hybrid) created. Number of batches: ~{len(test_loader_hybrid)}")

        # 3. Perform Evaluation using the utility function
        if 'criterion' not in locals() or criterion is None: 
            criterion = nn.MSELoss() # Ensure criterion is defined
            print("  Re-initialized criterion to MSELoss for test evaluation.")

        print(f"  Evaluating best hybrid model on {num_test_subjects_hybrid} test subjects...")
        test_set_metrics_dict_hybrid = evaluate_model(
            test_model_hybrid, 
            test_loader_hybrid, 
            criterion, 
            device, 
            model_name_for_batch_unpack="hybrid"
        )

        print(f"\n--- Test Set Performance of Best Hybrid Model ---")
        if 'best_validation_loss_achieved_hybrid' in locals():
             print(f"  (Model achieved best validation loss: {best_validation_loss_achieved_hybrid:.4f})")
        
        print(f"  Test Loss (Criterion): {test_set_metrics_dict_hybrid.get('loss', float('nan')):.4f}")
        print(f"  Test MSE (from preds): {test_set_metrics_dict_hybrid.get('mse', float('nan')):.4f}")
        print(f"  Test MAE:  {test_set_metrics_dict_hybrid.get('mae', float('nan')):.4f}")
        print(f"  Test R2:   {test_set_metrics_dict_hybrid.get('r2', float('nan')):.4f}")

        # 4. Log Test Metrics to W&B Summary
        if run:
            run.summary["test_set_hybrid/loss_criterion"] = test_set_metrics_dict_hybrid.get('loss')
            run.summary["test_set_hybrid/mse_from_preds"] = test_set_metrics_dict_hybrid.get('mse')
            run.summary["test_set_hybrid/mae"] = test_set_metrics_dict_hybrid.get('mae')
            run.summary["test_set_hybrid/r2"] = test_set_metrics_dict_hybrid.get('r2')
            if 'best_validation_loss_achieved_hybrid' in locals():
                run.summary["test_set_hybrid/checkpoint_achieved_best_val_loss"] = best_validation_loss_achieved_hybrid
            run.summary["test_set_hybrid/num_test_subjects"] = num_test_subjects_hybrid
            run.summary["test_set_hybrid/num_test_visits"] = len(test_dataset_hybrid.data_df) if hasattr(test_dataset_hybrid, 'data_df') else 'N/A'
            print("  Hybrid model test metrics logged to W&B run summary.")
            
    except Exception as e_test_eval_hybrid:
        print(f"CRITICAL ERROR during hybrid model test set evaluation: {e_test_eval_hybrid}")
        import traceback
        traceback.print_exc()
        if run: run.summary["test_set_hybrid/status"] = f"EvaluationError: {str(e_test_eval_hybrid)[:100]}"

else:
    print("Skipping hybrid model test set evaluation: Best model checkpoint not found/valid, "
          "or other prerequisite variables are missing.")
    if run: run.summary["test_set_hybrid/status"] = "SkippedMissingPrerequisites"

## 9. Finalize W&B Run

Complete the execution of this hybrid model training notebook and finish the associated Weights & Biases run. This ensures all queued logs, metrics, configurations, and model artifacts (including the best model checkpoint and any periodic checkpoints) are fully uploaded and synchronized with the W&B platform.

In [None]:
# --- Finish W&B Run for Notebook 07 (Hybrid Model Training) ---
print(f"\n--- {NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER} (Hybrid Training) complete. Finishing W&B run. ---")

if run: # Check if 'run' object exists and is an active W&B run
    try:
        # Ensure final best_val_loss is in summary (train_model utility also logs this if W&B active)
        if 'best_validation_loss_achieved_hybrid' in locals() and 'best_val_loss_final_hybrid' not in run.summary : # Check if already logged by train_model more directly
            run.summary['best_val_loss_final_hybrid'] = best_validation_loss_achieved_hybrid
        
        run.finish()
        run_name_to_print = run.name_synced if hasattr(run, 'name_synced') and run.name_synced else \
                            run.name if hasattr(run, 'name') and run.name else \
                            run.id if hasattr(run, 'id') else "current NB07 run"
        print(f"W&B run '{run_name_to_print}' finished successfully.")
    except Exception as e_finish_run_nb07:
        print(f"Error during wandb.finish() for Notebook 07: {e_finish_run_nb07}")
        print("The run may not have finalized correctly on the W&B server.")
else:
    print("No active W&B run to finish for this training session.")

print(f"\n--- Notebook {NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER} execution finished. ---")