In [None]:
# In: notebooks/02_Cohort_Definition.ipynb
# Purpose: Define the analysis cohort for longitudinal prediction based on OASIS-2 data.
#          Applies criteria for baseline status, minimum visits, and MRI availability.
#          Logs decisions and outputs the final cohort definition.

In [None]:
# --- Import Libraries ---
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
import os
import time
import json
from pathlib import Path

In [None]:
# --- Config Loading ---
print("--- Loading Configuration ---")
CONFIG_PATH = Path('../config.json') # Path relative to the notebook location
try:
    PROJECT_ROOT = CONFIG_PATH.parent.resolve()
    print(f"Project Root detected as: {PROJECT_ROOT}")

    with open(CONFIG_PATH, 'r', encoding='utf-8') as f:
        config = json.load(f)
    print("Configuration loaded successfully.")

    # Define key variables from config
    INPUT_DATA_PATH = PROJECT_ROOT / config['data']['clinical_excel_path']
    OUTPUT_DIR_BASE = PROJECT_ROOT / config['data']['output_dir_base']
    WANDB_PROJECT = config['wandb']['project_name']
    WANDB_ENTITY = config['wandb'].get('entity', None)

    # Define specific output dir for this notebook and create it
    NOTEBOOK_NAME = "02_Cohort_Definition"
    output_dir = OUTPUT_DIR_BASE / NOTEBOOK_NAME
    output_dir.mkdir(parents=True, exist_ok=True)
    print(f"Outputs will be saved to: {output_dir}")

    # Define path to verification results from Notebook 01
    NB01_OUTPUT_DIR = OUTPUT_DIR_BASE / "01_Data_Exploration"
    VERIFICATION_CSV_PATH = NB01_OUTPUT_DIR / "verification_details.csv" # ASSUMPTION: NB01 saved this

except FileNotFoundError:
    print(f"Error: Configuration file not found at {CONFIG_PATH}")
    exit()
except KeyError as e:
    print(f"Error: Missing key {e} in configuration file.")
    exit()
except Exception as e:
    print(f"An error occurred loading the config file: {e}")
    exit()

In [None]:
# --- Helper Functions ---
def finalize_plot(fig, run, wandb_key, save_path):
    """Handles logging plot to W&B, saving locally, showing, and closing."""
    if run:
        run.log({wandb_key: wandb.Image(fig)})
    if save_path:
        try:
            fig.savefig(save_path, bbox_inches='tight')
        except Exception as e:
            print(f"Warning: Could not save plot to {save_path}. Error: {e}")
    plt.show()
    plt.close(fig)

In [None]:
# --- Initialize W&B Run ---
print("\n--- Initializing Weights & Biases Run ---")
run = None # Initialize run to None
try:
    run = wandb.init(
        project=WANDB_PROJECT,
        entity=WANDB_ENTITY,
        job_type="cohort-definition",
        name=f"{NOTEBOOK_NAME}-run-{time.strftime('%Y%m%d-%H%M')}",
        config={ # Log key config choices for this job
            "source_data_path": str(INPUT_DATA_PATH),
            "verification_data_path": str(VERIFICATION_CSV_PATH)
            # Cohort criteria will be added via wandb.config.update()
        }
    )
    print(f"W&B run '{run.name}' initialized successfully. View at: {run.url}")
except Exception as e:
    print(f"Error initializing W&B: {e}")
    print("Proceeding without W&B logging.")

## Load Clinical Data

Load the raw longitudinal clinical and demographic data from the specified Excel file using pandas. We also initialize Weights & Biases here to track this exploration run and log the source data as an artifact for reproducibility.

In [None]:
print(f"\n--- Loading Raw Clinical Data from: {INPUT_DATA_PATH} ---")
try:
    if not INPUT_DATA_PATH.is_file():
         raise FileNotFoundError(f"Input data file not found at {INPUT_DATA_PATH}")
    clinical_df_raw = pd.read_excel(INPUT_DATA_PATH)
    print(f"Raw clinical data loaded successfully. Shape: {clinical_df_raw.shape}")
    if run: run.log({'cohort_definition/00_raw_rows': len(clinical_df_raw)})

except Exception as e:
    print(f"Error loading clinical data: {e}")
    if run: run.finish()
    exit()

## Load MRI Verification Results

Load the detailed verification results (`verification_details.csv`) saved by Notebook 01. This file contains information on which `MRI ID`s correspond to successfully located raw scan files (`.img` + `.hdr` pairs) on the local disk. This is needed to ensure our final cohort only includes visits with available imaging data.

In [None]:
print(f"\n--- Loading MRI Verification Results from: {VERIFICATION_CSV_PATH} ---")
try:
    if not VERIFICATION_CSV_PATH.is_file():
         raise FileNotFoundError(f"Verification results file not found at {VERIFICATION_CSV_PATH}. Please ensure Notebook 01 saved it.")
    verification_df = pd.read_csv(VERIFICATION_CSV_PATH)
    # --- Get the set of MRI IDs that passed verification ---
    # Adjust criteria if needed, e.g., check mprs_found_count > 0
    verified_mri_ids = set(verification_df[verification_df['mri_folder_exists'] == True]['mri_id'].unique())
    print(f"Loaded verification results. Found {len(verified_mri_ids)} unique MRI IDs with existing folders.")
    if run: run.log({'cohort_definition/00_verified_mri_ids': len(verified_mri_ids)})

except FileNotFoundError as e:
    print(f"Error: {e}")
    print("Cannot proceed without verification results to filter the cohort.")
    if run: run.finish()
    exit()
except Exception as e:
    print(f"An error occurred loading the verification results: {e}")
    if run: run.finish()
    exit()

## Cohort Definition Step 1: Filter by Baseline CDR

Apply the first inclusion criterion based on the subject's cognitive status at their first available visit. We include subjects whose baseline CDR score was 0.0 (Cognitively Normal) or 0.5 (Mild Cognitive Impairment). Log the number of subjects and visits remaining after this filter.

In [None]:
print("\n--- Applying Baseline CDR Filter (Keeping CDR=0 and CDR=0.5) ---")
baseline_cdr_criteria = [0.0, 0.5]

# Find first visit data for each subject
# Ensure data is sorted by visit to correctly identify the first
clinical_df_raw_sorted = clinical_df_raw.sort_values(by=['Subject ID', 'Visit'])
first_visit_data = clinical_df_raw_sorted.loc[clinical_df_raw_sorted.groupby('Subject ID')['Visit'].idxmin()]

# Identify subjects meeting baseline criteria
subjects_meeting_baseline_cdr = first_visit_data[first_visit_data['CDR'].isin(baseline_cdr_criteria)]['Subject ID'].unique()
num_subjects_baseline_criteria = len(subjects_meeting_baseline_cdr)
print(f"Found {num_subjects_baseline_criteria} unique subjects with baseline CDR in {baseline_cdr_criteria}.")

# Filter the main dataframe to keep only visits from these subjects
df_baseline_filtered = clinical_df_raw[clinical_df_raw['Subject ID'].isin(subjects_meeting_baseline_cdr)].copy()
print(f"DataFrame shape after baseline CDR filter: {df_baseline_filtered.shape}")

if run:
    wandb.config.update({'cohort_criteria/baseline_cdr_included': baseline_cdr_criteria})
    run.log({
        'cohort_definition/01_subjects_after_baseline_cdr_filter': num_subjects_baseline_criteria,
        'cohort_definition/01_visits_after_baseline_cdr_filter': len(df_baseline_filtered)
    })

## Cohort Definition Step 2: Check and Apply Minimum Visits Filter

Analyze the distribution of the total number of visits for the subjects selected in Step 1. Based on this distribution (aiming to balance longitudinal information content with cohort size), make a data-driven decision for the minimum number of visits required per subject (e.g., >=2 or >=3). Apply this filter and log the chosen criterion and resulting cohort size.

In [None]:
print("\n--- Checking and Applying Minimum Visits Filter ---")

if df_baseline_filtered.empty:
    print("No subjects remaining after baseline filter. Stopping.")
    if run: run.finish()
    exit()

# Count visits per subject *within the baseline-filtered group*
visits_per_subject_filtered = df_baseline_filtered.groupby('Subject ID')['Visit'].count()

# --- Analyze visit counts before setting threshold ---
print("\nDistribution of visit counts (for subjects meeting baseline criteria):")
visit_counts_dist = visits_per_subject_filtered.value_counts().sort_index()
print(visit_counts_dist)

total_subjects_step1 = num_subjects_baseline_criteria
count_ge_2 = sum(visits_per_subject_filtered >= 2)
count_ge_3 = sum(visits_per_subject_filtered >= 3)
count_ge_4 = sum(visits_per_subject_filtered >= 4)

percent_ge_2 = count_ge_2 / total_subjects_step1 if total_subjects_step1 > 0 else 0
percent_ge_3 = count_ge_3 / total_subjects_step1 if total_subjects_step1 > 0 else 0
percent_ge_4 = count_ge_4 / total_subjects_step1 if total_subjects_step1 > 0 else 0

print(f"\nSubjects meeting baseline criteria: {total_subjects_step1}")
print(f"Number with >= 2 visits: {count_ge_2} ({percent_ge_2:.1%})")
print(f"Number with >= 3 visits: {count_ge_3} ({percent_ge_3:.1%})")
print(f"Number with >= 4 visits: {count_ge_4} ({percent_ge_4:.1%})")

# Log cohort check stats
if run:
    run.log({
        'cohort_check/total_baseline_criteria_subjects': total_subjects_step1,
        'cohort_check/subjects_ge_2_visits': count_ge_2,
        'cohort_check/subjects_ge_3_visits': count_ge_3,
        'cohort_check/subjects_ge_4_visits': count_ge_4,
        'cohort_check/percent_ge_2_visits': percent_ge_2,
        'cohort_check/percent_ge_3_visits': percent_ge_3,
        'cohort_check/percent_ge_4_visits': percent_ge_4
    })
    # Log distribution table
    try:
        visit_counts_table = wandb.Table(dataframe=visit_counts_dist.reset_index().rename(columns={'index': 'num_visits', 'Visit': 'subject_count'}))
        run.log({"cohort_check/visit_count_distribution": visit_counts_table})
    except Exception as e:
        print(f"Warning: Could not log visit count table to W&B. Error: {e}")


# Visualize
fig, ax = plt.subplots(figsize=(10, 5))
sns.countplot(x=visits_per_subject_filtered, ax=ax, color='skyblue') # Use countplot directly on the Series
ax.set_title(f'Number of Visits per Subject (Baseline CDR in {baseline_cdr_criteria})')
ax.set_xlabel('Number of Visits')
ax.set_ylabel('Number of Subjects')
finalize_plot(fig, run, "charts/cohort_check/visits_per_subject", output_dir / 'cohort_visit_counts.png')

In [None]:
# --- Make data-driven decision on min_visits_required ---
# Example logic: Require >=3 if at least 40% have it, otherwise require >=2 if cohort size is decent
if percent_ge_3 >= 0.40:
    min_visits_required = 3
    print(f"\nDecision: >= 40% ({percent_ge_3:.1%}) of subjects have 3+ visits. Setting min_visits_required = 3.")
elif count_ge_2 > 50 : # Ensure at least a reasonable number of subjects have >= 2 visits
    min_visits_required = 2
    print(f"\nDecision: Less than 40% ({percent_ge_3:.1%}) of subjects have 3+ visits.")
    print(f"Relaxing criterion to min_visits_required = 2 ({percent_ge_2:.1%} have >=2 visits).")
else:
    min_visits_required = 3 # Default to 3 if cohort is very small anyway or percentages are odd
    print(f"\nWarning: Low number of subjects with multiple visits ({percent_ge_2:.1%} have >=2). Defaulting to min_visits_required = 3.")
    print("Consider re-evaluating baseline criteria if cohort size is too small.")

# Log final decision
if run:
    wandb.config.update({'cohort_criteria/min_visits_required': min_visits_required})

# Identify subjects meeting the final minimum visit count
subjects_with_enough_visits = visits_per_subject_filtered[visits_per_subject_filtered >= min_visits_required].index.unique()
num_subjects_min_visits = len(subjects_with_enough_visits)
print(f"\nFound {num_subjects_min_visits} unique subjects meeting baseline CDR and >= {min_visits_required} visits criteria.")

# Filter the DataFrame further
df_min_visits_filtered = df_baseline_filtered[df_baseline_filtered['Subject ID'].isin(subjects_with_enough_visits)].copy()
print(f"DataFrame shape after min visits filter: {df_min_visits_filtered.shape}")

if run:
    run.log({
        'cohort_definition/02_subjects_after_min_visits_filter': num_subjects_min_visits,
        'cohort_definition/02_visits_after_min_visits_filter': len(df_min_visits_filtered)
    })

## Cohort Definition Step 3: Filter by MRI Availability

Apply the final inclusion criterion by ensuring that only visits with verified corresponding MRI scan files (based on the results loaded from Notebook 01) are retained in the cohort. Log the number of visits removed at this stage.

In [None]:
print("\n--- Applying MRI Verification Filter ---")

if df_min_visits_filtered.empty:
    print("No subjects remaining after minimum visits filter. Stopping.")
    if run: run.finish()
    exit()

initial_visits_step3 = len(df_min_visits_filtered)
cohort_df_final = df_min_visits_filtered[df_min_visits_filtered['MRI ID'].isin(verified_mri_ids)].copy()
final_subjects = cohort_df_final['Subject ID'].nunique()
final_visits = len(cohort_df_final)
visits_removed_mri = initial_visits_step3 - final_visits

print(f"Removed {visits_removed_mri} visits due to missing/unverified MRI scans.")
print(f"Final cohort for modeling: {final_visits} visits from {final_subjects} subjects.")

if run:
    run.log({
        'cohort_definition/03_visits_before_mri_filter': initial_visits_step3,
        'cohort_definition/03_visits_removed_for_mri': visits_removed_mri,
        'cohort_definition/03_final_visits': final_visits,
        'cohort_definition/03_final_subjects': final_subjects
    })

if final_visits == 0:
    print("Error: No visits remaining after applying all filters. Check data and criteria.")
    if run: run.finish()
    exit()


## Final Cohort Summary and Saving

Print a summary of the final cohort characteristics (number of subjects, visits) after applying all filters. Save the resulting cohort DataFrame (`cohort_df_final`) locally as `final_analysis_cohort.csv` (in this notebook's output directory). Log this final cohort DataFrame as a versioned artifact to Weights & Biases for downstream use.

In [None]:
print("\n--- Final Cohort Defined ---")
print(f"Total Subjects: {final_subjects}")
print(f"Total Visits (Scan Sessions): {final_visits}")
print(f"Baseline CDR criteria: {baseline_cdr_criteria}")
print(f"Minimum Visits criteria: >= {min_visits_required}")
print("MRI Verified criteria: Corresponding MRI folder and img/hdr pair found.")

# Save the final cohort DataFrame
final_cohort_path = output_dir / "final_analysis_cohort.csv"
try:
    cohort_df_final.to_csv(final_cohort_path, index=False)
    print(f"Final cohort DataFrame saved locally to: {final_cohort_path}")

    # Log final cohort as W&B artifact
    if run:
        print("Logging final cohort DataFrame as W&B artifact...")
        cohort_artifact = wandb.Artifact(f"analysis_cohort-OASIS2-CDR_{'_'.join(map(str, baseline_cdr_criteria))}-MinV_{min_visits_required}",
                                         type="analysis-dataset",
                                         description=f"Final cohort data after inclusion/exclusion and MRI verification. Baseline CDR={baseline_cdr_criteria}, MinVisits>={min_visits_required}.",
                                         metadata={'num_subjects': final_subjects, 'num_visits': final_visits,
                                                   'baseline_cdr_criteria': baseline_cdr_criteria, 'min_visits_required': min_visits_required})
        cohort_artifact.add_file(str(final_cohort_path))
        run.log_artifact(cohort_artifact)
        print("Final cohort artifact logged.")

except Exception as e:
    print(f"Warning: Could not save or log final cohort DataFrame. Error: {e}")

In [None]:
# --- Define Prediction Task ---
print("\n--- Defining Prediction Task ---")

### Prediction Task Definition (Phase 1 & 2)


* **Target Variable:** Predict the **CDR score** at the next available visit (visit `k+1`).
* **Input Features Strategy:** Use features from all available prior visits up to and including the current visit (visit `k`).
* **Feature Types:**
    * **Time-Varying:** Clinical scores (e.g., MMSE at visit `k`), Age (at visit `k`), time since baseline/previous visit, MRI-derived features (from scan at visit `k`).
    * **Static (Planned):** Baseline CDR, Baseline MMSE, Sex, Education (EDUC), SES. These will be concatenated to the input at each time step.


In [None]:
if run:
    wandb.config.update({
        'prediction/target_variable': 'CDR_next_visit',
        'prediction/input_strategy': 'all_prior_visits_plus_static',
        'prediction/time_varying_features': ['Age_visit', 'MMSE_visit', 'MRI_features_visit', 'Time_interval'], # Example list
        'prediction/static_features_planned': ['Baseline_CDR', 'Baseline_MMSE', 'Sex', 'EDUC', 'SES'] # Planned
    })
    print("Prediction task configuration logged to W&B.")


In [None]:
# --- Note on Next Steps: Preprocessing ---
print("\n--- Next Steps: Preprocessing ---")

The next stage involves preprocessing the `final_analysis_cohort.csv` data to make it suitable for sequence modeling:


1.  **Feature Engineering:** Create necessary features like 'Time since baseline/previous visit'. Extract Baseline CDR/MMSE to be used as static features.
2.  **Sequence Creation:** Group data by subject and create sequences of visits. Define the input sequence (visits 1 to k) and target (CDR at visit k+1) for each prediction point. Handle sequences of varying lengths (padding/masking).
3.  **Data Splitting:** Split subjects into Training, Validation, and Test sets *before* any scaling or imputation that involves learning parameters from the data. Ensure subjects from the same family (if applicable) stay in the same split.
4.  **Clinical Feature Scaling:** Scale numerical clinical features (e.g., Age, MMSE) appropriately (e.g., StandardScaler fit on the training set).
5.  **Missing Value Imputation (Within Sequence):** Decide on a strategy for handling missing values *within* the time-varying features of a sequence (e.g., forward fill, mean imputation based on training set, model-based imputation).
6.  **MRI Preprocessing:** Define and implement the pipeline to process the verified T1w NIfTI files (e.g., registration, skull stripping, feature extraction using 3D CNN or ViT). This is a major separate step.
7.  **Combine & Save Processed Data:** Integrate clinical sequences and pointers to processed MRI features, saving the final model-ready data splits.

## Finalize Run

Finish the Weights & Biases run associated with this cohort definition process.

In [None]:
print("\n--- Cohort Definition complete. Finishing W&B run. ---")
if run:
    run.finish()
    print("W&B run finished.")
else:
    print("No active W&B run to finish.")

print("\nScript execution finished.")