# Notebook 02: OASIS-2 Cohort Definition for Longitudinal Analysis

**Project Phase:** 1 (Data Processing - Cohort Definition)
**Dataset:** OASIS-2 Longitudinal MRI & Clinical Data

**Purpose:**
This notebook defines the final analysis cohort from the OASIS-2 dataset based on specific inclusion and exclusion criteria. It builds upon the raw clinical data and the MRI scan file verification results from Notebook 01. The key objectives are:
1.  Load the raw clinical dataset and the MRI verification details.
2.  Apply inclusion criteria:
    * Baseline CDR score (e.g., selecting subjects who are Cognitively Normal or have Mild Cognitive Impairment at their first visit).
    * Minimum number of longitudinal visits per subject to ensure sufficient data for sequence modeling.
    * Availability of verified MRI scan files for the selected visits.
3.  Log all cohort definition criteria and step-by-step filtering statistics to Weights & Biases (W&B).
4.  Save the final cohort DataFrame locally as `final_analysis_cohort.csv`.
5.  Log the final cohort DataFrame as a versioned artifact to W&B for use in downstream feature engineering and modeling notebooks.

**Workflow:**
1.  **Setup:** Import libraries, set up `src` path, load `config.json`, and define initial paths including inputs from Notebook 01.
2.  **W&B Initialization:** Start a new W&B run for this cohort definition task using the `initialize_wandb_run` utility. Define the output directory for this notebook.
3.  **Load Input Data:** Load the raw clinical data (from original Excel) and the `verification_details.csv` (from Notebook 01).
4.  **Filter by Baseline CDR:** Select subjects based on their CDR score at their first available visit. Log counts.
5.  **Filter by Minimum Visits:** Analyze visit counts for the remaining subjects, decide on a minimum visit threshold, and apply the filter. Log counts and criteria.
6.  **Filter by MRI Availability:** Retain only those visits from the filtered cohort that have a corresponding verified MRI scan (based on `verification_details.csv`). Log counts.
7.  **Summarize & Save Final Cohort:** Print characteristics of the final cohort, save it locally, and log it as a W&B artifact.
8.  **Define Prediction Task (High-Level):** Briefly log the intended prediction task for which this cohort is being prepared.
9.  **Finalize W&B Run.**

**Input:**
* `config.json`: Main project configuration file.
* Raw OASIS-2 Clinical Data Excel file (path from `config.json`).
* `verification_details.csv`: Output from Notebook 01, containing MRI file verification status (path constructed based on `config.json` and Notebook 01's output structure).

**Output:**
* **Local Files (in notebook-specific output directory, e.g., `notebooks/outputs/02_Cohort_Definition_OASIS2/`):**
    * `final_analysis_cohort.csv` (Key input for Notebook 03)
    * Plots related to visit count distributions (if any generated in this notebook).
* **W&B Run:**
    * Logged run configuration (including input paths and cohort criteria).
    * Step-by-step subject/visit counts after each filtering stage.
    * Final cohort criteria (baseline CDR, min visits).
    * The `final_analysis_cohort.csv` logged as a W&B Artifact (e.g., `analysis_cohort-OASIS2-CDR_X_X-MinV_Y`).

In [None]:
# In: notebooks/02_Cohort_Definition.ipynb
# Purpose: Define the analysis cohort for longitudinal prediction based on OASIS-2 data.
#          Applies criteria for baseline status, minimum visits, and MRI availability.
#          Logs decisions and outputs the final cohort definition.

In [None]:
# --- Import Libraries ---
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
import sys
import os
import time
import json
from pathlib import Path

## 1. Setup: Project Configuration, Paths, and Utilities

This section initializes the notebook environment:
* Determines the project's root directory to enable access to shared resources and modules.
* Adds the `src` directory to the Python system path for importing custom utility functions.
* Imports necessary custom utilities, particularly for W&B run initialization and path management.
* Loads the main project configuration from `config.json`.
* Defines key dataset identifiers and notebook-specific parameters.
* Resolves and prints essential input paths (raw clinical data, MRI verification results from Notebook 01) and sets up the primary output directory for this notebook's locally saved files (e.g., the final cohort CSV).

In [None]:
# --- Project Setup, Configuration Loading, and Utility Imports ---
print("--- Initializing Project Setup & Configuration ---")

# Initialize
PROJECT_ROOT = None
base_config = {}

try:
    # --- 1. Determine Project Root and Add src to Python Path ---
    current_notebook_path = Path.cwd() # Assumes notebook is run from its directory
    # Try to find project root assuming standard structure: <PROJECT_ROOT>/notebooks/02_...
    potential_project_root = current_notebook_path.parent 
    if (potential_project_root / "src").is_dir() and \
       (potential_project_root / "config.json").is_file():
        PROJECT_ROOT = potential_project_root
    else: # Fallback: assume current working directory IS the project root
        PROJECT_ROOT = current_notebook_path 
    
    if not (PROJECT_ROOT / "src").is_dir() or not (PROJECT_ROOT / "config.json").is_file():
        raise FileNotFoundError(
            f"CRITICAL: 'src' directory or 'config.json' not found relative to determined "
            f"PROJECT_ROOT: {PROJECT_ROOT}. Ensure 'config.json' is at the project root."
        )

    if str(PROJECT_ROOT) not in sys.path:
        sys.path.insert(0, str(PROJECT_ROOT)) # Add project root to allow 'from src...'
    print(f"PROJECT_ROOT successfully set to: {PROJECT_ROOT}")
    print(f"Added '{str(PROJECT_ROOT)}' to sys.path.")

    # --- 2. Import Custom Utilities ---
    from src.wandb_utils import initialize_wandb_run
    from src.plotting_utils import finalize_plot 
    print("Successfully imported custom utilities")

    # --- 3. Load Main Project Configuration ---
    CONFIG_PATH_MAIN = PROJECT_ROOT / 'config.json'
    with open(CONFIG_PATH_MAIN, 'r', encoding='utf-8') as f:
        base_config = json.load(f)
    print(f"Main project configuration loaded from: {CONFIG_PATH_MAIN}")

    # --- 4. Define Dataset, Notebook Specifics, and Key Paths for NB02 ---
    DATASET_IDENTIFIER = "oasis2" # Specific to this notebook's current focus
    NOTEBOOK_MODULE_NAME = "02_Cohort_Definition" # For job_type and output folder naming
    
    # Key from config.json's 'pipeline_artefact_locators_oasis2' for this notebook's output subfolder
    NB02_OUTPUT_LOCATOR_KEY = "cohort_def_subdir" 

    # Resolve critical INPUT paths from base_config
    INPUT_DATA_PATH_CLINICAL = PROJECT_ROOT / base_config['data']['clinical_excel_path']
    
    output_dir_base_from_config = PROJECT_ROOT / base_config['data']['output_dir_base']
    dataset_locators = base_config.get(f"pipeline_artefact_locators_{DATASET_IDENTIFIER}", {})
    if not dataset_locators:
        raise KeyError(f"pipeline_artefact_locators_{DATASET_IDENTIFIER} section not found or empty in config.json.")

    # Path to verification_details.csv (output from Notebook 01)
    nb01_output_subdir_name_key = "exploration_subdir"
    nb01_verification_fname_key = "verification_csv_fname"
    
    nb01_output_subdir_name = dataset_locators.get(nb01_output_subdir_name_key)
    nb01_verification_fname = dataset_locators.get(nb01_verification_fname_key)

    if not nb01_output_subdir_name or not nb01_verification_fname:
        raise KeyError(f"Missing '{nb01_output_subdir_name_key}' or '{nb01_verification_fname_key}' "
                       f"in '{f'pipeline_artefact_locators_{DATASET_IDENTIFIER}'}' of config.json.")
    VERIFICATION_CSV_PATH_NB02_INPUT = output_dir_base_from_config / nb01_output_subdir_name / nb01_verification_fname

    # Define the main OUTPUT directory for THIS notebook's files (e.g., final_analysis_cohort.csv)
    # This uses a static output directory pattern for NB02, derived from config.
    notebook_output_folder_name_from_locators = dataset_locators.get(
        NB02_OUTPUT_LOCATOR_KEY, 
        f"{NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER}_default_outputs" # Fallback name
    )
    output_dir = output_dir_base_from_config / notebook_output_folder_name_from_locators
    output_dir.mkdir(parents=True, exist_ok=True) # Create if it doesn't exist

    # --- Print and Verify Key Paths ---
    print(f"\nKey paths defined for Notebook 02 ({DATASET_IDENTIFIER}):")
    print(f"  Input Raw Clinical Data Excel: {INPUT_DATA_PATH_CLINICAL}")
    print(f"  Input MRI Verification CSV (from NB01): {VERIFICATION_CSV_PATH_NB02_INPUT}")
    print(f"  Notebook Output Directory (for local saves like final_analysis_cohort.csv): {output_dir}")
    
    # Critical Input File Checks
    if not INPUT_DATA_PATH_CLINICAL.is_file():
        raise FileNotFoundError(f"CRITICAL: Input clinical data excel file not found at specified path: {INPUT_DATA_PATH_CLINICAL}")
    if not VERIFICATION_CSV_PATH_NB02_INPUT.is_file():
        raise FileNotFoundError(f"CRITICAL: MRI verification CSV (from NB01) not found at specified path: {VERIFICATION_CSV_PATH_NB02_INPUT}. "
                                "Ensure Notebook 01 ran successfully and config.json locators are correct.")
    print("All critical input paths for NB02 verified.")

except (FileNotFoundError, KeyError, ValueError) as e_setup: # Catch specific expected errors
    print(f"CRITICAL ERROR during setup in Notebook 02: {e_setup}")
    # exit() # Optional: exit if setup fails
except Exception as e_general:
    print(f"An unexpected CRITICAL ERROR occurred during setup in Notebook 02: {e_general}")
    # exit()

## 2. Initialize Weights & Biases Run

A new W&B run is initialized for this cohort definition notebook. This run will track the configuration parameters used for defining the cohort, step-by-step filtering statistics, and the final cohort dataset as an artifact.

In [None]:
# --- Initialize W&B Run for THIS Notebook Execution (NB02) ---
print("\n--- Initializing Weights & Biases Run for Notebook 02 ---")

# Configuration specific to this NB02 run
nb02_run_config_log = {
    "notebook_name_code": f"{NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER}",
    "dataset_source": DATASET_IDENTIFIER,
    "input_raw_clinical_data_path": str(INPUT_DATA_PATH_CLINICAL), # From Cell 3
    "input_mri_verification_csv_path": str(VERIFICATION_CSV_PATH_NB02_INPUT), # From Cell 3
    "output_dir_for_local_saves": str(output_dir), # From Cell 3
    "execution_timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
    # Cohort selection criteria (baseline CDR, min_visits) will be added later via run.config.update()
}

# Extract notebook number for naming and job type
nb_number_prefix_nb02 = NOTEBOOK_MODULE_NAME.split('_')[0] if '_' in NOTEBOOK_MODULE_NAME else "NB"
job_specific_type_nb02 = f"{nb_number_prefix_nb02}-CohortDefinition-{DATASET_IDENTIFIER}"
custom_elements_for_name_nb02 = [nb_number_prefix_nb02, DATASET_IDENTIFIER.upper(), "CohortDef"]

run = initialize_wandb_run(
    base_project_config=base_config,
    job_group="DataProcessing",
    job_specific_type=job_specific_type_nb02,
    run_specific_config=nb02_run_config_log,
    custom_run_name_elements=custom_elements_for_name_nb02,
    notes=f"{DATASET_IDENTIFIER.upper()}: Defining analysis cohort based on inclusion criteria."
)

if run:
    print(f"W&B run '{run.name}' (Job Type: '{run.job_type}') initialized. View at: {run.url}")
else:
    print("Proceeding without W&B logging for this session (W&B run initialization failed).")
    # output_dir should still be defined from Cell 3 for local saves.

## 3. Load Raw Clinical Data

Load the raw longitudinal clinical and demographic data from the Excel file specified in `config.json`. This dataset forms the basis from which the analysis cohort will be derived. The shape of the loaded data and a W&B artifact for the raw data are logged.

In [None]:
# --- Load Raw Clinical Data ---
# INPUT_DATA_PATH_CLINICAL should be defined in the setup cell
print(f"\n--- Loading Raw Clinical Data from: {INPUT_DATA_PATH_CLINICAL} ---")
clinical_df_raw = None # Initialize to ensure it's defined in case of errors

try:
    if INPUT_DATA_PATH_CLINICAL is None or not INPUT_DATA_PATH_CLINICAL.is_file():
         raise FileNotFoundError(f"Input clinical data file path not defined or file not found: {INPUT_DATA_PATH_CLINICAL}")
    
    clinical_df_raw = pd.read_excel(INPUT_DATA_PATH_CLINICAL)
    print(f"Raw clinical data loaded successfully. Shape: {clinical_df_raw.shape}")

    if clinical_df_raw.empty:
        print("CRITICAL ERROR: Loaded raw clinical dataframe is empty. Stopping execution.")
        if run: run.finish(exit_code=1) # Finish W&B run with error status
        # exit() # Or raise an error to halt the notebook

    # Log initial raw data characteristics to W&B
    if run: 
        run.log({'cohort_definition/00_raw_clinical_rows': clinical_df_raw.shape[0],
                 'cohort_definition/00_raw_clinical_columns': clinical_df_raw.shape[1]})
        
        # Log the raw data file as a W&B artifact for complete traceability
        print("Logging raw clinical data file as W&B artifact...")
        raw_data_artifact_name = f"raw_clinical_data_source_{DATASET_IDENTIFIER}"
        raw_data_artifact = wandb.Artifact(
            raw_data_artifact_name, 
            type=f"raw_dataset_{DATASET_IDENTIFIER}", # Use a consistent type for raw datasets
            description=f"Raw clinical and demographic data for {DATASET_IDENTIFIER} dataset, "
                        f"loaded from {INPUT_DATA_PATH_CLINICAL.name}.",
            metadata={
                "source_file_path": str(INPUT_DATA_PATH_CLINICAL), 
                "shape_rows": clinical_df_raw.shape[0],
                "shape_columns": clinical_df_raw.shape[1],
                "load_timestamp": time.strftime("%Y-%m-%d %H:%M:%S") # From import time
            }
        )
        raw_data_artifact.add_file(str(INPUT_DATA_PATH_CLINICAL), name="source_excel_file.xlsx")
        run.log_artifact(raw_data_artifact, aliases=["original", f"{time.strftime('%Y%m%d')}"])
        print(f"Raw data artifact '{raw_data_artifact_name}' logged to W&B.")

except FileNotFoundError as e_fnf:
    print(f"CRITICAL ERROR: {e_fnf}")
    if run: run.finish(exit_code=1)
    # exit()
except ImportError: # For pd.read_excel if 'openpyxl' is missing
     print(f"CRITICAL ERROR loading Excel file: Missing 'openpyxl' library. Please install it: `pip install openpyxl`")
     if run: run.finish(exit_code=1)
     # exit()
except Exception as e_load_raw:
    print(f"CRITICAL ERROR occurred while loading the raw clinical data: {e_load_raw}")
    if run: run.finish(exit_code=1)
    # exit()

# Ensure clinical_df_raw is defined for subsequent cells, even if empty after an error (if not exiting)
if clinical_df_raw is None:
    clinical_df_raw = pd.DataFrame()

## 4. Load MRI Scan Verification Results

Load the `verification_details.csv` file, which was generated by Notebook 01. This CSV contains information on which `MRI ID`s (scan sessions) have successfully located raw scan files (`.img` + `.hdr` pairs) in the local filesystem. This step is crucial for ensuring that the cohort definition only includes subjects/visits for whom MRI data is actually available for preprocessing.

In [None]:
# --- Load MRI Scan Verification Results (Output from Notebook 01) ---
# VERIFICATION_CSV_PATH_NB02_INPUT should be defined in the setup cell
print(f"\n--- Loading MRI Verification Results from: {VERIFICATION_CSV_PATH_NB02_INPUT} ---")
verification_df = None # Initialize
verified_mri_ids = set() # Initialize as an empty set

try:
    if VERIFICATION_CSV_PATH_NB02_INPUT is None or not VERIFICATION_CSV_PATH_NB02_INPUT.is_file():
         raise FileNotFoundError(f"MRI verification CSV file path not defined or file not found: {VERIFICATION_CSV_PATH_NB02_INPUT}. "
                                 "Please ensure Notebook 01 ran successfully and saved this file, "
                                 "and that config.json locators are correct.")
    
    verification_df = pd.read_csv(VERIFICATION_CSV_PATH_NB02_INPUT)
    print(f"MRI verification data loaded successfully. Shape: {verification_df.shape}")

    # Extract the set of MRI IDs that passed verification (e.g., folder exists and contains valid scans)
    # The criteria for "passed" might be more stringent, e.g., 'mprs_found_count' > 0
    # For now, using 'mri_folder_exists'
    if 'mri_folder_exists' in verification_df.columns and 'mri_id' in verification_df.columns:
        verified_mri_ids = set(verification_df[verification_df['mri_folder_exists'] == True]['mri_id'].unique())
        print(f"Found {len(verified_mri_ids)} unique MRI IDs with existing scan folders based on verification file.")
    else:
        print("Warning: 'mri_folder_exists' or 'mri_id' column not found in verification CSV. Cannot determine verified MRI IDs.")
        # verified_mri_ids remains an empty set

    if run: 
        run.log({'cohort_definition/00_input_verified_mri_ids_count': len(verified_mri_ids)})
        # Log the verification_details.csv used as an input artifact to this run
        verif_input_artifact_name = f"input_verification_details_{DATASET_IDENTIFIER}"
        verif_input_artifact = wandb.Artifact(
            verif_input_artifact_name,
            type="input_metadata",
            description=f"MRI verification details CSV used as input for cohort definition. From NB01 output: {VERIFICATION_CSV_PATH_NB02_INPUT.name}"
        )
        verif_input_artifact.add_file(str(VERIFICATION_CSV_PATH_NB02_INPUT))
        run.log_artifact(verif_input_artifact)
        print(f"Logged '{VERIFICATION_CSV_PATH_NB02_INPUT.name}' as an input artifact to W&B.")


except FileNotFoundError as e_fnf_verif:
    print(f"CRITICAL ERROR: {e_fnf_verif}")
    print("Cannot proceed with cohort definition without MRI verification results.")
    if run: run.finish(exit_code=1)
    # exit()
except Exception as e_load_verif:
    print(f"CRITICAL ERROR occurred while loading MRI verification results: {e_load_verif}")
    if run: run.finish(exit_code=1)
    # exit()

# Ensure verification_df and verified_mri_ids are defined for subsequent cells
if verification_df is None:
    verification_df = pd.DataFrame()
# verified_mri_ids is already initialized to an empty set

## 5. Cohort Definition - Step 1: Filter by Baseline CDR

Apply the first inclusion criterion based on the subject's cognitive status at their *first available visit within the raw clinical data*. Subjects are included if their baseline CDR score is 0.0 (Cognitively Normal) or 0.5 (Very Mild Cognitive Impairment). The number of unique subjects and total visits remaining after this filter are logged.

In [None]:
# --- Cohort Definition Step 1: Filter by Baseline CDR ---
print("\n--- Applying Baseline CDR Filter (Keeping CDR=0.0 and CDR=0.5) ---")

df_baseline_filtered = pd.DataFrame() # Initialize
num_subjects_baseline_criteria = 0

if 'clinical_df_raw' in locals() and not clinical_df_raw.empty:
    # Define the baseline CDR criteria
    baseline_cdr_criteria = [0.0, 0.5]
    print(f"Baseline CDR inclusion criteria: {baseline_cdr_criteria}")

    # Ensure data is sorted by 'Subject ID' and 'Visit' to correctly identify the first visit
    # Using a copy to avoid modifying clinical_df_raw if it's used elsewhere unfiltered
    clinical_df_for_baseline_check = clinical_df_raw.sort_values(by=['Subject ID', 'Visit'])
    
    # Get data for the first visit of each subject
    # .loc[...idxmin()] is a robust way to get the entire row for the first visit
    first_visit_data_df = clinical_df_for_baseline_check.loc[
        clinical_df_for_baseline_check.groupby('Subject ID')['Visit'].idxmin()
    ]

    # Identify subjects meeting the baseline CDR criteria
    if 'CDR' in first_visit_data_df.columns and 'Subject ID' in first_visit_data_df.columns:
        subjects_meeting_criteria_arr = first_visit_data_df[
            first_visit_data_df['CDR'].isin(baseline_cdr_criteria)
        ]['Subject ID'].unique()
        num_subjects_baseline_criteria = len(subjects_meeting_criteria_arr)
        print(f"Found {num_subjects_baseline_criteria} unique subjects meeting baseline CDR criteria.")

        # Filter the original raw dataframe to keep all visits of these selected subjects
        df_baseline_filtered = clinical_df_raw[
            clinical_df_raw['Subject ID'].isin(subjects_meeting_criteria_arr)
        ].copy() # Use .copy() to avoid SettingWithCopyWarning on slices
        print(f"DataFrame shape after baseline CDR filter: {df_baseline_filtered.shape} "
              f"({df_baseline_filtered['Subject ID'].nunique()} subjects)")
    else:
        print("Warning: 'CDR' or 'Subject ID' column not found in first_visit_data_df. Cannot apply baseline CDR filter.")
        df_baseline_filtered = clinical_df_raw.copy() # Proceed with unfiltered data if CDR filter fails

    # Log results to W&B
    if run:
        # Log the criteria used to W&B config for this run
        wandb.config.update({'cohort_criteria/baseline_cdr_included': baseline_cdr_criteria}, allow_val_change=True)
        run.log({
            'cohort_definition/01_subjects_meeting_baseline_cdr': num_subjects_baseline_criteria,
            'cohort_definition/01_visits_after_baseline_cdr_filter': len(df_baseline_filtered),
            'cohort_definition/01_subjects_after_baseline_cdr_filter': df_baseline_filtered['Subject ID'].nunique()
        })
else:
    print("Skipping baseline CDR filter as raw clinical data (clinical_df_raw) is not available or empty.")
    # df_baseline_filtered remains an empty DataFrame

## 6. Cohort Definition - Step 2: Filter by Minimum Number of Visits

To ensure sufficient longitudinal data for sequence modeling, subjects are filtered based on a minimum number of recorded visits. This section first analyzes the distribution of visit counts for subjects who met the baseline CDR criteria. Based on this distribution, a data-informed threshold for the minimum number of visits is determined and applied. Statistics before and after this filter, along with the chosen threshold, are logged to W&B. A visualization of the visit count distribution is also generated.

In [None]:
# --- Cohort Definition Step 2: Check and Apply Minimum Visits Filter ---
print("\n--- Checking and Applying Minimum Visits Filter ---")

df_min_visits_filtered = pd.DataFrame() # Initialize
num_subjects_min_visits = 0
min_visits_required = 0 # Initialize

# Ensure df_baseline_filtered exists and is not empty from the previous step
if 'df_baseline_filtered' in locals() and not df_baseline_filtered.empty:
    # Count visits per subject *within the baseline-filtered group*
    visits_per_subject_filtered = df_baseline_filtered.groupby('Subject ID')['Visit'].count()

    # --- Analyze visit counts before setting threshold ---
    print("\nDistribution of visit counts (for subjects meeting baseline CDR criteria):")
    visit_counts_distribution = visits_per_subject_filtered.value_counts().sort_index()
    print(visit_counts_distribution)

    total_subjects_after_cdr_filter = df_baseline_filtered['Subject ID'].nunique() # More robust count
    
    # Calculate counts and percentages for different visit thresholds
    count_ge_2_visits = sum(visits_per_subject_filtered >= 2)
    count_ge_3_visits = sum(visits_per_subject_filtered >= 3)
    count_ge_4_visits = sum(visits_per_subject_filtered >= 4)

    percent_ge_2_visits = count_ge_2_visits / total_subjects_after_cdr_filter if total_subjects_after_cdr_filter > 0 else 0
    percent_ge_3_visits = count_ge_3_visits / total_subjects_after_cdr_filter if total_subjects_after_cdr_filter > 0 else 0
    percent_ge_4_visits = count_ge_4_visits / total_subjects_after_cdr_filter if total_subjects_after_cdr_filter > 0 else 0

    print(f"\nSubjects meeting baseline CDR criteria: {total_subjects_after_cdr_filter}")
    print(f"  Number with >= 2 visits: {count_ge_2_visits} ({percent_ge_2_visits:.1%})")
    print(f"  Number with >= 3 visits: {count_ge_3_visits} ({percent_ge_3_visits:.1%})")
    print(f"  Number with >= 4 visits: {count_ge_4_visits} ({percent_ge_4_visits:.1%})")

    # Log cohort check stats to W&B
    if run:
        run.log({
            'cohort_check/01_total_baseline_criteria_subjects': total_subjects_after_cdr_filter,
            'cohort_check/02_subjects_ge_2_visits': count_ge_2_visits,
            'cohort_check/03_subjects_ge_3_visits': count_ge_3_visits,
            'cohort_check/04_subjects_ge_4_visits': count_ge_4_visits,
            'cohort_check/05_percent_ge_2_visits': percent_ge_2_visits,
            'cohort_check/06_percent_ge_3_visits': percent_ge_3_visits,
            'cohort_check/07_percent_ge_4_visits': percent_ge_4_visits
        })
        try:
            visit_counts_table_df = visit_counts_distribution.reset_index()
            visit_counts_table_df.columns = ['number_of_visits', 'subject_count'] # Rename for clarity
            visit_counts_wandb_table = wandb.Table(dataframe=visit_counts_table_df)
            run.log({"cohort_check/08_visit_count_distribution_table": visit_counts_wandb_table})
        except Exception as e_wandb_table:
            print(f"Warning: Could not log visit count distribution table to W&B. Error: {e_wandb_table}")

    # Visualize visit counts distribution
    fig_visit_counts, ax_visit_counts = plt.subplots(figsize=(10, 6))
    sns.countplot(x=visits_per_subject_filtered, ax=ax_visit_counts, color='skyblue', stat='count')
    ax_visit_counts.set_title(f'Visit Counts per Subject (After Baseline CDR Filter: {baseline_cdr_criteria})')
    ax_visit_counts.set_xlabel('Number of Visits Recorded per Subject')
    ax_visit_counts.set_ylabel('Number of Subjects')
    finalize_plot(fig_visit_counts, plt, run, 
                  f"charts_cohort_def_{DATASET_IDENTIFIER}/01_visit_counts_distribution", 
                  output_dir / '01_cohort_visit_counts.png')

    # --- Make data-driven decision on min_visits_required ---
    # Example logic: Prioritize >=3 visits if a substantial portion of the cohort has them.
    # Adjust thresholds and cohort size checks (e.g., count_ge_2_visits > 30) as needed for the specific dataset.
    if percent_ge_3_visits >= 0.40 and count_ge_3_visits > 20: # If >=40% have 3+ visits AND it's a reasonable number
        min_visits_required = 3
    elif count_ge_2_visits > 20: # Else, if enough subjects have at least 2 visits
        min_visits_required = 2
    else: # Fallback if cohort becomes very small
        min_visits_required = 2 # Or 1, depending on absolute minimum for modeling
        print(f"Warning: Low number of subjects with multiple visits. Setting min_visits_required = {min_visits_required}. "
              "Consider re-evaluating baseline criteria if final cohort size is too small.")
    print(f"\nDecision: Setting minimum visits required per subject = {min_visits_required}.")

    if run:
        wandb.config.update({'cohort_criteria/min_visits_required': min_visits_required}, allow_val_change=True)

    # Identify subjects meeting the minimum visit count
    subjects_with_enough_visits = visits_per_subject_filtered[visits_per_subject_filtered >= min_visits_required].index
    num_subjects_min_visits = len(subjects_with_enough_visits)
    print(f"Found {num_subjects_min_visits} unique subjects meeting baseline CDR and >= {min_visits_required} visits criteria.")

    # Filter the DataFrame to keep all visits of these subjects
    df_min_visits_filtered = df_baseline_filtered[df_baseline_filtered['Subject ID'].isin(subjects_with_enough_visits)].copy()
    print(f"DataFrame shape after min visits filter: {df_min_visits_filtered.shape} ({df_min_visits_filtered['Subject ID'].nunique()} subjects)")

    if run:
        run.log({
            'cohort_definition/02_subjects_after_min_visits_filter': num_subjects_min_visits,
            'cohort_definition/02_visits_after_min_visits_filter': len(df_min_visits_filtered)
        })
else:
    print("Skipping minimum visits filter as 'df_baseline_filtered' is empty or not defined.")
    # Ensure df_min_visits_filtered is an empty DataFrame for subsequent steps if pipeline were to continue
    df_min_visits_filtered = pd.DataFrame()

## 7. Cohort Definition - Step 3: Filter by MRI Availability

The final cohort inclusion step ensures that all selected visits have corresponding, successfully verified MRI scan files. This filter uses the set of `verified_mri_ids` generated in Notebook 01 (and loaded earlier in this notebook). Visits without a verified MRI are removed. The number of visits removed and the final cohort size (subjects and visits) are logged.

In [None]:
# --- Cohort Definition Step 3: Filter by MRI Availability ---
print("\n--- Applying MRI Verification Filter ---")

cohort_df_final = pd.DataFrame() # Initialize
final_subjects = 0
final_visits = 0

# Ensure df_min_visits_filtered exists and is not empty from the previous step
# Also ensure verified_mri_ids was loaded and is a set
if 'df_min_visits_filtered' in locals() and not df_min_visits_filtered.empty and \
   'verified_mri_ids' in locals() and isinstance(verified_mri_ids, set):

    initial_visits_before_mri_filter = len(df_min_visits_filtered)
    
    # Filter to keep only rows where 'MRI ID' is in the set of verified_mri_ids
    # Assumes 'MRI ID' column exists in df_min_visits_filtered
    if 'MRI ID' in df_min_visits_filtered.columns:
        cohort_df_final = df_min_visits_filtered[df_min_visits_filtered['MRI ID'].isin(verified_mri_ids)].copy()
        
        final_subjects = cohort_df_final['Subject ID'].nunique() if 'Subject ID' in cohort_df_final.columns else 0
        final_visits = len(cohort_df_final)
        visits_removed_by_mri_filter = initial_visits_before_mri_filter - final_visits

        print(f"Removed {visits_removed_by_mri_filter} visits due to missing or unverified MRI scans.")
        print(f"Final cohort for modeling: {final_visits} visits from {final_subjects} subjects.")

        if run:
            run.log({
                'cohort_definition/03_visits_before_mri_filter': initial_visits_before_mri_filter,
                'cohort_definition/03_visits_removed_by_mri_filter': visits_removed_by_mri_filter,
                'cohort_definition/03_final_subject_count': final_subjects,
                'cohort_definition/03_final_visit_count': final_visits
            })
    else:
        print("Warning: 'MRI ID' column not found in DataFrame. Cannot apply MRI verification filter.")
        cohort_df_final = df_min_visits_filtered.copy() # No filtering done
        final_subjects = cohort_df_final['Subject ID'].nunique() if 'Subject ID' in cohort_df_final.columns else 0
        final_visits = len(cohort_df_final)
        print(f"Proceeding with {final_visits} visits from {final_subjects} subjects without MRI filter.")


    if final_visits == 0:
        print("CRITICAL ERROR: No visits remaining after applying all filters. "
              "Check data, verification CSV, and cohort definition criteria.")
        if run: run.finish(exit_code=1)
        # exit()
else:
    print("Skipping MRI verification filter as preceding DataFrame ('df_min_visits_filtered') is empty or not defined, "
          "or 'verified_mri_ids' not available.")
    # Ensure cohort_df_final is an empty DataFrame if pipeline were to continue
    cohort_df_final = pd.DataFrame()

## 8. Final Cohort Summary and Saving

Summarize the characteristics of the `cohort_df_final` after all inclusion/exclusion criteria have been applied. This final cohort DataFrame is then saved locally to a CSV file (e.g., `final_analysis_cohort_oasis2.csv`) in this notebook's designated output directory. This CSV file is also logged as a versioned data artifact to Weights & Biases, making it easily accessible for downstream notebooks (e.g., Notebook 03 for feature engineering).

In [None]:
# --- Summarize, Save Final Cohort, and Log as W&B Artifact ---
print("\n--- Final Cohort Defined & Saving ---")

# Ensure cohort_df_final and other necessary variables are defined
if 'cohort_df_final' in locals() and not cohort_df_final.empty and \
   'final_subjects' in locals() and 'final_visits' in locals() and \
   'baseline_cdr_criteria' in locals() and 'min_visits_required' in locals():

    print(f"Final Cohort Summary for {DATASET_IDENTIFIER.upper()}:")
    print(f"  Total Unique Subjects: {final_subjects}")
    print(f"  Total Visits (Scan Sessions): {final_visits}")
    print(f"  Applied Baseline CDR criteria: {baseline_cdr_criteria}")
    print(f"  Applied Minimum Visits criteria: >= {min_visits_required}")
    print(f"  Applied MRI Verified criteria: Yes (visits kept only if MRI was verified in NB01)")

    # Define path for saving the final cohort DataFrame
    final_cohort_filename = f"final_analysis_cohort_{DATASET_IDENTIFIER}.csv"
    final_cohort_path = output_dir / final_cohort_filename # output_dir defined in Cell 3
    
    try:
        cohort_df_final.to_csv(final_cohort_path, index=False)
        print(f"\nFinal cohort DataFrame saved locally to: {final_cohort_path}")

        # Log final cohort DataFrame as a W&B artifact
        if run:
            print("Logging final cohort DataFrame as W&B artifact...")
            # Create a descriptive artifact name including key criteria
            artifact_name_cohort = (f"analysis-cohort-{DATASET_IDENTIFIER}"
                                    f"-BCDR_{'_'.join(map(str, baseline_cdr_criteria)).replace('.', 'p')}"
                                    f"-MinV_{min_visits_required}")
            
            cohort_artifact_description = (
                f"Final analysis cohort for {DATASET_IDENTIFIER} after all filtering criteria. "
                f"Baseline CDR in {baseline_cdr_criteria}, Minimum Visits >= {min_visits_required}, MRI verified."
            )
            cohort_artifact_metadata = {
                'num_subjects': final_subjects, 
                'num_visits': final_visits,
                'baseline_cdr_criteria_applied': baseline_cdr_criteria, 
                'min_visits_required_applied': min_visits_required,
                'mri_verified_filter_applied': True,
                'source_notebook': f"{NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER}",
                'dataset_identifier': DATASET_IDENTIFIER
            }
            
            final_cohort_wandb_artifact = wandb.Artifact(
                artifact_name_cohort,
                type=f"processed_dataset_{DATASET_IDENTIFIER}", # Consistent type for processed datasets
                description=cohort_artifact_description,
                metadata=cohort_artifact_metadata
            )
            final_cohort_wandb_artifact.add_file(str(final_cohort_path), name=final_cohort_filename) # Add the saved CSV
            run.log_artifact(final_cohort_wandb_artifact, aliases=["latest_cohort", f"final_{time.strftime('%Y%m%d')}"])
            print(f"Final cohort artifact '{artifact_name_cohort}' logged to W&B.")

    except Exception as e_save_final_cohort:
        print(f"Warning: Could not save or log final cohort DataFrame. Error: {e_save_final_cohort}")
else:
    print("Final cohort (cohort_df_final) is empty or key variables for summary are missing. Skipping final summary and save.")

In [None]:
# --- Define Prediction Task ---
print("\n--- Defining Prediction Task ---")

### Prediction Task Definition (Phase 1 & 2)


* **Target Variable:** Predict the **CDR score** at the next available visit (visit `k+1`).
* **Input Features Strategy:** Use features from all available prior visits up to and including the current visit (visit `k`).
* **Feature Types:**
    * **Time-Varying:** Clinical scores (e.g., MMSE at visit `k`), Age (at visit `k`), time since baseline/previous visit, MRI-derived features (from scan at visit `k`).
    * **Static (Planned):** Baseline CDR, Baseline MMSE, Sex, Education (EDUC), SES. These will be concatenated to the input at each time step.


In [None]:
if run:
    wandb.config.update({
        'prediction/target_variable': 'CDR_next_visit',
        'prediction/input_strategy': 'all_prior_visits_plus_static',
        'prediction/time_varying_features': ['Age_visit', 'MMSE_visit', 'MRI_features_visit', 'Time_interval'], # Example list
        'prediction/static_features_planned': ['Baseline_CDR', 'Baseline_MMSE', 'Sex', 'EDUC', 'SES'] # Planned
    })
    print("Prediction task configuration logged to W&B.")


In [None]:
# --- Note on Next Steps: Preprocessing ---
print("\n--- Next Steps: Preprocessing ---")

The next stage involves preprocessing the `final_analysis_cohort.csv` data to make it suitable for sequence modeling:


1.  **Feature Engineering:** Create necessary features like 'Time since baseline/previous visit'. Extract Baseline CDR/MMSE to be used as static features.
2.  **Sequence Creation:** Group data by subject and create sequences of visits. Define the input sequence (visits 1 to k) and target (CDR at visit k+1) for each prediction point. Handle sequences of varying lengths (padding/masking).
3.  **Data Splitting:** Split subjects into Training, Validation, and Test sets *before* any scaling or imputation that involves learning parameters from the data. Ensure subjects from the same family (if applicable) stay in the same split.
4.  **Clinical Feature Scaling:** Scale numerical clinical features (e.g., Age, MMSE) appropriately (e.g., StandardScaler fit on the training set).
5.  **Missing Value Imputation (Within Sequence):** Decide on a strategy for handling missing values *within* the time-varying features of a sequence (e.g., forward fill, mean imputation based on training set, model-based imputation).
6.  **MRI Preprocessing:** Define and implement the pipeline to process the verified T1w NIfTI files (e.g., registration, skull stripping, feature extraction using 3D CNN or ViT). This is a major separate step.
7.  **Combine & Save Processed Data:** Integrate clinical sequences and pointers to processed MRI features, saving the final model-ready data splits.

## Finalize Run

Finish the Weights & Biases run associated with this cohort definition process.

In [None]:
print("\n--- Cohort Definition complete. Finishing W&B run. ---")
if run:
    run.finish()
    print("W&B run finished.")
else:
    print("No active W&B run to finish.")

print("\nScript execution finished.")