# Notebook 05: OASIS-2 Data Loading Pipeline Test

**Project Phase:** 1 (Data Processing - DataLoader Verification)
**Dataset:** OASIS-2 Longitudinal MRI & Clinical Data

**Purpose:**
This notebook is dedicated to verifying the integrity and correct behavior of the custom data loading pipeline, primarily the `OASISDataset` class and `pad_collate_fn` (defined in `src/datasets.py`). This involves:

1.  **Load Training & Validation Data Splits:** Use paths resolved by `get_dataset_stage_paths` (from `config.json`) to access `cohort_train.parquet` and `cohort_validation.parquet` (outputs from Notebook 03).
2.  **Consume Preprocessor Artifacts & Fetch NB04 Config:**
    * Consume the versioned **fitted preprocessor W&B Artifacts** (e.g., `scaler_standard_oasis2:latest`, `imputer_median_oasis2:latest`) produced by Notebook 04.
    * Download these artifacts to get local paths to the `.joblib` files.
    * Identify the W&B Run from Notebook 04 that **produced** these preprocessor artifacts using `artifact.logged_by()`.
    * Fetch the definitive **`features` (time-varying, static) and `preprocess` (imputation/scaling columns and strategies) configurations** directly from this producer Notebook 04 run's W&B config. This ensures `OASISDataset` is initialized with the authoritative settings.
3.  **Instantiate `OASISDataset`:** Create dataset instances for both training and validation data, ensuring they correctly use the downloaded preprocessors and the fetched feature configurations, with the option to `include_mri=True`.
4.  **Test `DataLoader` with `pad_collate_fn`:** Wrap the datasets in PyTorch `DataLoader`s.
5.  **Verify Batch Structure:** Iterate through a few batches from the `train_loader` and:
    * Unpack all expected components (padded tabular sequences, padded MRI sequences, lengths, targets, masks).
    * Print their shapes, data types, and example values.
    * Confirm that padding, masking, and data types are as expected for model input.
6.  Log key batch characteristics and test status to a new W&B run for this notebook.

**Input:**
* `config.json`: Main project configuration file.
* **W&B Artifact Names for Fitted Preprocessors:** e.g., `"scaler_standard_oasis2:latest"`, `"imputer_median_oasis2:latest"` (produced by Notebook 04).
* `cohort_train.parquet`, `cohort_validation.parquet`: Data splits (output from Notebook 03, paths obtained via `get_dataset_stage_paths`).
* `src/datasets.py`: Contains `OASISDataset` and `pad_collate_fn`.
* Directory containing preprocessed MRI scans (if `include_mri=True`).

**Output:**
* Console output displaying the properties (shapes, types, example values) of generated batches for verification.
* W&B Run: Logs the configuration used (including source NB04 run ID and consumed preprocessor artifact names) and key characteristics of the test batches.

In [None]:
# In: notebooks/05_Test_DataLoader.ipynb
# Purpose: Test the custom OASISDataset and pad_collate_fn
#          to ensure data loading, preprocessing (using saved objects),
#          sequencing, padding, and batching works correctly.

In [None]:
# --- Import Libraries ---
import pandas as pd
import numpy as np
import torch 
from torch.utils.data import DataLoader
import joblib 
import json
from pathlib import Path
import sys
import os
import wandb 
import time

## 1. Setup: Project Configuration, Paths, and Utilities

Initialize the notebook environment:
* Determines the project root and adds the `src` directory to `sys.path`.
* Imports custom utilities: `initialize_wandb_run` and `get_dataset_stage_paths`.
* Imports `OASISDataset` and `pad_collate_fn` from `src/datasets.py`.
* Loads the main project configuration (`base_config`) from `config.json`.
* Defines dataset and notebook-specific identifiers.
* **Uses `get_dataset_paths` to resolve paths for the input training and validation data splits (from Notebook 03) and the general MRI data directory.** Paths to preprocessor `.joblib` files will be obtained later by downloading W&B artifacts.
* Defines parameters for DataLoader testing.

In [None]:
# --- Project Setup, Configuration Loading, and Utility Imports ---
print("--- Initializing Project Setup & Configuration for NB05 ---")

# Initialize
PROJECT_ROOT = None
base_config = {} 

try:
    current_notebook_path = Path.cwd() 
    potential_project_root = current_notebook_path.parent 
    if (potential_project_root / "src").is_dir() and (potential_project_root / "config.json").is_file():
        PROJECT_ROOT = potential_project_root
    else: 
        PROJECT_ROOT = current_notebook_path
    if not (PROJECT_ROOT / "src").is_dir() or not (PROJECT_ROOT / "config.json").is_file():
        raise FileNotFoundError(f"Could not find 'src' or 'config.json'. PROJECT_ROOT: {PROJECT_ROOT}")
    if str(PROJECT_ROOT) not in sys.path:
        sys.path.insert(0, str(PROJECT_ROOT))
    print(f"PROJECT_ROOT: {PROJECT_ROOT}, added to sys.path.")

    from src.wandb_utils import initialize_wandb_run
    from src.paths_utils import get_dataset_paths 
    from src.datasets import OASISDataset, pad_collate_fn
    print("Successfully imported custom utilities and dataset classes.")

except Exception as e_setup:
    print(f"CRITICAL ERROR during initial setup: {e_setup}")
    # exit()

# --- Load Main Project Configuration ---
print("\n--- Loading Main Project Configuration ---")
try:
    if PROJECT_ROOT is None: raise ValueError("PROJECT_ROOT not set.")
    CONFIG_PATH_MAIN = PROJECT_ROOT / 'config.json'
    with open(CONFIG_PATH_MAIN, 'r', encoding='utf-8') as f:
        base_config = json.load(f)
    print(f"Main project config loaded from: {CONFIG_PATH_MAIN}")

    WANDB_ENTITY = base_config.get('wandb', {}).get('entity')
    WANDB_PROJECT = base_config.get('wandb', {}).get('project_name')
    if not WANDB_ENTITY or not WANDB_PROJECT:
        raise KeyError("WANDB_ENTITY or WANDB_PROJECT not found in config.json ['wandb'] section.")
    print(f"W&B Entity: {WANDB_ENTITY}, Project: {WANDB_PROJECT}")

except Exception as e_cfg:
    print(f"CRITICAL ERROR loading main config.json: {e_cfg}")
    # exit() 

# --- Define Dataset, Notebook Specifics, and Resolve Key INPUT Paths ---
DATASET_IDENTIFIER = "oasis2" 
NOTEBOOK_MODULE_NAME = "05_Test_DataLoader"

# Paths for data splits (inputs from NB03)
TRAIN_DATA_PATH = None
VAL_DATA_PATH = None
MRI_DATA_DIR = None # General preprocessed MRI directory

# Preprocessor paths will be determined by downloading artifacts later
SCALER_PATH_FROM_ARTIFACT = None 
IMPUTER_PATH_FROM_ARTIFACT = None

try:
    if not base_config: raise ValueError("base_config is empty.")
    
    # Get paths for training stage data (train/val splits) and MRI dir
    # The preprocessor paths returned by this utility won't be used directly for loading preprocessors,
    # as we'll use artifacts, but it's good to have them if needed for reference or if artifact download fails.
    pipeline_paths_for_nb05_inputs = get_dataset_paths(
        PROJECT_ROOT, 
        base_config, 
        DATASET_IDENTIFIER, 
        stage="training" # We test DataLoaders on training and validation data
    )
    TRAIN_DATA_PATH = pipeline_paths_for_nb05_inputs.get('train_data_parquet')
    VAL_DATA_PATH = pipeline_paths_for_nb05_inputs.get('val_data_parquet')
    MRI_DATA_DIR = pipeline_paths_for_nb05_inputs.get('mri_data_dir') # For OASISDataset

    if not all([TRAIN_DATA_PATH, VAL_DATA_PATH, MRI_DATA_DIR]):
        raise ValueError("Failed to resolve one or more critical data paths from get_dataset_stage_paths.")

    print(f"\nKey input paths for Notebook 05 ({DATASET_IDENTIFIER}):")
    print(f"  Input Training Data Parquet (from NB03): {TRAIN_DATA_PATH}")
    print(f"  Input Validation Data Parquet (from NB03): {VAL_DATA_PATH}")
    print(f"  Input MRI Data Directory: {MRI_DATA_DIR}")
    
    # Verify existence of these critical input files/dirs
    for p_name, p_obj in [("Training Data", TRAIN_DATA_PATH), ("Validation Data", VAL_DATA_PATH)]:
        if not p_obj.is_file(): raise FileNotFoundError(f"CRITICAL: {p_name} not found at: {p_obj}")
    if not MRI_DATA_DIR.is_dir(): raise FileNotFoundError(f"CRITICAL: MRI Data Directory not found at: {MRI_DATA_DIR}")
    print("All critical input data paths for NB05 verified.")

except (KeyError, ValueError, FileNotFoundError) as e_paths_nb05:
    print(f"CRITICAL ERROR during path setup for NB05: {e_paths_nb05}")
    # exit()
except Exception as e_general_nb05_setup:
    print(f"CRITICAL ERROR during setup for NB05: {e_general_nb05_setup}")
    # exit()

# --- Parameters for DataLoader Testing ---
BATCH_SIZE_FOR_TESTING = 4 
NUM_BATCHES_TO_INSPECT = 2
# Flag to test MRI inclusion; set to False to test tabular-only path of OASISDataset/pad_collate_fn
TEST_WITH_MRI = True

## 2. Initialize W&B Run & Define Artifacts to Consume

A new W&B run is initiated for this "Test DataLoader" notebook. This run will log:
* The configuration parameters used by this notebook.
* The names of the W&B Preprocessor Artifacts (Scaler and Imputer from Notebook 04) that this test will consume.
* The W&B Run ID of the Notebook 04 execution that produced these preprocessors (obtained via artifact lineage).
* Key characteristics of the data batches generated by the `DataLoader` to verify the pipeline's integrity.

This step also defines the names and expected types of the preprocessor artifacts that will be downloaded from W&B.

In [None]:
# --- Initialize a NEW W&B Run for THIS Notebook 05 execution ---
print("\n--- Initializing a New Weights & Biases Run for NB05 (DataLoader Test) ---")

# --- Define W&B Artifact Names for INPUT Preprocessors (Outputs from NB04) ---
# These names MUST MATCH the artifact names used by Notebook 04 when it logged them.
# They depend on the strategies defined in config.json and the DATASET_IDENTIFIER.

# Get strategies from base_config to construct artifact names
preprocessing_cfg = base_config.get('preprocessing_config', {})
scaling_strategy_name = preprocessing_cfg.get('scaling_strategy', 'standard_scaler')
imputation_strategy_name = preprocessing_cfg.get('imputation_strategy', 'median')

# Construct artifact names as NB04 would have logged them
SCALER_ARTIFACT_NAME = f"scaler_{scaling_strategy_name.lower().replace('_scaler','').replace('_','')}_{DATASET_IDENTIFIER}"
IMPUTER_ARTIFACT_NAME = f"simple_imputer_{imputation_strategy_name.lower()}_{DATASET_IDENTIFIER}"
PREPROCESSOR_ARTIFACT_TYPE = f"preprocessor_{DATASET_IDENTIFIER}" # Matches type used by NB04
PREPROCESSOR_ARTIFACT_VERSION = "latest" # Use "latest" to get the newest fitted preprocessors

print(f"  Will attempt to use Scaler artifact: {SCALER_ARTIFACT_NAME}:{PREPROCESSOR_ARTIFACT_VERSION}")
print(f"  Will attempt to use Imputer artifact: {IMPUTER_ARTIFACT_NAME}:{PREPROCESSOR_ARTIFACT_VERSION}")

# Configuration specific to this NB05 run
nb05_run_config_log = {
    "notebook_name_code": f"{NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER}",
    "dataset_source": DATASET_IDENTIFIER,
    "input_train_data_path_used": str(TRAIN_DATA_PATH), 
    "input_val_data_path_used": str(VAL_DATA_PATH),     
    "mri_data_dir_used": str(MRI_DATA_DIR),
    "scaler_artifact_to_use": f"{SCALER_ARTIFACT_NAME}:{PREPROCESSOR_ARTIFACT_VERSION}",          
    "imputer_artifact_to_use": f"{IMPUTER_ARTIFACT_NAME}:{PREPROCESSOR_ARTIFACT_VERSION}",        
    "batch_size_tested": BATCH_SIZE_FOR_TESTING,               
    "include_mri_tested": TEST_WITH_MRI,
    "execution_timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
    # Source NB04 run ID and its config (for features/preprocess lists) will be added after artifact loading
}

nb_number_prefix_nb05 = NOTEBOOK_MODULE_NAME.split('_')[0] if '_' in NOTEBOOK_MODULE_NAME else "NB"
job_specific_type_nb05 = f"{nb_number_prefix_nb05}-TestDataLoader-{DATASET_IDENTIFIER}"
custom_elements_for_name_nb05 = [nb_number_prefix_nb05, DATASET_IDENTIFIER.upper(), "DLTest", f"MRI_{str(TEST_WITH_MRI)}"]

run = initialize_wandb_run( # This 'run' object is for NB05's own logging
    base_project_config=base_config,
    job_group="Verification", 
    job_specific_type=job_specific_type_nb05,
    run_specific_config=nb05_run_config_log,
    custom_run_name_elements=custom_elements_for_name_nb05,
    notes=f"Testing OASISDataset & pad_collate_fn for {DATASET_IDENTIFIER} (MRI: {TEST_WITH_MRI}). Consumes preprocessor artifacts from NB04."
)

if run:
    print(f"New W&B run for NB05 '{run.name}' (Job Type: '{run.job_type}') initialized. View at: {run.url}")
else:
    print("Proceeding with DataLoader test without W&B logging for this NB05 execution.")

## 3. Load Preprocessor Artifacts & Fetch Definitive Configuration from Producer (NB04) Run

This crucial step establishes the link to the preprocessing decisions made in Notebook 04:
1.  The W&B Artifacts for the fitted `StandardScaler` and `SimpleImputer` (produced by Notebook 04) are **consumed using `run.use_artifact()`**. This downloads the `.joblib` files to a local directory. These paths will be used by `OASISDataset`.
2.  From one of these consumed preprocessor artifacts (e.g., the scaler artifact), the W&B Run that **produced it** (i.e., the relevant Notebook 04 run) is identified using `artifact.logged_by()`.
3.  The W&B configuration of this producer Notebook 04 run is then fetched. This configuration contains the `features` and `preprocess` dictionaries that `OASISDataset` needs for consistent data handling (e.g., knowing which exact columns were scaled/imputed, feature lists).
This process ensures that `OASISDataset` in this notebook uses the exact same preprocessing logic and feature definitions as intended by the finalized Notebook 04 execution.

In [None]:
# --- Load Preprocessor Artifacts from W&B and Fetch NB04 Config for OASISDataset ---
print(f"\n--- Loading Preprocessor Artifacts & Fetching Definitive Config from NB04 Producer Run ---")

# These variables are expected to be defined in the W&B Initialization cell for NB05:
# SCALER_ARTIFACT_NAME, IMPUTER_ARTIFACT_NAME, PREPROCESSOR_ARTIFACT_TYPE, PREPROCESSOR_ARTIFACT_VERSION
# WANDB_ENTITY, WANDB_PROJECT (from base_config)
# DATASET_IDENTIFIER
# base_config (loaded main config.json)
# run (the W&B run object for this NB05 execution)

SCALER_PATH_FROM_ARTIFACT = None
IMPUTER_PATH_FROM_ARTIFACT = None
config_for_dataset_instance = {} # This will be populated from NB04's run config
source_nb04_run_id_for_config = "N/A" 
source_nb04_run_name_for_config = "N/A"

try:
    if run is None: 
        raise ConnectionError("W&B run for Notebook 05 not initialized. Cannot use W&B artifacts.")

    # --- Step 1: Use and Download SCALER Artifact ---
    scaler_artifact_full_name = f"{WANDB_ENTITY}/{WANDB_PROJECT}/{SCALER_ARTIFACT_NAME}:{PREPROCESSOR_ARTIFACT_VERSION}"
    print(f"  Attempting to use and download Scaler artifact: {scaler_artifact_full_name} (Type: {PREPROCESSOR_ARTIFACT_TYPE})")
    scaler_artifact = run.use_artifact(scaler_artifact_full_name, type=PREPROCESSOR_ARTIFACT_TYPE)
    scaler_artifact_dir = Path(scaler_artifact.download())
    
    # Determine the scaler filename within the artifact (based on how NB04 saved it)
    scaler_filename_in_artifact = base_config.get(f"pipeline_artefact_locators_{DATASET_IDENTIFIER}", {})\
        .get('scaler_fname_pattern', '{scaling_strategy}_{dataset_identifier}.joblib').format(
            scaling_strategy=base_config.get('preprocessing_config',{}).get('scaling_strategy','standard_scaler').lower(),
            dataset_identifier=DATASET_IDENTIFIER
    )
    SCALER_PATH_FROM_ARTIFACT = scaler_artifact_dir / scaler_filename_in_artifact
    
    if not SCALER_PATH_FROM_ARTIFACT.is_file(): # Fallback if specific name not found
        joblib_files_scaler = list(scaler_artifact_dir.glob("*.joblib"))
        if joblib_files_scaler: 
            SCALER_PATH_FROM_ARTIFACT = joblib_files_scaler[0]
            print(f"  Warning: Specific scaler file '{scaler_filename_in_artifact}' not found, using first .joblib: {SCALER_PATH_FROM_ARTIFACT.name}")
        else: 
            raise FileNotFoundError(f"Scaler .joblib file ('{scaler_filename_in_artifact}' or any .joblib) "
                                    f"not found in downloaded scaler artifact directory: {scaler_artifact_dir}")
    print(f"  Scaler artifact '{scaler_artifact.name}' downloaded. Local path: {SCALER_PATH_FROM_ARTIFACT}")

    # --- Step 2: Use and Download IMPUTER Artifact ---
    imputer_artifact_full_name = f"{WANDB_ENTITY}/{WANDB_PROJECT}/{IMPUTER_ARTIFACT_NAME}:{PREPROCESSOR_ARTIFACT_VERSION}"
    print(f"  Attempting to use and download Imputer artifact: {imputer_artifact_full_name} (Type: {PREPROCESSOR_ARTIFACT_TYPE})")
    imputer_artifact = run.use_artifact(imputer_artifact_full_name, type=PREPROCESSOR_ARTIFACT_TYPE)
    imputer_artifact_dir = Path(imputer_artifact.download())
    
    imputer_filename_in_artifact = base_config.get(f"pipeline_artefact_locators_{DATASET_IDENTIFIER}", {})\
        .get('imputer_fname_pattern', 'simple_imputer_{imputation_strategy}_{dataset_identifier}.joblib').format(
            imputation_strategy=base_config.get('preprocessing_config',{}).get('imputation_strategy','median'),
            dataset_identifier=DATASET_IDENTIFIER
    )
    IMPUTER_PATH_FROM_ARTIFACT = imputer_artifact_dir / imputer_filename_in_artifact
    
    if not IMPUTER_PATH_FROM_ARTIFACT.is_file(): # Fallback
        joblib_files_imputer = list(imputer_artifact_dir.glob("*.joblib"))
        if joblib_files_imputer: 
            IMPUTER_PATH_FROM_ARTIFACT = joblib_files_imputer[0]
            print(f"  Warning: Specific imputer file '{imputer_filename_in_artifact}' not found, using first .joblib: {IMPUTER_PATH_FROM_ARTIFACT.name}")
        else: 
            raise FileNotFoundError(f"Imputer .joblib file ('{imputer_filename_in_artifact}' or any .joblib) "
                                    f"not found in downloaded imputer artifact directory: {imputer_artifact_dir}")
    print(f"  Imputer artifact '{imputer_artifact.name}' downloaded. Local path: {IMPUTER_PATH_FROM_ARTIFACT}")

    # --- Step 3: Fetch Full Configuration from the NB04 Run that Produced these Preprocessors ---
    # Use one of the consumed artifacts (e.g., scaler_artifact) to find its producer run (the NB04 run).
    nb04_producer_run = imputer_artifact.logged_by()     # Verify if the correct run is used in case of multiple NB04 runs
    if nb04_producer_run:
        source_nb04_run_id_for_config = nb04_producer_run.id
        source_nb04_run_name_for_config = nb04_producer_run.name
        print(f"  Fetching full config from producer NB04 run: '{source_nb04_run_name_for_config}' (ID: {source_nb04_run_id_for_config})")
        
        config_from_nb04_producer_run = dict(nb04_producer_run.config) # Convert W&B config to standard dict
        
        # Validate that the critical keys ('features', 'preprocess') are present in the fetched config
        # These keys should have been logged by Notebook 04.
        if 'features' not in config_from_nb04_producer_run or \
           'preprocess' not in config_from_nb04_producer_run:
            print("  ERROR: The config fetched from the producer NB04 run is missing critical 'features' or 'preprocess' sections.")
            print(f"  Available keys in fetched NB04 config: {list(config_from_nb04_producer_run.keys())}")
            raise ValueError("Incomplete configuration fetched from NB04 producer run.")
        
        config_for_dataset_instance = config_from_nb04_producer_run 
        # This now holds the authoritative config (features, preprocess, cnn_model_params, etc.)
        # that OASISDataset will use.
        
        print(f"  Successfully fetched configuration from NB04 producer run for OASISDataset.")
        
        # Update current NB05 run's config with info about the source NB04 run and artifacts used
        if run: # Ensure NB05's run object is valid
            run.config.update({
                "source_config_details/nb04_producer_run_id": source_nb04_run_id_for_config,
                "source_config_details/nb04_producer_run_name": source_nb04_run_name_for_config,
                "source_config_details/scaler_artifact_used_name_version": scaler_artifact.name, # Logs e.g. "scaler_standard_oasis2:vX"
                "source_config_details/imputer_artifact_used_name_version": imputer_artifact.name,
                # Log the key parts of the config that OASISDataset will actually use for clarity
                "dataset_config_used/features": config_for_dataset_instance.get('features',{}),
                "dataset_config_used/preprocess": config_for_dataset_instance.get('preprocess',{}),
                "dataset_config_used/cnn_model_params": config_for_dataset_instance.get('cnn_model_params',{}),
                "dataset_config_used/preprocessing_config_mri": config_for_dataset_instance.get('preprocessing_config',{})
            }, allow_val_change=True)
            print("  NB05 W&B run config updated with source NB04 run info and the dataset config to be used.")
    else: 
        raise ConnectionError("Could not retrieve the W&B run that produced the preprocessor artifacts "
                              "(via artifact.logged_by()). Definitive configuration cannot be fetched.")

except wandb.errors.CommError as e_wandb_comm:
    print(f"CRITICAL W&B Communication Error: {e_wandb_comm}")
    print("  This could be due to invalid artifact names, types, versions, or general W&B API issues.")
    print(f"  Attempted to use Scaler: {WANDB_ENTITY}/{WANDB_PROJECT}/{SCALER_ARTIFACT_NAME}:{PREPROCESSOR_ARTIFACT_VERSION}")
    print(f"  Attempted to use Imputer: {WANDB_ENTITY}/{WANDB_PROJECT}/{IMPUTER_ARTIFACT_NAME}:{PREPROCESSOR_ARTIFACT_VERSION}")
    if run: run.finish(exit_code=1)
    # exit()
except FileNotFoundError as e_fnf_artifact:
    print(f"CRITICAL FileNotFoundError after attempting artifact download: {e_fnf_artifact}")
    if run: run.finish(exit_code=1)
    # exit()
except Exception as e_artifact_other:
    print(f"CRITICAL ERROR loading preprocessor artifacts or fetching config from NB04 run: {e_artifact_other}")
    import traceback
    traceback.print_exc()
    if run: run.finish(exit_code=1)
    # exit()

# Ensure config_for_dataset_instance is somewhat populated for OASISDataset,
# even if with fallbacks from base_config, to prevent immediate errors if the above failed
# and execution wasn't halted by an exit() or re-raised error.
if not config_for_dataset_instance or 'features' not in config_for_dataset_instance or 'preprocess' not in config_for_dataset_instance :
    print("CRITICAL WARNING: 'config_for_dataset_instance' (from NB04 W&B run) is not correctly populated.")
    print("  OASISDataset will likely fail or use incorrect default features.")
    # Define minimal fallback structure based on base_config if critical failure occurred
    config_for_dataset_instance.setdefault('features', base_config.get('feature_definitions_fallback', {}).get('features_for_model', {'time_varying': [], 'static': []}))
    config_for_dataset_instance.setdefault('preprocess', base_config.get('feature_definitions_fallback', {}).get('preprocess_details_for_model', {'imputation_cols': [], 'scaling_cols': []}))
    config_for_dataset_instance.setdefault('cnn_model_params', base_config.get('cnn_model_params', {})) 
    config_for_dataset_instance.setdefault('preprocessing_config', base_config.get('preprocessing_config', {})) 
    if run: run.config.update({"warning_nb04_config_fetch_failed": True}, allow_val_change=True)

## 4. Instantiate `OASISDataset` Instances

With all necessary configurations, data paths, and preprocessor paths now available (preprocessor paths pointing to locally downloaded W&B Artifacts, and feature/preprocessing configurations fetched from the authoritative Notebook 04 W&B run), we instantiate the custom `OASISDataset` class.

Separate instances are created for the training and validation data splits (`cohort_train.parquet` and `cohort_validation.parquet`, respectively). The `OASISDataset` will:
* Load the specified Parquet data file for the split.
* Internally load the fitted imputer and scaler using the provided artifact paths.
* Apply imputation and scaling to the appropriate columns based on the fetched `config_for_dataset_instance`.
* Handle the selection of time-varying and static features as defined in `config_for_dataset_instance`.
* Perform encoding of categorical features (e.g., 'M/F' to 'M/F_encoded') if indicated by the feature configuration.
* Optionally load and prepare corresponding MRI scan data if the `TEST_WITH_MRI` flag is set to `True`.

This step critically tests the core data loading, preprocessing application, and feature assembly logic encapsulated within the `OASISDataset` class, ensuring it uses the precise settings defined by Notebook 04.

In [None]:
# --- Instantiate OASISDataset for Training and Validation Sets ---
print("\n--- Instantiating Datasets ---")

# Initialize dataset variables to ensure they are defined
train_dataset: OASISDataset | None = None
val_dataset: OASISDataset | None = None

# --- Prerequisite Variable Check ---
# These variables are expected to be defined and populated from the preceding cells:
#   TRAIN_DATA_PATH (Path object to train parquet file)
#   VAL_DATA_PATH (Path object to validation parquet file)
#   SCALER_PATH_FROM_ARTIFACT (Path object to downloaded scaler.joblib)
#   IMPUTER_PATH_FROM_ARTIFACT (Path object to downloaded imputer.joblib)
#   config_for_dataset_instance (dict: config fetched from NB04 W&B run for OASISDataset)
#   MRI_DATA_DIR (Path object to preprocessed MRI scans directory)
#   TEST_WITH_MRI (bool: flag to include MRI data in this test run)
#   DATASET_IDENTIFIER (str: e.g., "oasis2") - for logging/consistency, though not directly used by OASISDataset here
#   run (wandb.Run object for NB05) - for potential logging within this cell

required_vars_for_dataset_instantiation = {
    'TRAIN_DATA_PATH': TRAIN_DATA_PATH,
    'VAL_DATA_PATH': VAL_DATA_PATH,
    'SCALER_PATH_FROM_ARTIFACT': SCALER_PATH_FROM_ARTIFACT,
    'IMPUTER_PATH_FROM_ARTIFACT': IMPUTER_PATH_FROM_ARTIFACT,
    'config_for_dataset_instance': config_for_dataset_instance,
    'MRI_DATA_DIR': MRI_DATA_DIR,
    'TEST_WITH_MRI': TEST_WITH_MRI 
    # Note: DATASET_IDENTIFIER and run are used for logging, not direct OASISDataset args
}

proceed_with_instantiation = True
for var_name, var_value in required_vars_for_dataset_instantiation.items():
    if var_value is None: # Check for None explicitly
        print(f"  CRITICAL ERROR: Prerequisite variable '{var_name}' is None.")
        proceed_with_instantiation = False
    # Also check if config_for_dataset_instance is empty or missing crucial keys
    if var_name == 'config_for_dataset_instance' and \
       (not var_value or 'features' not in var_value or 'preprocess' not in var_value):
        print(f"  CRITICAL ERROR: 'config_for_dataset_instance' is empty or missing 'features'/'preprocess' keys.")
        # from pprint import pprint; pprint(var_value) # Uncomment for detailed debug if needed
        proceed_with_instantiation = False
        
if not proceed_with_instantiation:
    print("Halting dataset instantiation due to missing or invalid prerequisites from previous cells.")
    if run: run.finish(exit_code=1) # Mark W&B run as failed
    # exit() # Or raise an error
else:
    try:
        print(f"\nInstantiating train_dataset for {DATASET_IDENTIFIER.upper()} (MRI Included: {TEST_WITH_MRI})...")
        train_dataset = OASISDataset(
            data_parquet_path=TRAIN_DATA_PATH,
            scaler_path=SCALER_PATH_FROM_ARTIFACT,   # Use path to downloaded artifact
            imputer_path=IMPUTER_PATH_FROM_ARTIFACT, # Use path to downloaded artifact
            config=config_for_dataset_instance,      # Use config fetched from NB04 W&B run
            mri_data_dir=MRI_DATA_DIR if TEST_WITH_MRI else None,
            include_mri=TEST_WITH_MRI 
        )
        num_train_subjects = len(train_dataset)
        print(f"  Train dataset created successfully. Number of subjects (sequences): {num_train_subjects}")
        if run: run.log({f'dataset_test_{DATASET_IDENTIFIER}/train_subject_count': num_train_subjects})


        print(f"\nInstantiating val_dataset for {DATASET_IDENTIFIER.upper()} (MRI Included: {TEST_WITH_MRI})...")
        val_dataset = OASISDataset(
            data_parquet_path=VAL_DATA_PATH,
            scaler_path=SCALER_PATH_FROM_ARTIFACT,   # Use the SAME scaler/imputer from training
            imputer_path=IMPUTER_PATH_FROM_ARTIFACT,
            config=config_for_dataset_instance,      # Use the SAME config from NB04 W&B run
            mri_data_dir=MRI_DATA_DIR if TEST_WITH_MRI else None,
            include_mri=TEST_WITH_MRI
        )
        num_val_subjects = len(val_dataset)
        print(f"  Validation dataset created successfully. Number of subjects (sequences): {num_val_subjects}")
        if run: run.log({f'dataset_test_{DATASET_IDENTIFIER}/val_subject_count': num_val_subjects})

    except FileNotFoundError as e_fnf_ds_init:
        print(f"CRITICAL ERROR during OASISDataset instantiation: A required file was not found - {e_fnf_ds_init}")
        print("  This could be an issue with TRAIN_DATA_PATH, VAL_DATA_PATH, or the downloaded artifact paths "
              "(SCALER_PATH_FROM_ARTIFACT, IMPUTER_PATH_FROM_ARTIFACT). Verify paths from previous cells.")
        if run: run.finish(exit_code=1)
        # exit()
    except KeyError as e_key_ds_init:
        print(f"CRITICAL ERROR during OASISDataset instantiation: Missing a key in 'config_for_dataset_instance' - {e_key_ds_init}")
        print("  Ensure the configuration fetched from the Notebook 04 W&B run contains the complete "
              "'features' and 'preprocess' dictionaries with all expected subkeys.")
        # from pprint import pprint; print("DEBUG: config_for_dataset_instance being passed to OASISDataset:"); pprint(config_for_dataset_instance)
        if run: run.finish(exit_code=1)
        # exit()
    except Exception as e_ds_init_other:
        print(f"An unexpected CRITICAL ERROR occurred during OASISDataset instantiation: {e_ds_init_other}")
        import traceback
        traceback.print_exc()
        if run: run.finish(exit_code=1)
        # exit()

# Final check to ensure datasets were created for the next cell
if train_dataset is None or val_dataset is None:
    print("CRITICAL ERROR: Dataset instantiation failed. Cannot proceed to DataLoader creation in the next cell.")
    # exit() # This would halt the notebook

## 5. Create DataLoaders

Wrap the `OASISDataset` instances (for training and validation sets) in PyTorch `DataLoader`s. The `DataLoader` is responsible for:
* Batching the data (grouping multiple subjects/sequences together).
* Shuffling the training data before each epoch (optional, but good practice).
* **Critically, using the custom `pad_collate_fn`** (from `src/datasets.py`). This function takes lists of sequences (which can have varying lengths within a batch) and pads them into uniform PyTorch tensors suitable for input into sequence models. It also generates sequence `lengths` and boolean `masks` indicating real vs. padded elements.
* Optionally using multiple worker processes for efficient data loading in the background (for this test, `num_workers=0` is often used for simpler debugging).

In [None]:
# --- Create DataLoaders ---
print("\n--- Creating DataLoaders ---")

# BATCH_SIZE_FOR_TESTING and TEST_WITH_MRI should be defined in a setup cell
# pad_collate_fn should be imported from src.datasets

train_loader = None
val_loader = None

if 'train_dataset' in locals() and train_dataset is not None and \
   'val_dataset' in locals() and val_dataset is not None:
    
    print(f"  Using BATCH_SIZE: {BATCH_SIZE_FOR_TESTING}")
    # num_workers=0 is often best for local debugging to avoid multiprocessing complexities.
    # persistent_workers=False if num_workers > 0, can help with some OS/environment issues.
    # For num_workers=0, persistent_workers has no effect and can be omitted.
    
    try:
        train_loader = DataLoader(
            train_dataset,
            batch_size=BATCH_SIZE_FOR_TESTING,
            shuffle=True, # Shuffle training data each epoch for better generalization
            collate_fn=pad_collate_fn,
            num_workers=0 
            # persistent_workers=False if HP_analysis.get('num_workers',0) > 0 else False # Example from training
        )
        print(f"  Training DataLoader created. Number of batches: ~{len(train_loader)}")

        val_loader = DataLoader(
            val_dataset,
            batch_size=BATCH_SIZE_FOR_TESTING,
            shuffle=False, # No need to shuffle validation/test data
            collate_fn=pad_collate_fn,
            num_workers=0
        )
        print(f"  Validation DataLoader created. Number of batches: ~{len(val_loader)}")

    except Exception as e_dataloader:
        print(f"CRITICAL ERROR creating DataLoaders: {e_dataloader}")
        if run: run.finish(exit_code=1)
        # exit()
else:
    print("Skipping DataLoader creation as datasets are not available.")

# Ensure loaders are defined for subsequent cells
if train_loader is None or val_loader is None:
    print("CRITICAL ERROR: DataLoader creation failed. Cannot proceed to batch iteration test.")
    # exit()

## 6. Test Batch Iteration and Verify Batch Contents

To confirm that the entire data loading and preprocessing pipeline works as expected, this section iterates through a small number of batches yielded by the `train_loader`. For each test batch, it unpacks all components generated by `pad_collate_fn`:
* `sequences_tabular_padded`
* `sequences_mri_padded` (if MRI data is included)
* `lengths` (original sequence lengths)
* `targets`
* `masks` (boolean masks for padding)

The shapes, data types, and some example values or summaries of these components are printed. This allows for careful verification that the data is correctly formatted for input into a PyTorch sequence model.

**Key things to check in the output:**
* **No errors** during batch iteration.
* `sequences_tabular_padded` shape: `(batch_size, max_seq_len_in_batch, num_tabular_features)`. Dtype: `torch.float32`.
* `sequences_mri_padded` shape (if `TEST_WITH_MRI=True`): `(batch_size, max_seq_len_in_batch, C, D, H, W)`. Dtype: `torch.float32`.
* `lengths` shape: `(batch_size,)`. Dtype: `torch.int64`. Values should match actual pre-padding sequence lengths.
* `targets` shape: `(batch_size, 1)` or `(batch_size,)`. Dtype: `torch.float32`.
* `masks` shape: `(batch_size, max_seq_len_in_batch)`. Dtype: `torch.bool`. Values (`True`/`False`) must align with `lengths`.

In [None]:
# --- Test Batch Iteration from Train Loader ---
print("\n--- Testing Batch Iteration from Training DataLoader ---")

# NUM_BATCHES_TO_INSPECT and TEST_WITH_MRI should be defined in a setup cell

if 'train_loader' in locals() and train_loader is not None:
    if len(train_loader) == 0:
        print("  Warning: train_loader is empty. Cannot iterate through batches.")
    else:
        print(f"  Iterating through the first {NUM_BATCHES_TO_INSPECT} batch(es) "
              f"(MRI Included in this test: {TEST_WITH_MRI})...")
        
        # Initialize variables to store last batch info for W&B logging, ensure they exist
        last_batch_tabular_shape_log = None
        last_batch_mri_shape_log = None
        last_batch_lengths_log = None
        last_batch_targets_shape_log = None

        for i, batch_content in enumerate(train_loader):
            if i >= NUM_BATCHES_TO_INSPECT:
                break

            print(f"\n--- Contents of Batch {i+1}/{NUM_BATCHES_TO_INSPECT} ---")
            
            # Unpack the batch based on whether MRI data is included
            # This logic must align with pad_collate_fn's return tuple
            if TEST_WITH_MRI: # Expecting 5 items
                if len(batch_content) == 5:
                    sequences_tabular_padded, sequences_mri_padded, lengths, targets, masks = batch_content
                else:
                    print(f"  ERROR: Expected 5 items in batch when TEST_WITH_MRI=True, but got {len(batch_content)}. Skipping this batch display.")
                    continue
            else: # Expecting 4 items (tabular only)
                if len(batch_content) == 4:
                    sequences_tabular_padded, lengths, targets, masks = batch_content
                    sequences_mri_padded = None # Explicitly None
                else:
                    print(f"  ERROR: Expected 4 items in batch when TEST_WITH_MRI=False, but got {len(batch_content)}. Skipping this batch display.")
                    continue

            print(f"  Tabular Sequences Tensor Shape: {sequences_tabular_padded.shape}") 
            print(f"  Tabular Sequences Tensor Type: {sequences_tabular_padded.dtype}") 
            
            if sequences_mri_padded is not None:
                print(f"  MRI Sequences Tensor Shape: {sequences_mri_padded.shape}") 
                print(f"  MRI Sequences Tensor Type: {sequences_mri_padded.dtype}")
                last_batch_mri_shape_log = list(sequences_mri_padded.shape)
            else:
                print(f"  MRI Sequences Tensor: Not included in this batch/test.")
                last_batch_mri_shape_log = "Not_Included"

            print(f"  Lengths Tensor Shape: {lengths.shape}") 
            print(f"  Lengths Tensor Type: {lengths.dtype}") 
            print(f"  Lengths Tensor Values (first {min(5, len(lengths))}): {lengths.tolist()[:min(5, len(lengths.tolist()))]}") 
            
            print(f"  Targets Tensor Shape: {targets.shape}") 
            print(f"  Targets Tensor Type: {targets.dtype}") 
            print(f"  Targets Tensor Values (first {min(5, len(targets.tolist()))}): {targets.squeeze().tolist()[:min(5, len(targets.tolist()))]}") 
            
            print(f"  Masks Tensor Shape: {masks.shape}") 
            print(f"  Masks Tensor Type: {masks.dtype}") 

            # Store shapes from the last iterated batch for logging
            last_batch_tabular_shape_log = list(sequences_tabular_padded.shape)
            last_batch_lengths_log = lengths.tolist()[:min(5, len(lengths.tolist()))] # Log a sample
            last_batch_targets_shape_log = list(targets.shape)

            # Print a slice of a sequence and its mask to verify padding
            if sequences_tabular_padded.shape[0] > 0 and sequences_tabular_padded.shape[1] > 0:
                print("\n  Example Tabular Sequence (First Sample in Batch, First 5 steps, First 3 features):")
                print(sequences_tabular_padded[0, :min(5, sequences_tabular_padded.shape[1]), :min(3, sequences_tabular_padded.shape[2])])
                print("\n  Corresponding Mask (First Sample in Batch, First 5 steps):")
                print(masks[0, :min(5, masks.shape[1])])
                print(f"  (Original length for this first sample was: {lengths[0].item()})")
        
        # Log info about the last batch tested to W&B
        if run:
            log_payload_nb05 = {
                "dataloader_test/status": "completed_iteration",
                'dataloader_test/num_batches_inspected': i + 1 if 'i' in locals() else 0
            }
            if last_batch_tabular_shape_log:
                 log_payload_nb05['dataloader_test/last_batch_tab_shape'] = last_batch_tabular_shape_log
            if last_batch_mri_shape_log:
                 log_payload_nb05['dataloader_test/last_batch_mri_shape'] = last_batch_mri_shape_log
            if last_batch_lengths_log:
                 log_payload_nb05['dataloader_test/last_batch_lengths_example'] = last_batch_lengths_log
            if last_batch_targets_shape_log:
                log_payload_nb05['dataloader_test/last_batch_target_shape'] = last_batch_targets_shape_log
            run.log(log_payload_nb05)
            print("\n  Logged batch test summary to W&B.")

        print(f"\n--- Finished inspecting {NUM_BATCHES_TO_INSPECT} batch(es). ---")
        print("Review shapes, dtypes, lengths, and masks to ensure correctness.")
else:
    print("Skipping batch iteration test as train_loader is not available.")

## 7. Finalize W&B Run

Complete the execution for this DataLoader testing notebook (NB05) and finish the associated Weights & Biases run. This ensures all queued logs and information about the test are uploaded to the W&B platform.

In [None]:
# --- Finish W&B Run for Notebook 05 ---
print(f"\n--- {NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER} (DataLoader Test) complete. Finishing W&B run. ---")

if run: # Check if 'run' object exists and is an active W&B run
    try:
        # Ensure any final summary or status is updated if needed
        run.summary.update({"overall_notebook_status": "Completed Successfully"})
        
        run.finish()
        run_name_to_print = run.name_synced if hasattr(run, 'name_synced') and run.name_synced else \
                            run.name if hasattr(run, 'name') and run.name else \
                            run.id if hasattr(run, 'id') else "current NB05 run"
        print(f"W&B run '{run_name_to_print}' finished successfully.")
    except Exception as e_finish_run_nb05:
        print(f"Error during wandb.finish() for Notebook 05: {e_finish_run_nb05}")
        print("The run may not have finalized correctly on the W&B server.")
else:
    print("No active W&B run to finish for this session.")

print(f"\n--- Notebook {NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER} execution finished. ---")

## 8. Conclusion

If the iteration through batches completed successfully and the printed shapes, data types, sequence lengths, targets, and masks appear correct for both tabular and (if tested) MRI data streams, it indicates that the data loading pipeline (`OASISDataset` and `pad_collate_fn`) is functioning as expected with the preprocessors and configurations derived from Notebook 04. The data is now confirmed to be in the appropriate format for input into sequence models for training.