# Notebook 03: OASIS-2 Feature Engineering & Data Splitting

**Project Phase:** 1 (Data Processing - Feature Engineering & Splitting)
**Dataset:** OASIS-2 Longitudinal MRI & Clinical Data

**Purpose:**
This notebook takes the defined analysis cohort (output from Notebook 02) and performs critical feature engineering and data splitting steps to prepare for model training. The objectives are:
1.  Load the `final_analysis_cohort.csv` data.
2.  Engineer time-based features essential for longitudinal modeling, such as `Days_from_Baseline` (relative to the first visit *within the current cohort*) and `Time_since_Last_Visit_Days`.
3.  Extract baseline clinical scores (e.g., `Baseline_CDR`, `Baseline_MMSE`) for each subject to be used as static input features.
4.  Select the final set of columns (identifiers, time-varying features, static features, and the base target variable 'CDR') for the modeling dataset.
5.  Create the primary target variable: `CDR_next_visit` (the CDR score at the subsequent visit). Rows corresponding to a subject's last visit (which have no next CDR) are dropped.
6.  Perform a subject-level stratified train/validation/test split. Stratification is based on `Baseline_CDR` to ensure balanced representation of baseline cognitive status across splits.
7.  Save the resulting `train_df`, `val_df`, and `test_df` DataFrames locally as Parquet files.
8.  Log these data splits as versioned artifacts to Weights & Biases (W&B) for use in subsequent modeling notebooks.

**Workflow:**
1.  **Setup:** Import libraries, configure `sys.path`, load `config.json`, define input/output paths.
2.  **W&B Initialization:** Start a new W&B run using `initialize_wandb_run`.
3.  **Load Cohort Data:** Load `final_analysis_cohort.csv` (from NB02).
4.  **Engineer Time Features:** Calculate `Days_from_Baseline` and `Time_since_Last_Visit_Days`.
5.  **Prepare Static Features & Select Columns:** Extract baseline scores, define and filter feature lists.
6.  **Create Target Variable:** Generate `CDR_next_visit`.
7.  **Data Splitting:** Perform subject-level stratified train/validation/test split. Log split statistics.
8.  **Save Splits & Log Artifacts:** Save DataFrames locally and log to W&B.
9.  **Finalize W&B Run.**

**Input:**
* `config.json`: Main project configuration file.
* `final_analysis_cohort.csv`: Output from Notebook 02 (path constructed using `config.json`).

**Output:**
* **Local Files (in notebook-specific output directory, e.g., `notebooks/outputs/03_Feature_Engineering_Splitting_OASIS2/`):**
    * `cohort_train.parquet`
    * `cohort_validation.parquet`
    * `cohort_test.parquet`
* **W&B Run:**
    * Logged run configuration (including input paths, split ratios).
    * Statistics on feature engineering and data splitting.
    * Data splits (`train`, `validation`, `test`) logged as W&B Artifacts (e.g., `cohort-split-train-oasis2`).

In [1]:
# In: notebooks/03_Feature_Engineering_Splitting.ipynb
# Purpose: Load the defined cohort, engineer time-based features,
#          select features (incl. pre-computed MRI metrics) for baseline modeling,
#          perform subject-level stratified split, and save the split datasets.

In [2]:
# --- Import Libraries ---
import pandas as pd
import numpy as np
import wandb
import json
from pathlib import Path
import time
import sys 
import os
from sklearn.model_selection import train_test_split

## 1. Setup: Project Configuration, Paths, and Utilities

This section initializes the notebook environment:
* Determines the project's root directory and adds the `src` directory to the Python system path.
* Imports necessary custom utility functions (primarily for W&B run initialization).
* Loads the main project configuration from `config.json`.
* Defines dataset identifiers and notebook-specific parameters.
* Resolves the input path for the cohort data (from Notebook 02) and sets up the output directory for this notebook's generated data splits, using `config.json` for consistency.

In [3]:
# --- Project Setup, Configuration Loading, and Utility Imports ---
print("--- Initializing Project Setup & Configuration for NB03 ---")

# Initialize
PROJECT_ROOT = None
base_config = {}

try:
    # Determine project root 
    current_notebook_path = Path.cwd() 
    potential_project_root = current_notebook_path.parent 
    if (potential_project_root / "src").is_dir() and (potential_project_root / "config.json").is_file():
        PROJECT_ROOT = potential_project_root
    else: 
        PROJECT_ROOT = current_notebook_path
    
    if not (PROJECT_ROOT / "src").is_dir() or not (PROJECT_ROOT / "config.json").is_file():
        raise FileNotFoundError(f"Could not reliably find 'src' or 'config.json'. PROJECT_ROOT: {PROJECT_ROOT}")

    if str(PROJECT_ROOT) not in sys.path:
        sys.path.insert(0, str(PROJECT_ROOT))
    print(f"PROJECT_ROOT: {PROJECT_ROOT}")
    print(f"Added '{str(PROJECT_ROOT)}' to sys.path.")

    # Import custom utilities
    from src.wandb_utils import initialize_wandb_run 
    print("Successfully imported 'initialize_wandb_run' from src.wandb_utils.")

except FileNotFoundError as e_path:
    print(f"CRITICAL ERROR in project setup (paths or src): {e_path}")
    # exit() 
except ImportError as e_imp:
    print(f"CRITICAL ERROR: Could not import custom utilities: {e_imp}")
    # exit()
except Exception as e_general_setup:
    print(f"CRITICAL ERROR during initial setup: {e_general_setup}")
    # exit()

# --- Load Main Project Configuration ---
print("\n--- Loading Main Project Configuration from config.json ---")
try:
    if PROJECT_ROOT is None:
        raise ValueError("PROJECT_ROOT was not successfully defined.")
    CONFIG_PATH_MAIN = PROJECT_ROOT / 'config.json'
    with open(CONFIG_PATH_MAIN, 'r', encoding='utf-8') as f:
        base_config = json.load(f)
    print(f"Main project config loaded from: {CONFIG_PATH_MAIN}")
except Exception as e_cfg:
    print(f"CRITICAL ERROR loading main config.json: {e_cfg}")
    # exit() 

# --- Define Dataset, Notebook Specifics, and Key Paths ---
DATASET_IDENTIFIER = "oasis2" 
NOTEBOOK_MODULE_NAME = "03_Feature_Engineering_Splitting"
# Key for this notebook's output subdir in config.json's locators
NB03_OUTPUT_LOCATOR_KEY = "feature_eng_subdir"

COHORT_CSV_PATH_NB03_INPUT = None # Input from NB02
output_dir = None                 # Static output dir for this notebook's Parquet files

try:
    if not base_config:
        raise ValueError("base_config is empty. Cannot define paths.")

    output_dir_base_from_config = PROJECT_ROOT / base_config['data']['output_dir_base']
    dataset_locators = base_config.get(f"pipeline_artefact_locators_{DATASET_IDENTIFIER}", {})
    if not dataset_locators:
        raise KeyError(f"pipeline_artefact_locators_{DATASET_IDENTIFIER} section not found in config.json.")

    # Path to final_analysis_cohort.csv (output from Notebook 02)
    nb02_output_subdir_name = dataset_locators.get("cohort_def_subdir")
    nb02_final_cohort_fname = dataset_locators.get("final_cohort_fname")
    if not nb02_output_subdir_name or not nb02_final_cohort_fname:
        raise KeyError("Missing 'cohort_def_subdir' or 'final_cohort_fname' in locators config.")
    COHORT_CSV_PATH_NB03_INPUT = output_dir_base_from_config / nb02_output_subdir_name / nb02_final_cohort_fname

    # Define the main OUTPUT directory for THIS notebook's files (e.g., cohort_train.parquet)
    notebook_output_folder_name = dataset_locators.get(
        NB03_OUTPUT_LOCATOR_KEY, # e.g., "feature_eng_subdir_nb03"
        f"{NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER}_default_outputs" # Fallback
    )
    output_dir = output_dir_base_from_config / notebook_output_folder_name
    output_dir.mkdir(parents=True, exist_ok=True)

    print(f"\nKey paths for Notebook 03 ({DATASET_IDENTIFIER}):")
    print(f"  Input Cohort CSV (from NB02): {COHORT_CSV_PATH_NB03_INPUT}")
    print(f"  Notebook Output Directory (for data splits): {output_dir}")
    
    if not COHORT_CSV_PATH_NB03_INPUT.is_file():
        raise FileNotFoundError(f"CRITICAL: Input cohort CSV from NB02 not found: {COHORT_CSV_PATH_NB03_INPUT}. "
                                "Ensure Notebook 02 ran successfully and config.json locators are correct.")
    print("All critical input paths for NB03 verified.")

except KeyError as e_key:
    print(f"CRITICAL ERROR: Missing key {e_key} in config.json or locators section.")
    # exit()
except FileNotFoundError as e_fnf:
    print(f"CRITICAL ERROR: {e_fnf}")
    # exit()
except Exception as e_paths_nb03:
    print(f"CRITICAL ERROR defining paths for NB03: {e_paths_nb03}")
    # exit()

## 2. Initialize Weights & Biases Run

A new W&B run is started for this Feature Engineering and Data Splitting notebook. The run will log:
* The configuration parameters used (input data path, time feature strategy, split ratios).
* Summary statistics about the engineered features and data splits.
* The final train, validation, and test data splits as versioned W&B Artifacts.

In [4]:
# --- Initialize W&B Run for THIS Notebook Execution (NB03) ---
print("\n--- Initializing Weights & Biases Run for Notebook 03 ---")

# Define constants for splitting to be logged to W&B config
TEST_SET_RATIO_CONFIG = 0.15
VAL_SET_RATIO_CONFIG = 0.15
RANDOM_STATE_CONFIG = 42

nb03_run_config_log = {
    "notebook_name_code": f"{NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER}",
    "dataset_source": DATASET_IDENTIFIER,
    "input_cohort_data_path": str(COHORT_CSV_PATH_NB03_INPUT),
    "output_dir_for_local_saves": str(output_dir),
    "time_feature_source_planned": "'MR Delay' preferred, fallback 'Age'", # Documenting strategy
    "split_test_ratio_target": TEST_SET_RATIO_CONFIG, # Log planned ratios
    "split_validation_ratio_target": VAL_SET_RATIO_CONFIG,
    "split_random_state": RANDOM_STATE_CONFIG,
    "execution_timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
    # Actual feature lists used will be logged after they are defined.
}

nb_number_prefix_nb03 = NOTEBOOK_MODULE_NAME.split('_')[0] if '_' in NOTEBOOK_MODULE_NAME else "NB"
job_specific_type_nb03 = f"{nb_number_prefix_nb03}-FeatEngSplit-{DATASET_IDENTIFIER}"
custom_elements_for_name_nb03 = [nb_number_prefix_nb03, DATASET_IDENTIFIER.upper(), "FeatEngSplit"]

run = initialize_wandb_run(
    base_project_config=base_config,
    job_group="DataProcessing",
    job_specific_type=job_specific_type_nb03,
    run_specific_config=nb03_run_config_log,
    custom_run_name_elements=custom_elements_for_name_nb03,
    notes=f"{DATASET_IDENTIFIER.upper()}: Feature engineering and train/validation/test data splitting."
)

if run:
    print(f"W&B run '{run.name}' (Job Type: '{run.job_type}') initialized. View at: {run.url}")
else:
    print("Proceeding without W&B logging for this session (W&B run initialization failed).")
    # output_dir should still be defined for local saves.

## 3. Load Defined Cohort Data

Load the `final_analysis_cohort.csv` dataset which was generated and saved by Notebook 02. This dataset contains the subjects and their longitudinal visits that have met all initial inclusion criteria (baseline cognitive status, minimum number of visits, and MRI scan availability). Basic information about this cohort is printed, and the input dataset is logged as a "used" artifact in Weights & Biases for lineage tracking.

In [None]:
# --- Load Defined Cohort Data (Output from Notebook 02) ---
# COHORT_CSV_PATH_NB03_INPUT should be defined in the setup cell
print(f"\n--- Loading Defined Cohort Data from: {COHORT_CSV_PATH_NB03_INPUT} ---")
cohort_df = None # Initialize to ensure it's defined

try:
    if COHORT_CSV_PATH_NB03_INPUT is None or not COHORT_CSV_PATH_NB03_INPUT.is_file():
         raise FileNotFoundError(f"Defined cohort data file path not set or file not found: {COHORT_CSV_PATH_NB03_INPUT}. "
                                 "Ensure Notebook 02 ran successfully and saved its output, "
                                 "and config.json locators are correct.")
    
    cohort_df = pd.read_csv(COHORT_CSV_PATH_NB03_INPUT)
    print(f"Defined cohort data loaded successfully. Initial Shape: {cohort_df.shape}")

    if cohort_df.empty:
        print(f"CRITICAL ERROR: Loaded cohort DataFrame from {COHORT_CSV_PATH_NB03_INPUT} is empty. Cannot proceed.")
        if run: run.finish(exit_code=1)
        # exit() # Or raise error

    if run: # Log input stats and artifact if W&B run is active
        # Log characteristics of the input cohort
        num_input_subjects = cohort_df['Subject ID'].nunique() if 'Subject ID' in cohort_df.columns else 0
        run.log({'feature_engineering_input/cohort_rows': cohort_df.shape[0],
                 'feature_engineering_input/cohort_columns': cohort_df.shape[1],
                 'feature_engineering_input/cohort_subjects': num_input_subjects
                })
        
        # Log the input cohort CSV as a "used" artifact by this run for traceability
        input_cohort_artifact_name = f"input_cohort_{DATASET_IDENTIFIER}_NB03_from_NB02" # Clear, specific name
        input_cohort_artifact_description = (
            f"Final analysis cohort data (output of Notebook 02) used as input for Notebook 03 (Feature Engineering). "
            f"Source file: {COHORT_CSV_PATH_NB03_INPUT.name}"
        )
        input_cohort_wandb_artifact = wandb.Artifact(
            input_cohort_artifact_name,
            type=f"processed_dataset_{DATASET_IDENTIFIER}", # Matches type used by NB02 for its output artifact
            description=input_cohort_artifact_description,
            metadata={"source_notebook": "02_Cohort_Definition", 
                      "shape_rows": cohort_df.shape[0],
                      "shape_columns": cohort_df.shape[1],
                      "num_subjects": num_input_subjects,
                      "path_used_by_nb03": str(COHORT_CSV_PATH_NB03_INPUT)}
        )
        # Add the actual file that was used
        input_cohort_wandb_artifact.add_file(str(COHORT_CSV_PATH_NB03_INPUT), name=COHORT_CSV_PATH_NB03_INPUT.name)
        run.use_artifact(input_cohort_wandb_artifact) # Log that this run *used* this artifact
        print(f"Input cohort artifact '{input_cohort_artifact_name}' (from NB02) logged as used by this W&B run.")

    print("\nCohort DataFrame Head (first 5 rows):")
    #display(cohort_df.head()) # For better display
    print(cohort_df.head())
    print("\nCohort DataFrame Info:")
    cohort_df.info(verbose=True, show_counts=True) # Provides detailed info including non-null counts

except FileNotFoundError as e_fnf_cohort_nb03:
    print(f"CRITICAL ERROR: {e_fnf_cohort_nb03}")
    if run: run.finish(exit_code=1)
    # exit()
except Exception as e_load_cohort_nb03:
    print(f"CRITICAL ERROR occurred while loading the defined cohort data: {e_load_cohort_nb03}")
    if run: run.finish(exit_code=1)
    # exit()

# Ensure cohort_df is defined for subsequent cells, even if empty after an error (if not exiting)
if cohort_df is None: 
    cohort_df = pd.DataFrame()

## 4. Engineer Time-Based Features

For effective longitudinal modeling, features representing the passage of time and intervals between visits are essential. This section calculates:

* **`Days_from_Baseline`**: The number of days elapsed since the subject's first recorded visit *within the current analysis cohort*. This normalizes the timeline for each subject.
* **`Time_since_Last_Visit_Days`**: The number of days that have passed since the subject's immediately preceding visit. For a subject's first visit in the cohort, this value is set to 0.

The calculation prioritizes using the `MR Delay` column (assumed to be days from a consistent study baseline). If `MR Delay` is unreliable or largely missing, `Age` is used as a fallback to approximate these temporal features. The method employed for time feature generation is logged to W&B.

In [6]:
# --- Engineer Time-Based Features ---
print("\n--- Engineering Time-Based Features ---")

# Ensure cohort_df is not empty and is defined from the previous cell
if 'cohort_df' not in locals() or cohort_df.empty:
    print("CRITICAL ERROR: cohort_df is not defined or is empty. Cannot proceed with feature engineering.")
    if run: run.finish(exit_code=1)
    # exit() # Or raise error
else:
    # Work on a copy to avoid SettingWithCopyWarning if cohort_df was a slice,
    # and ensure original cohort_df (if needed later) is untouched by these specific calculations.
    # The user's original script modified cohort_df in place, which is also acceptable if intended.
    # For clarity of steps, let's use a new variable for this stage of processing.
    df_with_time_features = cohort_df.sort_values(by=['Subject ID', 'Visit']).copy()
    
    time_feature_source_method = "Not_Calculated" # Default status

    # Prioritize 'MR Delay' (days from study entry) if it exists, is numeric, and is mostly non-null
    if 'MR Delay' in df_with_time_features.columns and \
       pd.api.types.is_numeric_dtype(df_with_time_features['MR Delay']) and \
       df_with_time_features['MR Delay'].notnull().sum() > (len(df_with_time_features) * 0.5): # More than 50% non-null
        
        print("  Using 'MR Delay' (days from study baseline) to calculate longitudinal time features.")
        # 'MR Delay' is assumed to be days from a consistent project/study baseline for each subject.
        # To get Days_from_Baseline relative to their *first visit included in this cohort*:
        df_with_time_features['Min_MR_Delay_In_Cohort'] = df_with_time_features.groupby('Subject ID')['MR Delay'].transform('min')
        df_with_time_features['Days_from_Baseline'] = df_with_time_features['MR Delay'] - df_with_time_features['Min_MR_Delay_In_Cohort']
        
        # Time since last visit (for this subject, in this cohort)
        df_with_time_features['Time_since_Last_Visit_Days'] = df_with_time_features.groupby('Subject ID')['Days_from_Baseline'].diff()
        time_feature_source_method = 'MR_Delay_Col_Used'

    elif 'Age' in df_with_time_features.columns and pd.api.types.is_numeric_dtype(df_with_time_features['Age']):
        print("  Warning: 'MR Delay' not suitable or unavailable. Using 'Age' to approximate longitudinal time features.")
        # Calculate Days_from_Baseline based on age difference from first visit age in this cohort
        df_with_time_features['Min_Age_In_Cohort'] = df_with_time_features.groupby('Subject ID')['Age'].transform('min')
        df_with_time_features['Days_from_Baseline'] = (df_with_time_features['Age'] - df_with_time_features['Min_Age_In_Cohort']) * 365.25 # Approximate days
        
        # Time since last visit (for this subject, in this cohort) based on Age difference
        df_with_time_features['Time_since_Last_Visit_Days'] = df_with_time_features.groupby('Subject ID')['Days_from_Baseline'].diff()
        time_feature_source_method = 'Age_Approximation_Used'
    else:
        print("  CRITICAL ERROR: Neither 'MR Delay' nor 'Age' column is suitable for calculating time features. Stopping.")
        if run: run.finish(exit_code=1)
        # exit() 
        # If not exiting, ensure columns exist to prevent downstream errors, though they'll be all NaN
        df_with_time_features['Days_from_Baseline'] = np.nan
        df_with_time_features['Time_since_Last_Visit_Days'] = np.nan

    # For the first visit of each subject, Time_since_Last_Visit_Days will be NaN; fill with 0.0.
    if 'Time_since_Last_Visit_Days' in df_with_time_features.columns:
        df_with_time_features['Time_since_Last_Visit_Days'] = df_with_time_features['Time_since_Last_Visit_Days'].fillna(0.0)

    print("  Successfully calculated 'Days_from_Baseline' and 'Time_since_Last_Visit_Days'.")

    # Display example results for verification
    print("\n  Example Engineered Time Features (first 5 rows of processed data):")
    cols_to_show_time_example = ['Subject ID', 'Visit', 
                                 'MR Delay' if 'MR Delay' in df_with_time_features.columns else 'Age', 
                                 'Days_from_Baseline', 'Time_since_Last_Visit_Days']
    # Ensure all columns in the example list actually exist in the DataFrame before trying to display
    cols_to_show_time_example = [col for col in cols_to_show_time_example if col in df_with_time_features.columns]
    if cols_to_show_time_example:
        # display(df_with_time_features[cols_to_show_time_example].head())
        print(df_with_time_features[cols_to_show_time_example].head())
    else:
        print("    Could not display example time features (required columns for example view are missing).")

    # Log basic statistics of the new time features to W&B
    if 'Days_from_Baseline' in df_with_time_features.columns and \
       'Time_since_Last_Visit_Days' in df_with_time_features.columns and run:
        
        print("\n  Engineered Time Feature Statistics:")
        desc_engineered_time_features = df_with_time_features[['Days_from_Baseline', 'Time_since_Last_Visit_Days']].describe()
        # display(desc_engineered_time_features)
        print(desc_engineered_time_features)
        
        # Log which method was used to create time features to W&B config
        run.config.update({'feature_engineering_details/time_feature_source_method': time_feature_source_method}, allow_val_change=True)
        
        # Log statistics of these new features
        for col_stat_name in desc_engineered_time_features.columns:
            for idx_stat_name in desc_engineered_time_features.index: # e.g. 'mean', 'std', 'min', '25%'
                 # Sanitize stat_name for W&B key (e.g., replace '%' with 'pct')
                 safe_stat_name = idx_stat_name.replace('%','pct')
                 run.log({f'stats_engineered_features/time/{col_stat_name}_{safe_stat_name}': desc_engineered_time_features.loc[idx_stat_name, col_stat_name]})
        print("  Engineered time feature statistics logged to W&B.")
    elif not ('Days_from_Baseline' in df_with_time_features.columns and 'Time_since_Last_Visit_Days' in df_with_time_features.columns):
        print("  Could not calculate or log statistics for engineered time features (columns not created).")
        if run: run.config.update({'feature_engineering_details/time_feature_source_method': 'Failed_or_Skipped'}, allow_val_change=True)

# Replace original cohort_df with the processed one for subsequent steps in this notebook
if 'df_with_time_features' in locals() and not df_with_time_features.empty:
    cohort_df = df_with_time_features 
else:
    print("Warning: Time feature engineering resulted in an empty or undefined DataFrame. Original cohort_df may be used or errors might occur.")
    # cohort_df remains as it was, potentially without the new time features.

## 5. Prepare Static Features and Select Final Columns for Modeling

This step finalizes the feature set that will be used for building sequences:
1.  **Extract Baseline Clinical Scores:** For each subject, their Clinical Dementia Rating (CDR) and Mini-Mental State Examination (MMSE) scores from their *first visit present in this cohort* are extracted. These serve as static (time-invariant) features, providing crucial baseline cognitive context for each subject's entire sequence of visits.
2.  **Define Feature Categories:** Features are categorized into `identifiers`, `time_varying_features`, `static_features`, and the `target_base` ('CDR', used to derive the actual prediction target).
3.  **Select and Filter:** These predefined lists of desired features are filtered against the columns actually available in the current DataFrame. Only existing features are retained for the modeling dataset.
4.  **Create `feature_df`:** A new DataFrame is created containing only these selected columns.

The final lists of time-varying and static features intended for model input are logged to this W&B run's configuration. This configuration can then be fetched by downstream notebooks (e.g., Notebook 04 for fitting preprocessors, Notebook 06/07 for training models) to ensure consistent feature usage.

In [None]:
# --- Prepare Static Features & Select Final Columns for Modeling Dataset ---
print("\n--- Preparing Static Features & Selecting Final Columns ---")

feature_df = pd.DataFrame() # Initialize to ensure it's defined

# Ensure cohort_df (now potentially df_with_time_features) exists and is not empty
if 'cohort_df' in locals() and not cohort_df.empty:
    # Work on a copy to avoid modifying the DataFrame used for baseline feature extraction
    df_for_final_selection = cohort_df.copy()

    # --- Extract Baseline Clinical Scores (as static features for each subject) ---
    # These are derived from the *first visit present in this cohort* for each subject.
    # The DataFrame should already be sorted by Subject ID and Visit from previous cell.
    print("  Extracting Baseline_CDR and Baseline_MMSE as static features...")
    if 'CDR' in df_for_final_selection.columns and 'Subject ID' in df_for_final_selection.columns:
        df_for_final_selection['Baseline_CDR'] = df_for_final_selection.groupby('Subject ID')['CDR'].transform('first')
    else:
        print("  Warning: 'CDR' or 'Subject ID' column not found. 'Baseline_CDR' will be missing.")
        df_for_final_selection['Baseline_CDR'] = np.nan # Ensure column exists if expected

    if 'MMSE' in df_for_final_selection.columns and 'Subject ID' in df_for_final_selection.columns:
        df_for_final_selection['Baseline_MMSE'] = df_for_final_selection.groupby('Subject ID')['MMSE'].transform('first')
        # Propagate the first valid Baseline_MMSE to all visits for that subject
        df_for_final_selection['Baseline_MMSE'] = df_for_final_selection.groupby('Subject ID')['Baseline_MMSE'].ffill().bfill()
    else:
        print("  Warning: 'MMSE' or 'Subject ID' column not found. 'Baseline_MMSE' will be missing.")
        df_for_final_selection['Baseline_MMSE'] = np.nan # Ensure column exists
    print("  Baseline clinical scores extracted and propagated.")

    # --- Define Feature Categories for the Model ---
    # These are the *original* column names we aim to use or have engineered.
    # They will be filtered by actual availability in df_for_final_selection.

    # Days_from_Baseline measures the time elapsed relative to each subject's 
    # first visit that is part of the analysis cohort. MR Delay represents the 
    # number of days from a subject's first imaging session.
    
    # Time-varying features (dynamic per visit)
    time_varying_features = [
        'Age', 'MMSE', 'nWBV', 
        'Days_from_Baseline', 'Time_since_Last_Visit_Days'
    ]
    # Static features (constant per subject)
    static_features = [
        'M/F', # Original gender column for encoding later by OASISDataset
        'EDUC', 'SES', 
        'Baseline_CDR', 'Baseline_MMSE', # Engineered static features
        'eTIV', 'ASF'
    ]
    # Identifier columns crucial for data structure and linking
    identifiers = ['Subject ID', 'Visit', 'MRI ID']
    # Base column from which the actual target ('CDR_next_visit') will be derived
    target_base_col = ['CDR']


    # --- Filter Defined Feature Lists by Actual Availability in the DataFrame ---
    available_cols_in_current_df = df_for_final_selection.columns.tolist()
    
    selected_identifiers = [f for f in identifiers if f in available_cols_in_current_df]
    selected_time_varying = [f for f in time_varying_features if f in available_cols_in_current_df]
    selected_static = [f for f in static_features if f in available_cols_in_current_df]
    selected_target_base = [f for f in target_base_col if f in available_cols_in_current_df]

    # Validate presence of essential columns
    if not all(id_col in selected_identifiers for id_col in ['Subject ID', 'Visit']) or not selected_target_base:
         print("CRITICAL ERROR: Essential identifier ('Subject ID', 'Visit') or base target ('CDR') columns "
               "are missing from the DataFrame after feature selection. Check feature definitions and input data.")
         if run: run.finish(exit_code=1)
         # exit()

    # Combine all columns to keep for the modeling dataframe
    # This feature_df will be used for target creation and then splitting.
    final_columns_to_keep_for_feature_df = sorted(list(set(
        selected_identifiers + 
        selected_time_varying + 
        selected_static + 
        selected_target_base
    )))

    print(f"\nSelected final columns for 'feature_df' (count: {len(final_columns_to_keep_for_feature_df)}):")
    # pprint(final_columns_to_keep_for_feature_df) # Use pprint for nice list printing if desired

    # Create the feature DataFrame with only these selected columns
    feature_df = df_for_final_selection[final_columns_to_keep_for_feature_df].copy()
    print(f"Shape of 'feature_df' before target variable creation: {feature_df.shape}")

    # --- Log Selected Feature Lists to W&B Config for Downstream Use ---
    # This is the configuration that OASISDataset will expect to find when this run's
    # config is fetched by Notebook 04 (Fit Preprocessors) or Notebooks 06/07 (Train Models).
    if run:
        # This structure with nested 'features' and 'preprocess' (though preprocess is more NB04)
        # is what OASISDataset expects.
        features_config_for_downstream = {
            'time_varying': selected_time_varying, # List of time-varying features before encoding M/F
            'static': selected_static             # List of static features before encoding M/F
                                                  # M/F will be handled by OASISDataset if 'M/F_encoded' is in its static_feats_from_config
        }
        # Also log which columns were identified for scaling/imputation in NB04 for full context.
        # This is done in NB04 itself after analyzing the training data.
        # Here, we just log the feature lists for the model.
        
        # In NB03, the most important thing to log for OASISDataset is the
        # set of feature names that will constitute the tabular input.
        # OASISDataset also needs scaling_cols and imputation_cols, which NB04 determines.
        # So, NB04 will log the definitive 'features' and 'preprocess' config that OASISDataset uses.
        # What NB03 *can* log is the set of features it *prepared* and *selected*.
        
        wandb.config.update({
            "feature_selection/final_identifiers_kept": selected_identifiers,
            "feature_selection/final_time_varying_kept": selected_time_varying,
            "feature_selection/final_static_kept": selected_static,
            "feature_selection/final_target_base_kept": selected_target_base,
            # For direct use by OASISDataset if NB04 is skipped OR for documentation:
            "features_prepared_in_nb03": features_config_for_downstream 
        }, allow_val_change=True)
        print("Selected feature categories (input to target creation & splitting) logged to W&B config.")
else:
    print("Skipping static feature preparation and column selection as input DataFrame ('cohort_df') is empty.")
    feature_df = pd.DataFrame() # Ensure feature_df is defined

## 6. Create Target Variable (`CDR_next_visit`)

Generate the primary target variable for the longitudinal prediction task. The goal is to predict the Clinical Dementia Rating (CDR) score of the *next available visit* for each subject. To achieve this:
1.  The `feature_df` (containing selected features and the 'CDR' column) is sorted by `Subject ID` and `Visit`.
2.  For each subject, the 'CDR' values are shifted by one position (`.shift(-1)`). This assigns the CDR of visit `k+1` as the `CDR_next_visit` for visit `k`.
3.  Rows corresponding to a subject's last recorded visit will have a `NaN` value for `CDR_next_visit` (as there is no subsequent visit). These rows are dropped from the `feature_df` because they cannot be used for training or evaluating this specific prediction task.
The number of rows before and after this operation is logged.

In [8]:
# --- Create Target Variable (CDR score at the next visit) ---
print("\n--- Creating Target Variable: 'CDR_next_visit' ---")

# Ensure feature_df exists and is not empty from the previous cell
if 'feature_df' in locals() and not feature_df.empty:
    # Ensure DataFrame is sorted by Subject ID and then by Visit for correct shifting
    # This copy is important if feature_df is used elsewhere before this modification.
    df_for_target_creation = feature_df.sort_values(by=['Subject ID', 'Visit']).copy()

    if 'CDR' in df_for_target_creation.columns and 'Subject ID' in df_for_target_creation.columns:
        # Shift CDR scores up within each subject's group to get the next visit's CDR
        df_for_target_creation['CDR_next_visit'] = df_for_target_creation.groupby('Subject ID')['CDR'].shift(-1)

        # Rows corresponding to the last visit of each subject will now have NaN for 'CDR_next_visit'
        initial_rows_before_target_dropna = len(df_for_target_creation)
        
        # Drop rows where 'CDR_next_visit' is NaN (i.e., the last visit for each subject)
        df_with_target = df_for_target_creation.dropna(subset=['CDR_next_visit']).copy()
        # It's good practice to ensure the target is of a float type for regression models
        df_with_target['CDR_next_visit'] = df_with_target['CDR_next_visit'].astype(float)
        
        rows_after_target_dropna = len(df_with_target)
        rows_dropped_for_last_visit = initial_rows_before_target_dropna - rows_after_target_dropna

        print(f"Removed {rows_dropped_for_last_visit} rows (last visit of a subject or where next CDR was NaN).")
        print(f"Shape of DataFrame after creating target 'CDR_next_visit' and dropping NaNs: {df_with_target.shape}")

        if run: # Log statistics related to target creation
            run.log({
                'target_creation/rows_before_dropna_for_target': initial_rows_before_target_dropna,
                'target_creation/rows_dropped_for_missing_target': rows_dropped_for_last_visit,
                'target_creation/final_rows_with_target': rows_after_target_dropna,
                'target_creation/final_subjects_with_target': df_with_target['Subject ID'].nunique() if 'Subject ID' in df_with_target.columns else 0
            })

        if df_with_target.empty:
            print("CRITICAL ERROR: No data remaining after creating target variable and dropping NaNs. "
                  "This might happen if all subjects have only one visit after previous filters.")
            if run: run.finish(exit_code=1)
            # exit()
        
        feature_df = df_with_target # Update feature_df to be this version with the target
    else:
        print("Warning: 'CDR' or 'Subject ID' column not found. Cannot create 'CDR_next_visit' target.")
        # feature_df remains as it was, subsequent steps will likely fail or produce empty results.
else:
    print("Skipping target variable creation as 'feature_df' is empty or not defined.")
    # feature_df remains an empty DataFrame if it was initialized as such.

## 7. Perform Subject-Level Stratified Train/Validation/Test Split

The `feature_df` is split into training, validation, and test sets. This split is performed at the **subject level** to prevent data leakage, ensuring that all visits from a single subject belong exclusively to one set (train, validation, or test).

Stratification is based on the **`Baseline_CDR`** score of each subject. This helps to ensure that the distribution of baseline cognitive impairment levels is similar across the different data splits, which is important for robust model training and evaluation. The split ratios and random state are defined for reproducibility. Counts of subjects and visits in each split are logged to W&B.

In [9]:
# --- Perform Subject-Level Stratified Train/Validation/Test Split ---
print("\n--- Performing Stratified Train/Validation/Test Split by Subject ---")

# NOTE ON STRATIFICATION:
# Currently, the subject-level split is stratified based on 'Baseline_CDR' to ensure 
# representation of initial cognitive states across train, validation, and test sets.
# For future enhancements or if dealing with datasets where other baseline factors 
# (e.g., Sex, specific Age groups, Education levels, or future genetic markers like APOE for ADNI) 
# are known to have very strong confounding effects and imbalanced distributions, 
# more complex stratification could be considered. This might involve creating a 
# composite stratification key from multiple baseline variables. However, this also 
# increases the risk of having strata with too few samples, especially in smaller datasets.
# For this phase, stratification by Baseline_CDR is deemed the most critical.

# Initialize split DataFrames to ensure they are defined even if splitting fails or df is empty
train_df = pd.DataFrame()
val_df = pd.DataFrame()
test_df = pd.DataFrame()

if 'feature_df' in locals() and not feature_df.empty:
    # Get unique subjects and their baseline CDR (from the first visit in feature_df) for stratification
    # Ensure 'Visit' and 'Baseline_CDR' columns exist
    if not all(col in feature_df.columns for col in ['Subject ID', 'Visit', 'Baseline_CDR']):
        print("CRITICAL ERROR: 'Subject ID', 'Visit', or 'Baseline_CDR' missing from feature_df. Cannot stratify.")
        if run: run.finish(exit_code=1)
        # exit()
    else:
        # Get the Baseline_CDR for each subject (should be constant per subject already due to transform('first'))
        # We need one value per unique subject for stratification.
        subject_stratification_info = feature_df.drop_duplicates(subset=['Subject ID'])[['Subject ID', 'Baseline_CDR']].copy()
        subject_stratification_info.set_index('Subject ID', inplace=True)
        
        unique_subjects_for_split = subject_stratification_info.index.unique()
        subject_baseline_cdr_labels = subject_stratification_info['Baseline_CDR']

        if len(unique_subjects_for_split) < 10: # Arbitrary small number, adjust as needed
            print(f"Warning: Very small number of unique subjects ({len(unique_subjects_for_split)}) for splitting. "
                  "Splits might be very small or unstable.")

        # Use pre-defined split ratios and random state (defined in setup, logged to W&B config)
        # TEST_SET_RATIO_CONFIG, VAL_SET_RATIO_CONFIG, RANDOM_STATE_CONFIG
        
        # Calculate relative validation size for the second split (from (1 - test_ratio) of data)
        if (1.0 - TEST_SET_RATIO_CONFIG) <= 0:
             print("CRITICAL ERROR: TEST_SET_RATIO_CONFIG must be less than 1.0.")
             if run: run.finish(exit_code=1)
             # exit()
        else:
            relative_val_size_for_split = VAL_SET_RATIO_CONFIG / (1.0 - TEST_SET_RATIO_CONFIG)

            print(f"Attempting to split {len(unique_subjects_for_split)} unique subjects.")
            print(f"  Target Test Ratio: {TEST_SET_RATIO_CONFIG:.2%}")
            print(f"  Target Validation Ratio (of original): {VAL_SET_RATIO_CONFIG:.2%}")
            print(f"  Target Train Ratio (of original): {1.0 - TEST_SET_RATIO_CONFIG - VAL_SET_RATIO_CONFIG:.2%}")
            print(f"  Stratifying by 'Baseline_CDR'. Random State: {RANDOM_STATE_CONFIG}")

            try:
                # First split: separate out the test set
                train_val_subjects, test_subjects, train_val_labels, _ = train_test_split(
                    unique_subjects_for_split, subject_baseline_cdr_labels,
                    test_size=TEST_SET_RATIO_CONFIG,
                    random_state=RANDOM_STATE_CONFIG,
                    stratify=subject_baseline_cdr_labels 
                )

                # Second split: split the remaining (train_val_subjects) into train and validation
                # Handle cases where train_val_subjects might be too small or lack diversity for stratification
                if len(train_val_subjects) < 2 or (len(np.unique(train_val_labels)) < 2 and len(train_val_subjects) >=2) : # Check if stratification is possible
                     print("  Warning: Train+Validation set too small or lacks label diversity for stratified validation split. "
                           "Performing non-stratified split for Train/Validation from the remainder.")
                     if len(train_val_subjects) < 2: # Can't split further
                          train_subjects = train_val_subjects
                          val_subjects = np.array([]) # No validation subjects
                     else: # Can do non-stratified split
                          train_subjects, val_subjects = train_test_split(
                               train_val_subjects,
                               test_size=relative_val_size_for_split, # test_size here refers to proportion for val_subjects
                               random_state=RANDOM_STATE_CONFIG
                               # No stratify=train_val_labels here
                          )
                else: # Proceed with stratified split for train/validation
                     train_subjects, val_subjects, _, _ = train_test_split(
                         train_val_subjects, train_val_labels,
                         test_size=relative_val_size_for_split, 
                         random_state=RANDOM_STATE_CONFIG,
                         stratify=train_val_labels
                     )

                print(f"\nSplit completed:")
                print(f"  Train subjects count: {len(train_subjects)}")
                print(f"  Validation subjects count: {len(val_subjects)}")
                print(f"  Test subjects count: {len(test_subjects)}")
                total_split_subjects = len(train_subjects) + len(val_subjects) + len(test_subjects)
                print(f"  Total subjects accounted for in splits: {total_split_subjects} (Expected: {len(unique_subjects_for_split)})")
                if total_split_subjects != len(unique_subjects_for_split):
                    print("  WARNING: Discrepancy in total split subjects vs. unique subjects!")


                # --- Create the actual split DataFrames using the lists of subject IDs ---
                train_df = feature_df[feature_df['Subject ID'].isin(train_subjects)].copy()
                val_df = feature_df[feature_df['Subject ID'].isin(val_subjects)].copy()
                test_df = feature_df[feature_df['Subject ID'].isin(test_subjects)].copy()

                print(f"\nSplit DataFrames created:")
                print(f"  Train DataFrame shape: {train_df.shape} ({train_df['Subject ID'].nunique()} subjects)")
                print(f"  Validation DataFrame shape: {val_df.shape} ({val_df['Subject ID'].nunique()} subjects)")
                print(f"  Test DataFrame shape: {test_df.shape} ({test_df['Subject ID'].nunique()} subjects)")

                # Log split details to W&B
                if run:
                    run.log({
                        'split_counts/subjects_train': len(train_subjects),
                        'split_counts/subjects_val': len(val_subjects),
                        'split_counts/subjects_test': len(test_subjects),
                        'split_counts/visits_train': len(train_df),
                        'split_counts/visits_val': len(val_df),
                        'split_counts/visits_test': len(test_df),
                    })
                    # Config for split ratios was already logged during W&B init for this notebook
                    # run.config.update({
                    #     'split_details/test_set_ratio_used': TEST_SET_RATIO_CONFIG,
                    #     'split_details/validation_set_ratio_used': VAL_SET_RATIO_CONFIG,
                    #     'split_details/stratify_by_column': 'Baseline_CDR',
                    #     'split_details/random_state_used': RANDOM_STATE_CONFIG
                    # }, allow_val_change=True) # Already logged

                    # Verify stratification proportions in W&B
                    print("  Logging stratification check to W&B...")
                    for split_name_log, df_log_split in [('Train', train_df), ('Validation', val_df), ('Test', test_df)]:
                         if not df_log_split.empty and 'Subject ID' in df_log_split.columns and 'Baseline_CDR' in df_log_split.columns:
                             baseline_dist_log = df_log_split.drop_duplicates(subset=['Subject ID'])['Baseline_CDR'].value_counts(normalize=True).sort_index()
                             for cdr_val, prop in baseline_dist_log.items():
                                 run.log({f'split_stratification_check/{split_name_log}_prop_baseline_cdr_{str(cdr_val).replace(".","p")}': prop})
                    print("  Stratification check logged.")

            except ValueError as e_split_value: # Typically from stratify if a class has too few members
                print(f"\nCRITICAL ERROR during stratified split: {e_split_value}")
                print("  This often happens if a group for stratification (e.g., a Baseline_CDR value) "
                      "has too few members (e.g., less than n_splits=2 for that group).")
                print("  Consider adjusting split ratios, ensuring enough samples per stratum, or using non-stratified split as a fallback if cohort is very small.")
                if run: run.finish(exit_code=1)
                # exit()
            except Exception as e_split_general:
                 print(f"An unexpected CRITICAL ERROR occurred during data splitting: {e_split_general}")
                 if run: run.finish(exit_code=1)
                 # exit()
else:
    print("Skipping data splitting as input 'feature_df' is empty or not defined.")
    # Ensure train_df, val_df, test_df are empty DataFrames if not created
    train_df = pd.DataFrame()
    val_df = pd.DataFrame()
    test_df = pd.DataFrame()

## 8. Save Split DataFrames and Log to W&B Artifacts

The resulting train, validation, and test DataFrames, which now include engineered time features, selected static features, identifiers, and the `CDR_next_visit` target variable, are saved locally as efficient Parquet files in this notebook's output directory. These crucial data splits are also logged as versioned artifacts to Weights & Biases. This ensures that the exact datasets used for training, validation, and testing can be easily retrieved and reproduced in subsequent modeling and analysis stages (e.g., Notebook 04 for fitting preprocessors, Notebook 06/07 for model training).

In [10]:
# --- Save Split DataFrames Locally and Log as W&B Artifacts ---
print("\n--- Saving Split DataFrames and Logging to W&B Artifacts ---")

# Dictionary of DataFrames to process
# This ensures train_df, val_df, test_df are defined (even if empty from a failed split)
split_dataframes_to_save = {
    "train": train_df if 'train_df' in locals() else pd.DataFrame(),
    "validation": val_df if 'val_df' in locals() else pd.DataFrame(),
    "test": test_df if 'test_df' in locals() else pd.DataFrame()
}
# Define a consistent W&B artifact type for these data splits
wandb_artifact_type_for_splits = f"data_split_{DATASET_IDENTIFIER}" 

for split_name, df_to_save in split_dataframes_to_save.items():
    if df_to_save is not None and not df_to_save.empty:
        # Define local save path and W&B artifact name
        local_file_name = f"cohort_{split_name}_{DATASET_IDENTIFIER}.parquet"
        local_file_path = output_dir / local_file_name # output_dir is NB03's static output dir
        
        # W&B artifact name, e.g., "cohort_split_train_oasis2"
        wandb_artifact_name_for_split = f"cohort_split_{split_name}_{DATASET_IDENTIFIER}"
        
        try:
            # Save locally as Parquet
            df_to_save.to_parquet(local_file_path, index=False)
            print(f"{split_name.capitalize()} DataFrame saved locally to: {local_file_path} (Shape: {df_to_save.shape})")

            # Log as W&B artifact
            if run: # Check if W&B run is active
                print(f"  Logging {split_name} DataFrame as W&B artifact: '{wandb_artifact_name_for_split}'...")
                split_description = (
                    f"{split_name.capitalize()} data split for {DATASET_IDENTIFIER}. "
                    f"Contains {df_to_save['Subject ID'].nunique() if 'Subject ID' in df_to_save.columns else 'N/A'} subjects, "
                    f"{len(df_to_save)} visits. Features engineered, target 'CDR_next_visit' created."
                )
                split_metadata = {
                    'dataset_identifier': DATASET_IDENTIFIER,
                    'split_type': split_name,
                    'num_rows': len(df_to_save),
                    'num_columns': len(df_to_save.columns),
                    'num_subjects': df_to_save['Subject ID'].nunique() if 'Subject ID' in df_to_save.columns else 'N/A',
                    'columns_list': df_to_save.columns.tolist()
                }
                
                artifact_to_log = wandb.Artifact(
                    wandb_artifact_name_for_split, 
                    type=wandb_artifact_type_for_splits, 
                    description=split_description,
                    metadata=split_metadata
                )
                artifact_to_log.add_file(str(local_file_path), name=local_file_name) # Add the saved parquet file
                run.log_artifact(artifact_to_log, aliases=["latest", f"{split_name}_{time.strftime('%Y%m%d')}"])
                print(f"  {split_name.capitalize()} data split artifact logged to W&B.")

        except Exception as e_save_split:
            print(f"Warning: Could not save or log {split_name} DataFrame. Error: {e_save_split}")
    else:
        print(f"Skipping saving/logging for empty or None '{split_name}' DataFrame.")

## 9. Next Steps

The split DataFrames (`cohort_train.parquet`, `cohort_validation.parquet`, `cohort_test.parquet`) containing engineered and selected features are now saved locally and logged as W&B artifacts. These are the primary inputs for Notebook 04 (Fitting Preprocessors) and subsequently for model training (Notebooks 06 & 07).

The subsequent stages involve:
1.  **Fitting Preprocessors (Notebook 04):** Fit imputers and scalers *only* on `cohort_train.parquet`.
2.  **Data Loading Pipeline (`src/datasets.py`):** The `OASISDataset` class will load these Parquet files, apply the fitted preprocessors, handle sequence creation, padding, and batching.
3.  **MRI Feature Extraction:** The `run_preprocessing.py` script processes raw MRIs. The `OASISDataset` (for hybrid models) will load these preprocessed MRIs.
4.  **Model Training & Evaluation (Notebooks 06, 07):** Train models using the DataLoaders.
5.  **Model Analysis (Notebook 08):** Perform interpretability and uncertainty analysis on trained models.

## 10. Finalize W&B Run

Complete the execution of this notebook and finish the associated Weights & Biases run, ensuring all logs and artifacts are uploaded.

In [11]:
# --- Finish W&B Run ---
print(f"\n--- {NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER} complete. Finishing W&B run. ---")
if run: # Check if 'run' object exists and is an active run
    try:
        # Log final counts of the created DataFrames to summary, if they exist
        if 'train_df' in locals() and not train_df.empty: run.summary['final_train_df_rows'] = len(train_df)
        if 'val_df' in locals() and not val_df.empty: run.summary['final_val_df_rows'] = len(val_df)
        if 'test_df' in locals() and not test_df.empty: run.summary['final_test_df_rows'] = len(test_df)
        
        run.finish()
        run_name_to_print = run.name_synced if hasattr(run, 'name_synced') and run.name_synced else \
                            run.name if hasattr(run, 'name') and run.name else \
                            run.id if hasattr(run, 'id') else "current run"
        print(f"W&B run '{run_name_to_print}' finished successfully.")
    except Exception as e_finish_run:
        print(f"Error during wandb.finish(): {e_finish_run}")
else:
    print("No active W&B run to finish for this session.")

print(f"\n--- Notebook {NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER} execution finished. ---")