# Notebook 01: OASIS-2 Dataset - Initial Exploration & MRI File Verification

**Project Phase:** 1 (Data Ingestion, Exploration, and Validation)
**Dataset:** OASIS-2 Longitudinal MRI & Clinical Data

**Purpose:**
This notebook serves as the initial step in the OASIS-2 data processing pipeline. Its primary objectives are:
1.  Load the raw OASIS-2 longitudinal clinical and demographic data from the specified Excel file.
2.  Perform a comprehensive exploratory data analysis to understand data structure, variable types, distributions, missing values, and basic relationships between key variables.
3.  Verify the local file system presence of raw T1w MPRAGE scan files (`.img` + `.hdr` pairs) corresponding to entries in the clinical data, based on paths defined in `config.json`.
4.  Log summary statistics, generated plots, and the detailed MRI scan verification results to Weights & Biases (W&B) for experiment tracking and reproducibility.
5.  Save key outputs locally, notably `verification_details.csv`, which lists the status of each scan file and will be used in subsequent pipeline stages.

**Workflow:**
1.  **Setup:** Import libraries, configure `sys.path` to access `src/` utilities, and load the main project configuration (`config.json`).
2.  **W&B Initialization:** Start a new W&B run for this data exploration task using the `initialize_wandb_run` utility. Define a general output directory for this notebook's local saves.
3.  **Load Clinical Data:** Load the OASIS-2 clinical data Excel file. Log the raw dataset table as a W&B artifact.
4.  **Initial Data Inspection:** Examine DataFrame info, descriptive statistics, and missing value patterns. Log summaries to W&B and save tables locally.
5.  **Variable Distribution Analysis:** Visualize distributions of key variables (Age, MMSE, nWBV, Group, Gender, Visits per Subject) using histograms and count plots. Log plots to W&B and save locally.
6.  **Relationship Analysis:** Explore pairwise relationships (e.g., Age vs. MMSE) using scatter plots. Visualize and log a correlation matrix.
7.  **Group-wise Analysis:** Compare variable distributions across clinical groups.
8.  **Longitudinal Aspect Exploration:** Analyze visit intervals, baseline cohort characteristics, and plot average trends and example individual trajectories for key scores.
9.  **MRI Scan File Verification:** Iterate through clinical records, check for corresponding raw MRI files (`.img`/`.hdr`) based on expected directory structure, and log verification status.
10. **Summarize & Save Verification Outputs:** Compile verification results into a DataFrame, save `verification_details.csv` and `missing_mri_folders.csv` (if any) locally. Log these as a W&B Table and Artifacts.
11. **Finalize W&B Run.**

**Input:**
* `config.json`: Main project configuration file (specifies paths, W&B details, etc.).
* OASIS-2 Clinical Data Excel file (path specified in `config.json`).
* Local directories containing raw OASIS-2 MRI scans (paths specified in `config.json`).

**Output:**
* **Local Files (in notebook-specific output directory, e.g., `notebooks/outputs/01_Data_Exploration_OASIS2/`):**
    * `descriptive_stats.csv`
    * `missing_values_summary.csv`
    * `verification_details.csv` (Key input for Notebook 02 & MRI preprocessing scripts)
    * `missing_mri_folders.csv` (If applicable)
    * Plots as PNG files.
* **W&B Run:**
    * Logged run configuration (including paths used by this notebook).
    * Summary statistics of the dataset.
    * All generated EDA plots as `wandb.Image`.
    * `verification_details.csv` as a `wandb.Table`.
    * Raw clinical data table (`INPUT_DATA_PATH`) and `missing_mri_folders.csv` as W&B Artifacts.

In [None]:
# In: notebooks/01_Data_Exploration_OASIS2.ipynb
# Purpose: Load OASIS-2 clinical data using config, perform detailed exploration
#          (including longitudinal aspects), verify MPRAGE file presence,
#          log results efficiently to W&B, and save key outputs locally.

In [None]:
# --- Import Libraries ---
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
import sys
import os
import re
import time
import json
from tqdm.auto import tqdm
from pathlib import Path

## 1. Setup: Project Configuration and Paths

This section handles the initial setup for the notebook:
* Determines the project's root directory.
* Adds the `src` directory (containing custom utility modules) to the Python system path to allow imports.
* Imports necessary custom utility functions.
* Loads the main project configuration from `config.json`.
* Defines key dataset identifiers and notebook-specific parameters.
* Resolves and prints essential input paths for clinical data and MRI base directories, and sets up the primary output directory for this notebook's locally saved files.

In [None]:
# --- Project Setup, Configuration Loading, and Utility Imports ---
print("--- Initializing Project Setup & Configuration ---")

# Initialize
PROJECT_ROOT = None
base_config = {}

try:
    # Determine project root assuming this notebook is in a 'notebooks' subdirectory
    current_notebook_path = Path.cwd() 
    potential_project_root = current_notebook_path.parent 
    if (potential_project_root / "src").is_dir() and (potential_project_root / "config.json").is_file():
        PROJECT_ROOT = potential_project_root
    else: # Fallback: assume current working directory IS the project root
        PROJECT_ROOT = current_notebook_path
    
    if not (PROJECT_ROOT / "src").is_dir() or not (PROJECT_ROOT / "config.json").is_file():
        raise FileNotFoundError(f"Could not reliably find 'src' directory or 'config.json'. "
                                f"PROJECT_ROOT determined as: {PROJECT_ROOT}. "
                                "Ensure 'config.json' is at the project root and 'src' dir exists.")

    if str(PROJECT_ROOT) not in sys.path:
        sys.path.insert(0, str(PROJECT_ROOT)) # Add project root to path for src imports
    print(f"PROJECT_ROOT: {PROJECT_ROOT}")
    print(f"Added '{str(PROJECT_ROOT)}' to sys.path for src imports.")

    # Import custom utilities AFTER path setup
    from src.wandb_utils import initialize_wandb_run 
    from src.plotting_utils import finalize_plot
    print("Successfully imported custom utilities.")

except FileNotFoundError as e_path:
    print(f"CRITICAL ERROR in project setup (paths or src): {e_path}")
    # exit() # Or raise error to stop notebook
except ImportError as e_imp:
    print(f"CRITICAL ERROR: Could not import custom utilities: {e_imp}")
    print("Ensure src/wandb_utils.py (and other required utils) exist and are error-free.")
    # exit()
except Exception as e_general_setup:
    print(f"CRITICAL ERROR during initial setup: {e_general_setup}")
    # exit()


# --- Load Main Project Configuration ---
print("\n--- Loading Main Project Configuration from config.json ---")
try:
    if PROJECT_ROOT is None: # Should have been caught above, but as a safeguard
        raise ValueError("PROJECT_ROOT was not successfully defined. Cannot load configuration.")
    CONFIG_PATH_MAIN = PROJECT_ROOT / 'config.json'
    with open(CONFIG_PATH_MAIN, 'r', encoding='utf-8') as f:
        base_config = json.load(f)
    print(f"Main project config loaded successfully from: {CONFIG_PATH_MAIN}")
except Exception as e_cfg:
    print(f"CRITICAL ERROR loading main config.json: {e_cfg}")
    # exit() 

# --- Define Dataset, Notebook Specifics, and Key Paths ---
# These are based on the loaded base_config.
DATASET_IDENTIFIER = "oasis2" # Specific to this notebook's focus
NOTEBOOK_MODULE_NAME = "01_Data_Exploration" # For naming outputs and W&B job type

# Initialize paths to None or empty lists for safety before assignment
INPUT_DATA_PATH = None
MRI_BASE_PATHS_CONFIG_RELATIVE = []
MRI_BASE_PATHS_ABSOLUTE = []
MPR_IMG_PATTERN_CONFIG = None
output_dir = None # This will be the static output dir for NB01

try:
    if not base_config: # Check if base_config loaded
        raise ValueError("base_config is empty or not loaded. Cannot define paths.")

    # Resolve input paths from base_config
    INPUT_DATA_PATH = PROJECT_ROOT / base_config['data']['clinical_excel_path']
    MRI_BASE_PATHS_CONFIG_RELATIVE = base_config['data'].get('mri_base_paths', [])
    MRI_BASE_PATHS_ABSOLUTE = [PROJECT_ROOT / p for p in MRI_BASE_PATHS_CONFIG_RELATIVE]
    MPR_IMG_PATTERN_CONFIG = re.compile(base_config['mri_verification']['mpr_img_pattern'])

    # Define the main output directory for THIS notebook (static, not W&B run-specific subfolder)
    output_dir_base_from_config = PROJECT_ROOT / base_config['data']['output_dir_base']
    output_dir = output_dir_base_from_config / f"{NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER}"
    output_dir.mkdir(parents=True, exist_ok=True) # Create if it doesn't exist

    print(f"\nKey paths defined:")
    print(f"  Input clinical data: {INPUT_DATA_PATH}")
    print(f"  MRI base paths: {[str(p) for p in MRI_BASE_PATHS_ABSOLUTE]}")
    print(f"  Notebook output directory (local saves): {output_dir}")

except KeyError as e_key:
    print(f"CRITICAL ERROR: Missing key {e_key} in base_config needed for path definitions.")
    # exit()
except Exception as e_paths:
    print(f"CRITICAL ERROR defining paths: {e_paths}")
    # exit()

## 2. Helper Function

Define helper function used within this notebook for verifying MRI scan file presence.

In [None]:
# --- Helper Function ---
print("\n--- Defining Helper Function ---")

# --- verify_scan_files function ---
# Ensure it uses the mri_base_paths_absolute and mpr_img_pattern_config defined above when called
def verify_scan_files(row: pd.Series, 
                      abs_mri_base_paths: list[Path], 
                      mpr_pattern: re.Pattern) -> dict:
    """
    Verifies presence of Analyze 7.5 .img/.hdr pairs for a given clinical data row
    by checking expected file structures within the provided MRI base paths.

    Args:
        row (pd.Series): A row from the clinical DataFrame, expected to contain
                         'Subject ID' and 'MRI ID'. Optional: 'Visit', 'Group'.
        abs_mri_base_paths (list[Path]): List of absolute Path objects for MRI base directories.
        mpr_pattern (re.Pattern): Compiled regex pattern to match MPRAGE .img filenames
                                  (e.g., to extract "mpr-1", "mpr-2").

    Returns:
        dict: A log entry dictionary with verification details for the MRI ID in the row,
              including 'mri_folder_exists', 'mprs_found_count', 'mpr_labels_found', etc.
    """
    subject_id = row['Subject ID']
    mri_id = row['MRI ID'] # This is the MRI Session ID
    mri_scan_folder_path = None # Path object to the specific .../<MRI ID>/RAW/ folder
    base_path_where_found = None # Which of the MRI_BASE_PATHS it was found under

    # Prepare initial log entry
    log_details = {
        'mri_id': mri_id,
        'subject_id': subject_id,
        'visit': row.get('Visit'), 
        'group': row.get('Group'),
        'mri_base_path_used': None,
        'mri_folder_path_checked': None, # For logging which exact folder was checked
        'mri_folder_exists': False,
        # 'mri_folder_is_dir' field removed as exists + is_dir check is implicit if folder_path is set
        'mprs_found_count': 0,
        'mpr_labels_found': [], # e.g., ['mpr-1', 'mpr-2']
        'found_three_or_more_mprs': False,
        'error_listing_dir': None # Store any OSError during directory listing
    }

    # Search for the specific MRI ID's RAW folder across the provided base paths
    for current_mri_base_path in abs_mri_base_paths:
        potential_folder = current_mri_base_path / mri_id / 'RAW'
        log_details['mri_folder_path_checked'] = str(potential_folder) # Log path being checked
        if potential_folder.is_dir():
            mri_scan_folder_path = potential_folder
            base_path_where_found = str(current_mri_base_path)
            break 
    
    log_details['mri_base_path_used'] = base_path_where_found

    if mri_scan_folder_path is None:
        # mri_folder_exists remains False, other counts remain 0/empty
        return log_details

    log_details['mri_folder_exists'] = True
    
    found_mpr_scan_pairs = {} # Stores {label: (img_path, hdr_path)}
    try:
        # Iterate through files in the .../<MRI ID>/RAW/ directory
        for file_item in mri_scan_folder_path.iterdir():
            match_result = mpr_pattern.match(file_item.name)
            if match_result: # If filename matches the MPRAGE .img pattern
                mpr_label_found = match_result.group(1) # Extracts "mpr-X"
                # Construct corresponding .hdr filename path and check for its existence
                hdr_file_path = file_item.with_name(f"{mpr_label_found}.nifti.hdr")
                if hdr_file_path.is_file(): # Check if the .hdr pair exists
                    found_mpr_scan_pairs[mpr_label_found] = (str(file_item), str(hdr_file_path))
    except OSError as e_os:
        log_details['error_listing_dir'] = str(e_os) # Log error if directory listing fails
        return log_details # Return early as we can't proceed with this scan

    num_valid_pairs_found = len(found_mpr_scan_pairs)
    log_details.update({
        'mprs_found_count': num_valid_pairs_found,
        'mpr_labels_found': sorted(list(found_mpr_scan_pairs.keys())),
        'found_three_or_more_mprs': num_valid_pairs_found >= 3
    })
    return log_details

print("Helper function defined.")

## 3. Initialize Weights & Biases Run

A new W&B run is initialized for this data exploration and MRI verification notebook.

In [None]:
# --- Initialize W&B Run ---
print("\n--- Initializing Weights & Biases Run for Notebook 01 ---")

# Configuration specific to this NB01 run to be logged to W&B
# Using resolved absolute paths for clarity in W&B logs where appropriate,
# and relative (from config) for others to show what was configured.
nb01_run_config_for_wandb = {
    "notebook_name_code": f"{NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER}", # e.g., 01_Data_Exploration_oasis2
    "dataset_source": DATASET_IDENTIFIER,
    "input_clinical_data_path_actual": str(INPUT_DATA_PATH), # Actual path used
    "mri_base_paths_configured": MRI_BASE_PATHS_CONFIG_RELATIVE, # What was in config
    "mpr_img_pattern_configured": MPR_IMG_PATTERN_CONFIG.pattern, # Regex pattern used
    "output_dir_for_local_saves": str(output_dir), # Static output dir for this notebook
    "execution_timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}

# Initialize W&B run using the utility from src/wandb_utils.py
# The 'run' object will be global for this notebook if successful
run = initialize_wandb_run(
    base_project_config=base_config, # Contains WANDB_ENTITY and WANDB_PROJECT
    job_group="DataProcessing",       # Broad category for this type of notebook
    job_specific_type=f"ExploreValidate-{DATASET_IDENTIFIER}", 
    run_specific_config=nb01_run_config_for_wandb, # Config dict for this specific run
    custom_run_name_elements=["EDA"], # Keeps run name concise and informative
    notes=f"{DATASET_IDENTIFIER.upper()}: Data exploration, EDA, and raw MRI file verification."
)

if run:
    print(f"W&B run '{run.name}' (Job Type: '{run.job_type}') initialized. View at: {run.url}")
    # Note: output_dir is already defined and created in Cell 3 for this notebook's static outputs.
    # No need to use get_notebook_run_output_dir here if NB01 uses a fixed output_dir.
else:
    print("Proceeding without W&B logging for this session (W&B run initialization failed).")
    # output_dir should still be defined from Cell 3 for local saves even if W&B fails.

## Load Clinical Data

Load the raw longitudinal clinical and demographic data from the specified Excel file using pandas. We also initialize Weights & Biases here to track this exploration run and log the source data as an artifact for reproducibility.

In [None]:
# --- Load Clinical Data ---
print(f"\n--- Loading Clinical Data from: {INPUT_DATA_PATH} ---")
clinical_df = None # Initialize
try:
    if INPUT_DATA_PATH is None or not INPUT_DATA_PATH.is_file(): # Check if path was defined and exists
         raise FileNotFoundError(f"Input data file not found or path not defined: {INPUT_DATA_PATH}")

    clinical_df = pd.read_excel(INPUT_DATA_PATH)
    print(f"Raw clinical data loaded successfully. Shape: {clinical_df.shape}")

    if clinical_df.empty:
        print("CRITICAL ERROR: Loaded clinical dataframe is empty.")
        if run: run.finish(exit_code=1) # Finish W&B run with error status

    # Log initial dataset characteristics to W&B run if active
    if run:
        run.log({'dataset_raw/num_rows': clinical_df.shape[0],
                 'dataset_raw/num_columns': clinical_df.shape[1]})
        
        # Log a preview of the raw data table as a W&B artifact
        print("Logging raw clinical data table as W&B artifact...")
        # Use a more descriptive artifact name, incorporating dataset identifier
        raw_data_artifact_name = f"raw_clinical_data_{DATASET_IDENTIFIER}"
        raw_data_artifact_description = (
            f"Raw clinical and demographic data for the {DATASET_IDENTIFIER} dataset, "
            f"loaded from {INPUT_DATA_PATH.name}."
        )
        raw_data_table_artifact = wandb.Artifact(
            raw_data_artifact_name, 
            type=f"raw-dataset", # Consistent type for raw datasets
            description=raw_data_artifact_description,
            metadata={"source_file": str(INPUT_DATA_PATH), 
                      "shape": list(clinical_df.shape),
                      "load_timestamp": time.strftime("%Y-%m-%d %H:%M:%S")}
        )
        # Add the actual data file to the artifact for full reproducibility
        raw_data_table_artifact.add_file(str(INPUT_DATA_PATH), name="source_excel_file.xlsx")
        
        # Add a sample of the table directly for quick preview in W&B UI
        raw_data_wandb_table = wandb.Table(dataframe=clinical_df.head(100)) # Log first 100 rows
        raw_data_table_artifact.add(raw_data_wandb_table, name="raw_data_preview")
        
        run.log_artifact(raw_data_table_artifact, aliases=["initial_load", f"{time.strftime('%Y%m%d')}"])
        print(f"Raw data artifact '{raw_data_artifact_name}' logged to W&B.")

except FileNotFoundError as e_fnf:
    print(f"CRITICAL ERROR: {e_fnf}")
    if run: run.finish(exit_code=1)
    # exit()
except ImportError: 
     print(f"CRITICAL ERROR loading Excel file: Missing 'openpyxl' library.")
     print("Please install 'openpyxl': pip install openpyxl")
     if run: run.finish(exit_code=1)
     # exit()
except Exception as e_load_data:
    print(f"CRITICAL ERROR occurred while loading the clinical data: {e_load_data}")
    if run: run.finish(exit_code=1)
    # exit()

# Ensure clinical_df is defined for subsequent cells, even if empty after an error (if not exiting)
if clinical_df is None:
    clinical_df = pd.DataFrame()

## 5. Initial Data Inspection: Structure, Statistics, and Missing Values

Perform basic checks on the loaded `clinical_df` DataFrame to understand its structure and identify immediate data quality issues. This includes:
* Viewing data types, non-null counts, and memory usage using `.info()`.
* Calculating descriptive statistics for all columns using `.describe(include='all')`.
* Identifying and counting missing values per column using `.isnull().sum()`.

Summaries of these inspections are printed and saved locally, and key statistics are logged to W&B.

### Missing Value Strategy Note
The missing values identified here (especially in columns like MMSE, SES) will be addressed more formally during the preprocessing stage in Notebook 04. Preprocessing, such as imputation, will be performed *after* splitting the data into training, validation, and test sets to prevent data leakage from the validation/test data into the training data's preprocessing steps.

In [None]:
# --- Initial Data Inspection: DataFrame Info and Descriptive Statistics ---
if not clinical_df.empty:
    print("\n--- Basic Data Information ---")
    print("DataFrame Info (clinical_df.info()):")
    clinical_df.info() # Prints to console

    print("\nDescriptive Statistics (clinical_df.describe(include='all')):")
    desc_stats = clinical_df.describe(include='all')
    display(desc_stats) # For better formatting
    #print(desc_stats) # Standard print

    # Save descriptive statistics locally to the notebook's output directory
    desc_stats_path = output_dir / 'descriptive_stats_raw_data.csv' # Use 'output_dir'
    try:
        desc_stats.to_csv(desc_stats_path)
        print(f"\nDescriptive statistics saved locally to: {desc_stats_path}")
    except Exception as e_save_desc:
        print(f"Warning: Could not save descriptive statistics locally. Error: {e_save_desc}")
else:
    print("\nSkipping initial data inspection as clinical_df is empty.")
    desc_stats = pd.DataFrame() # Ensure desc_stats is defined for later W&B logging cell

In [None]:
# --- Initial Data Inspection: Missing Values ---
if not clinical_df.empty:
    print("\nMissing Values per Column (showing columns with >0 missing values):")
    missing_values_series = clinical_df.isnull().sum()
    missing_values_filtered_df = missing_values_series[missing_values_series > 0].sort_values(ascending=False).reset_index()
    missing_values_filtered_df.columns = ['column_name', 'missing_count']
    
    if not missing_values_filtered_df.empty:
        # display(missing_values_filtered_df)
        print(missing_values_filtered_df)
        # Save missing values summary locally
        missing_values_path = output_dir / 'missing_values_summary_raw_data.csv' # Use 'output_dir'
        try:
            missing_values_filtered_df.to_csv(missing_values_path, index=False)
            print(f"Missing values summary saved locally to: {missing_values_path}")
        except Exception as e_save_missing:
            print(f"Warning: Could not save missing values summary locally. Error: {e_save_missing}")
    else:
        print("No missing values found in the raw clinical dataset.")
        missing_values_filtered_df = pd.DataFrame(columns=['column_name', 'missing_count']) # Ensure defined for W&B log
        missing_values_series = pd.Series(dtype=int) # Ensure defined for W&B log

else:
    print("\nSkipping missing value analysis as clinical_df is empty.")
    missing_values_filtered_df = pd.DataFrame(columns=['column_name', 'missing_count'])
    missing_values_series = pd.Series(dtype=int)

## 6. Log Summary Statistics to Weights & Biases

Key summary statistics derived from the initial data inspection are logged to the active W&B run. This provides a quick overview of the dataset characteristics within the W&B interface for this exploration stage.

In [None]:
# --- Log Key Dataset Statistics to W&B ---
if run and not clinical_df.empty: # Check if W&B run is active and DataFrame is not empty
    print("\n--- Logging Summary Dataset Statistics to W&B ---")
    
    # Ensure desc_stats, missing_values_series, and missing_values_filtered_df are defined
    # (they are initialized to empty DataFrames/Series if clinical_df was empty)
    if 'desc_stats' not in locals(): desc_stats = pd.DataFrame()
    if 'missing_values_series' not in locals(): missing_values_series = pd.Series(dtype=int)
    if 'missing_values_filtered_df' not in locals(): missing_values_filtered_df = pd.DataFrame()

    log_summary_dict = {
        'dataset/initial_num_rows': clinical_df.shape[0],
        'dataset/initial_num_columns': clinical_df.shape[1],
        'dataset/total_missing_values': int(missing_values_series.sum()), # Use the series before filtering
        'dataset/num_columns_with_missing_values': int(len(missing_values_filtered_df)),
        'dataset/memory_usage_mb': float(clinical_df.memory_usage(deep=True).sum() / (1024**2))
    }
    
    # Add descriptive stats for key numerical columns if they exist
    # These columns are commonly used in Alzheimer's research.
    key_numerical_cols_for_stats = ['Age', 'EDUC', 'SES', 'MMSE', 'CDR', 'eTIV', 'nWBV', 'ASF']
    for col_name in key_numerical_cols_for_stats:
        if col_name in desc_stats.columns:
            # Log common stats like mean, std, min, max if they are numerical
            # desc_stats[col_name] might contain non-numeric items like 'top' if include='all'
            col_stats = desc_stats[col_name]
            if pd.api.types.is_numeric_dtype(col_stats.dtype) or col_stats.name in ['count','mean','std','min','25%','50%','75%','max']: # Check if it contains numeric stats
                 for stat_name, stat_val in col_stats.items():
                    if pd.api.types.is_number(stat_val) and not pd.isna(stat_val): # Ensure it's a valid number
                        log_summary_dict[f'stats_raw/{col_name}_{stat_name.replace("%", "pct")}'] = float(stat_val)
            elif 'count' in col_stats: # Log at least count for object/categorical
                 log_summary_dict[f'stats_raw/{col_name}_count'] = float(col_stats['count'])


    try:
        run.log(log_summary_dict)
        print("Summary dataset statistics logged to W&B.")
    except Exception as e_wandb_log_stats:
        print(f"Warning: Could not log summary stats to W&B. Error: {e_wandb_log_stats}")
        
elif not clinical_df.empty:
    print("\nSkipping W&B logging of summary statistics (W&B run not active).")
else:
    print("\nSkipping W&B logging of summary statistics (clinical_df is empty).")

## 7. Analyze Variable Distributions

Visualize the distributions to understand the overall characteristics of the cohort across all recorded visits. This includes:
* Histograms for numerical variables (Age, MMSE, nWBV) to observe their spread and central tendency.
* Count plots for categorical or discrete variables (Number of Visits per Subject, Clinical Group, Gender) to understand category frequencies.

TEach plot is saved locally to this notebook's output directory and logged to the active W&B run.

In [None]:
# --- Analyze and Visualize Key Variable Distributions ---
print("\n--- Analyzing Variable Distributions ---")

if not clinical_df.empty:
    # Define a consistent order for 'Group' if it exists, for plotting
    group_order_for_plots = None
    if 'Group' in clinical_df.columns:
        group_order_for_plots = clinical_df['Group'].value_counts().index.tolist()

    # Distribution of Age
    print("Plotting distribution of Age...")
    fig_age_dist, ax_age_dist = plt.subplots(figsize=(10, 5))
    sns.histplot(data=clinical_df, x='Age', kde=True, bins=20, ax=ax_age_dist)
    ax_age_dist.set_title('Distribution of Age (All Visits)')
    ax_age_dist.set_xlabel('Age (Years)')
    ax_age_dist.set_ylabel('Frequency')
    finalize_plot(fig_age_dist, plt, run, 
                  f"charts_eda_{DATASET_IDENTIFIER}/distribution/age", 
                  output_dir / '01_age_distribution.png')

    # Distribution of MMSE scores (ensure MMSE column exists)
    if 'MMSE' in clinical_df.columns:
        print("Plotting distribution of MMSE Scores...")
        fig_mmse_dist, ax_mmse_dist = plt.subplots(figsize=(10, 5))
        sns.histplot(data=clinical_df.dropna(subset=['MMSE']), x='MMSE', kde=True, bins=15, ax=ax_mmse_dist)
        ax_mmse_dist.set_title('Distribution of MMSE Scores (All Visits)')
        ax_mmse_dist.set_xlabel('MMSE Score')
        ax_mmse_dist.set_ylabel('Frequency')
        finalize_plot(fig_mmse_dist, plt, run, 
                      f"charts_eda_{DATASET_IDENTIFIER}/distribution/mmse", 
                      output_dir / '02_mmse_distribution.png')
    else:
        print("MMSE column not found, skipping MMSE distribution plot.")

    # Distribution of nWBV (Normalized Whole Brain Volume)
    if 'nWBV' in clinical_df.columns:
        print("Plotting distribution of nWBV...")
        fig_nwbv_dist, ax_nwbv_dist = plt.subplots(figsize=(10, 5))
        sns.histplot(data=clinical_df.dropna(subset=['nWBV']), x='nWBV', kde=True, bins=20, ax=ax_nwbv_dist)
        ax_nwbv_dist.set_title('Distribution of Normalized Whole Brain Volume (nWBV - All Visits)')
        ax_nwbv_dist.set_xlabel('nWBV')
        ax_nwbv_dist.set_ylabel('Frequency')
        finalize_plot(fig_nwbv_dist, plt, run, 
                      f"charts_eda_{DATASET_IDENTIFIER}/distribution/nwbv", 
                      output_dir / '03_nwbv_distribution.png')
    else:
        print("nWBV column not found, skipping nWBV distribution plot.")

    # Count of Visits per Subject
    if 'Subject ID' in clinical_df.columns:
        print("Plotting number of visits per subject...")
        visits_per_subject = clinical_df['Subject ID'].value_counts()
        fig_visits_subj, ax_visits_subj = plt.subplots(figsize=(10, 5))
        # Plot the counts of these counts for a clearer distribution
        sns.countplot(x=visits_per_subject, ax=ax_visits_subj, color='skyblue', stat='count')
        ax_visits_subj.set_title('Distribution of Visit Counts per Subject')
        ax_visits_subj.set_xlabel('Number of Visits Recorded for a Subject')
        ax_visits_subj.set_ylabel('Number of Subjects')
        finalize_plot(fig_visits_subj, plt, run, 
                      f"charts_eda_{DATASET_IDENTIFIER}/distribution/visits_per_subject", 
                      output_dir / '04_visits_per_subject_distribution.png')
    else:
        print("'Subject ID' column not found, skipping visits per subject plot.")

    # Distribution of Clinical Groups (if 'Group' column exists)
    if 'Group' in clinical_df.columns:
        print("Plotting distribution of clinical groups...")
        fig_group_dist, ax_group_dist = plt.subplots(figsize=(8, 5))
        sns.countplot(data=clinical_df, x='Group', order=group_order_for_plots, ax=ax_group_dist)
        ax_group_dist.set_title('Distribution of Clinical Groups (All Visits)')
        ax_group_dist.set_xlabel('Clinical Group')
        ax_group_dist.set_ylabel('Number of Visits/Scans')
        finalize_plot(fig_group_dist, plt, run, 
                      f"charts_eda_{DATASET_IDENTIFIER}/distribution/group", 
                      output_dir / '05_clinical_groups_distribution.png')
    else:
        print("'Group' column not found, skipping group distribution plot.")
        
    # Distribution of Gender (M/F)
    if 'M/F' in clinical_df.columns:
        print("Plotting distribution of gender...")
        fig_gender_dist, ax_gender_dist = plt.subplots(figsize=(6, 4))
        sns.countplot(data=clinical_df, x='M/F', ax=ax_gender_dist)
        ax_gender_dist.set_title('Distribution of Gender (All Visits)')
        ax_gender_dist.set_xlabel('Gender')
        ax_gender_dist.set_ylabel('Number of Visits/Scans')
        finalize_plot(fig_gender_dist, plt, run, 
                      f"charts_eda_{DATASET_IDENTIFIER}/distribution/gender", 
                      output_dir / '06_gender_distribution.png')
    else:
        print("'M/F' column not found, skipping gender distribution plot.")
else:
    print("Skipping variable distribution analysis as clinical_df is empty.")

## 8. Analyze Relationships Between Key Variables

Explore potential pairwise relationships between important numerical and categorical variables to identify correlations or trends. This includes:
* Scatter plots to visualize relationships (e.g., Age vs. MMSE, Age vs. nWBV), often colored by clinical group to reveal group-specific patterns.
* A correlation matrix heatmap for key numerical variables to quantify linear associations.

These visualizations help in understanding how different factors co-vary within the dataset. Plots are saved locally and logged to W&B.

In [None]:
# --- Analyze and Visualize Relationships Between Key Variables ---
print("\n--- Analyzing Relationships Between Variables ---")

if not clinical_df.empty:
    # Use group_order_for_plots defined in the previous cell for consistent hue order
    current_group_order = None
    if 'Group' in clinical_df.columns:
        current_group_order = clinical_df['Group'].value_counts().index.tolist() # Ensure it's defined for this cell

    # Age vs. MMSE score
    if 'Age' in clinical_df.columns and 'MMSE' in clinical_df.columns and 'Group' in clinical_df.columns:
        print("Plotting Age vs. MMSE Score by Group...")
        fig_age_mmse, ax_age_mmse = plt.subplots(figsize=(10, 6))
        sns.scatterplot(data=clinical_df.dropna(subset=['Age', 'MMSE']), 
                        x='Age', y='MMSE', hue='Group', alpha=0.6, 
                        ax=ax_age_mmse, hue_order=current_group_order)
        ax_age_mmse.set_title('Age vs. MMSE Score by Clinical Group (All Visits)')
        ax_age_mmse.set_xlabel('Age (Years)')
        ax_age_mmse.set_ylabel('MMSE Score')
        finalize_plot(fig_age_mmse, plt, run, 
                      f"charts_eda_{DATASET_IDENTIFIER}/relationship/age_vs_mmse", 
                      output_dir / '07_age_vs_mmse_by_group.png')
    else:
        print("Skipping Age vs. MMSE plot due to missing columns (Age, MMSE, or Group).")

    # Age vs. nWBV
    if 'Age' in clinical_df.columns and 'nWBV' in clinical_df.columns and 'Group' in clinical_df.columns:
        print("Plotting Age vs. nWBV by Group...")
        fig_age_nwbv, ax_age_nwbv = plt.subplots(figsize=(10, 6))
        sns.scatterplot(data=clinical_df.dropna(subset=['Age', 'nWBV']), 
                        x='Age', y='nWBV', hue='Group', alpha=0.6, 
                        ax=ax_age_nwbv, hue_order=current_group_order)
        ax_age_nwbv.set_title('Age vs. Normalized Whole Brain Volume (nWBV) by Group (All Visits)')
        ax_age_nwbv.set_xlabel('Age (Years)')
        ax_age_nwbv.set_ylabel('nWBV')
        finalize_plot(fig_age_nwbv, plt, run, 
                      f"charts_eda_{DATASET_IDENTIFIER}/relationship/age_vs_nwbv", 
                      output_dir / '08_age_vs_nwbv_by_group.png')
    else:
        print("Skipping Age vs. nWBV plot due to missing columns (Age, nWBV, or Group).")

    # MMSE vs. nWBV
    if 'MMSE' in clinical_df.columns and 'nWBV' in clinical_df.columns and 'Group' in clinical_df.columns:
        print("Plotting MMSE vs. nWBV by Group...")
        fig_mmse_nwbv, ax_mmse_nwbv = plt.subplots(figsize=(10, 6))
        sns.scatterplot(data=clinical_df.dropna(subset=['MMSE', 'nWBV']), 
                        x='MMSE', y='nWBV', hue='Group', alpha=0.6, 
                        ax=ax_mmse_nwbv, hue_order=current_group_order)
        ax_mmse_nwbv.set_title('MMSE Score vs. nWBV by Group (All Visits)')
        ax_mmse_nwbv.set_xlabel('MMSE Score')
        ax_mmse_nwbv.set_ylabel('nWBV')
        finalize_plot(fig_mmse_nwbv, plt, run, 
                      f"charts_eda_{DATASET_IDENTIFIER}/relationship/mmse_vs_nwbv", 
                      output_dir / '09_mmse_vs_nwbv_by_group.png')
    else:
        print("Skipping MMSE vs. nWBV plot due to missing columns (MMSE, nWBV, or Group).")

    # Correlation Heatmap for Key Numerical Variables
    print("Generating Correlation Heatmap...")
    # Define numerical columns intended for correlation analysis
    # Ensure these are present in the DataFrame and are indeed numeric
    numerical_cols_for_corr = ['Age', 'EDUC', 'SES', 'MMSE', 'CDR', 'Visit', 'eTIV', 'nWBV', 'ASF']
    valid_numerical_cols_for_corr = [
        col for col in numerical_cols_for_corr 
        if col in clinical_df.columns and pd.api.types.is_numeric_dtype(clinical_df[col])
    ]
    
    if valid_numerical_cols_for_corr:
        correlation_matrix = clinical_df[valid_numerical_cols_for_corr].corr()
        fig_corr_matrix, ax_corr_matrix = plt.subplots(figsize=(12, 10)) # Adjusted size
        sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt=".2f", 
                    linewidths=.5, ax=ax_corr_matrix, annot_kws={"size": 8})
        ax_corr_matrix.set_title('Correlation Matrix of Key Numerical Variables (All Visits)')
        plt.xticks(rotation=45, ha='right') # Improve label readability
        plt.yticks(rotation=0)
        # fig_corr_matrix.tight_layout() # finalize_plot calls this with bbox_inches='tight'
        finalize_plot(fig_corr_matrix, plt, run, 
                      f"charts_eda_{DATASET_IDENTIFIER}/relationship/correlation_heatmap", 
                      output_dir / '10_correlation_matrix.png')
    else:
        print("No valid numerical columns found for correlation heatmap.")
else:
    print("Skipping variable relationship analysis as clinical_df is empty.")

## 9. Analyze Differences Between Clinical Groups

Compare the distributions of key continuous variables (e.g., MMSE, nWBV, Age) across the different clinical groups as defined in the dataset ('Nondemented', 'Converted', 'Demented'). Box plots or violin plots are used to visualize these group differences, providing insights into how these markers vary with cognitive status.

In [None]:
# --- Analyze Differences Between Clinical Groups ---
print("\n--- Analyzing Differences Between Clinical Groups ---")

if not clinical_df.empty and 'Group' in clinical_df.columns:
    # Define a consistent order for clinical groups for plotting
    # This ensures 'Nondemented' typically comes first if present.
    # Adjust if your group names or desired order differ.
    group_order_for_analysis = ['Nondemented', 'Converted', 'Demented'] 
    # Filter to include only groups present in the data, maintaining desired order
    present_groups_ordered = [g for g in group_order_for_analysis if g in clinical_df['Group'].unique()]
    if not present_groups_ordered: # Fallback if predefined order has no matches
        present_groups_ordered = clinical_df['Group'].value_counts().index.tolist()


    # MMSE scores by group
    if 'MMSE' in clinical_df.columns:
        print("Plotting MMSE Scores by Clinical Group...")
        fig_mmse_group, ax_mmse_group = plt.subplots(figsize=(8, 6))
        sns.boxplot(data=clinical_df.dropna(subset=['MMSE']), 
                    x='Group', y='MMSE', order=present_groups_ordered, ax=ax_mmse_group)
        ax_mmse_group.set_title('MMSE Scores by Clinical Group (All Visits)')
        ax_mmse_group.set_xlabel('Clinical Group')
        ax_mmse_group.set_ylabel('MMSE Score')
        finalize_plot(fig_mmse_group, plt, run, 
                      f"charts_eda_{DATASET_IDENTIFIER}/group_comparison/mmse_boxplot", 
                      output_dir / '11_mmse_by_group_boxplot.png')
    else:
        print("MMSE column not found, skipping MMSE by group plot.")

    # nWBV by group
    if 'nWBV' in clinical_df.columns:
        print("Plotting nWBV by Clinical Group...")
        fig_nwbv_group, ax_nwbv_group = plt.subplots(figsize=(8, 6))
        sns.boxplot(data=clinical_df.dropna(subset=['nWBV']), 
                    x='Group', y='nWBV', order=present_groups_ordered, ax=ax_nwbv_group)
        ax_nwbv_group.set_title('nWBV by Clinical Group (All Visits)')
        ax_nwbv_group.set_xlabel('Clinical Group')
        ax_nwbv_group.set_ylabel('Normalized Whole Brain Volume (nWBV)')
        finalize_plot(fig_nwbv_group, plt, run, 
                      f"charts_eda_{DATASET_IDENTIFIER}/group_comparison/nwbv_boxplot", 
                      output_dir / '12_nwbv_by_group_boxplot.png')
    else:
        print("nWBV column not found, skipping nWBV by group plot.")

    # Age distribution by group
    if 'Age' in clinical_df.columns:
        print("Plotting Age Distribution by Clinical Group...")
        fig_age_group, ax_age_group = plt.subplots(figsize=(8, 6))
        sns.violinplot(data=clinical_df.dropna(subset=['Age']), 
                       x='Group', y='Age', order=present_groups_ordered, ax=ax_age_group)
        ax_age_group.set_title('Age Distribution by Clinical Group (All Visits)')
        ax_age_group.set_xlabel('Clinical Group')
        ax_age_group.set_ylabel('Age (Years)')
        finalize_plot(fig_age_group, plt, run, 
                      f"charts_eda_{DATASET_IDENTIFIER}/group_comparison/age_violinplot", 
                      output_dir / '13_age_by_group_violinplot.png')
    else:
        print("Age column not found, skipping Age by group plot.")
else:
    print("Skipping group difference analysis as clinical_df is empty or 'Group' column is missing.")

## 10. Deeper Longitudinal Analysis

To better understand the temporal aspects of the OASIS-2 dataset relevant to cognitive decline, this section explores:
1.  **Visit Intervals:** Calculation and distribution of the approximate time elapsed between consecutive visits for each subject. This helps understand the follow-up patterns.
2.  **Baseline Characteristics:** Analysis of the subject cohort specifically at their first recorded visit (`Visit == 1`), looking at clinical group distribution, MMSE scores, and age.
3.  **Average Longitudinal Trends:** Plotting the average MMSE and nWBV scores across visit numbers, stratified by clinical group, to observe general progression patterns.
4.  **Individual Trajectories:** Visualizing MMSE score changes over visits for a few example subjects to highlight inter-subject variability.

All generated plots and key statistics are saved locally and logged to W&B.

In [None]:
# --- Perform Deeper Longitudinal Analysis ---
print("\n--- Performing Deeper Longitudinal Analysis ---")

if not clinical_df.empty:
    # Ensure data is sorted by Subject ID and Visit for time-based calculations
    # Using .copy() to avoid SettingWithCopyWarning on the original clinical_df
    df_for_longitudinal = clinical_df.sort_values(by=['Subject ID', 'Visit']).copy()

    # --- 1. Calculate Time Between Visits ---
    # Prioritize 'MR Delay' (days from a common baseline for the subject) if available and reliable.
    # Otherwise, approximate using 'Age'.
    print("Calculating time intervals between visits...")
    time_feature_source_used = "Unknown" # Will be updated based on method used

    if 'MR Delay' in df_for_longitudinal.columns and \
       df_for_longitudinal['MR Delay'].isnull().sum() < len(df_for_longitudinal) * 0.5 and \
       pd.api.types.is_numeric_dtype(df_for_longitudinal['MR Delay']):
        
        print("  Using 'MR Delay' (days from study baseline) to calculate time features.")
        # Ensure MR Delay is numeric
        df_for_longitudinal['MR Delay_numeric'] = pd.to_numeric(df_for_longitudinal['MR Delay'], errors='coerce')
        # Days_from_Baseline for each subject is relative to their *first visit in this dataset*
        df_for_longitudinal['BaselineVisitMRDelay'] = df_for_longitudinal.groupby('Subject ID')['MR Delay_numeric'].transform('min')
        df_for_longitudinal['Days_from_ThisCohort_Baseline'] = df_for_longitudinal['MR Delay_numeric'] - df_for_longitudinal['BaselineVisitMRDelay']
        df_for_longitudinal['Years_Since_Prev_Visit'] = df_for_longitudinal.groupby('Subject ID')['Days_from_ThisCohort_Baseline'].diff() / 365.25
        time_feature_source_used = 'MR Delay'
    
    elif 'Age' in df_for_longitudinal.columns and pd.api.types.is_numeric_dtype(df_for_longitudinal['Age']):
        print("  Warning: 'MR Delay' not suitable. Using 'Age' to approximate time features.")
        # Days_from_Baseline relative to age at first visit in this cohort
        df_for_longitudinal['BaselineAge_ThisCohort'] = df_for_longitudinal.groupby('Subject ID')['Age'].transform('min')
        df_for_longitudinal['Days_from_ThisCohort_Baseline'] = (df_for_longitudinal['Age'] - df_for_longitudinal['BaselineAge_ThisCohort']) * 365.25
        df_for_longitudinal['Years_Since_Prev_Visit'] = df_for_longitudinal.groupby('Subject ID')['Age'].diff() # Already in years
        time_feature_source_used = 'Age_approx'
    else:
        print("  Error: Neither 'MR Delay' nor 'Age' are suitable for calculating time features. Skipping interval analysis.")
        df_for_longitudinal['Years_Since_Prev_Visit'] = np.nan # Ensure column exists if later steps expect it
        df_for_longitudinal['Days_from_ThisCohort_Baseline'] = np.nan


    # First visit for each subject will have NaN for 'Years_Since_Prev_Visit', fill with 0
    df_for_longitudinal['Years_Since_Prev_Visit'] = df_for_longitudinal['Years_Since_Prev_Visit'].fillna(0)

    # Analyze distribution of intervals (excluding the first visit's 0)
    visit_intervals_years = df_for_longitudinal[df_for_longitudinal['Years_Since_Prev_Visit'] > 0]['Years_Since_Prev_Visit'].dropna()
    if not visit_intervals_years.empty:
        fig_intervals, ax_intervals = plt.subplots(figsize=(10, 5))
        sns.histplot(visit_intervals_years, kde=True, bins=20, ax=ax_intervals)
        ax_intervals.set_title(f'Distribution of Approx. Years Between Visits (Source: {time_feature_source_used})')
        ax_intervals.set_xlabel('Approx. Years Since Previous Visit')
        ax_intervals.set_ylabel('Frequency')
        finalize_plot(fig_intervals, plt, run, 
                      f"charts_eda_{DATASET_IDENTIFIER}/longitudinal/visit_interval_distribution", 
                      output_dir / '14_visit_interval_distribution.png')

        interval_stats = visit_intervals_years.describe()
        print("\nApproximate Visit Interval Stats (Years, excluding first visits):")
        print(interval_stats)
        if run:
            run.log({f'stats_longitudinal/visit_interval_{k.replace("%","pct")}': v for k, v in interval_stats.items()})
            run.config.update({'eda/time_feature_source_for_intervals': time_feature_source_used}, allow_val_change=True)
    else:
        print("\nCould not calculate meaningful visit intervals (e.g., only single visits per subject or missing time data).")

    # --- 2. Baseline Characteristics (Visit == 1 in the original Visit numbering) ---
    print("\nAnalyzing Baseline Characteristics (Original Visit == 1)...")
    baseline_df = df_for_longitudinal[df_for_longitudinal['Visit'] == 1].copy()

    group_order_for_analysis = ['Nondemented', 'Converted', 'Demented'] 
    present_groups_baseline = [g for g in group_order_for_analysis if g in baseline_df['Group'].unique()] if 'Group' in baseline_df.columns else baseline_df['Group'].value_counts().index.tolist() if 'Group' in baseline_df.columns else []


    if not baseline_df.empty:
        num_baseline_subjects = baseline_df['Subject ID'].nunique()
        print(f"Number of subjects at baseline (Visit 1): {num_baseline_subjects}")
        if run: run.log({'dataset_baseline/num_subjects': num_baseline_subjects})

        if 'Group' in baseline_df.columns:
            fig_bl_group, ax_bl_group = plt.subplots(figsize=(8, 5))
            sns.countplot(data=baseline_df, x='Group', order=present_groups_baseline, ax=ax_bl_group)
            ax_bl_group.set_title('Distribution of Clinical Groups at Baseline (Visit 1)')
            finalize_plot(fig_bl_group, plt, run, 
                          f"charts_eda_{DATASET_IDENTIFIER}/baseline/group_distribution", 
                          output_dir / '15_baseline_group_distribution.png')

        if 'MMSE' in baseline_df.columns and 'Group' in baseline_df.columns:
            fig_bl_mmse, ax_bl_mmse = plt.subplots(figsize=(8, 6))
            sns.boxplot(data=baseline_df.dropna(subset=['MMSE']), x='Group', y='MMSE', order=present_groups_baseline, ax=ax_bl_mmse)
            ax_bl_mmse.set_title('MMSE Scores at Baseline (Visit 1) by Group')
            finalize_plot(fig_bl_mmse, plt, run, 
                          f"charts_eda_{DATASET_IDENTIFIER}/baseline/mmse_boxplot", 
                          output_dir / '16_baseline_mmse_boxplot.png')
        
        if 'Age' in baseline_df.columns and 'Group' in baseline_df.columns:
            fig_bl_age, ax_bl_age = plt.subplots(figsize=(8, 6))
            sns.violinplot(data=baseline_df.dropna(subset=['Age']), x='Group', y='Age', order=present_groups_baseline, ax=ax_bl_age)
            ax_bl_age.set_title('Age Distribution at Baseline (Visit 1) by Group')
            finalize_plot(fig_bl_age, plt, run, 
                          f"charts_eda_{DATASET_IDENTIFIER}/baseline/age_violinplot", 
                          output_dir / '17_baseline_age_violinplot.png')
    else:
        print("No data found for Visit == 1 to analyze baseline characteristics.")

    # --- 3. Average MMSE/nWBV Trends Over Visits ---
    print("\nPlotting Average Longitudinal Trends by Group...")
    if 'MMSE' in df_for_longitudinal.columns and 'Visit' in df_for_longitudinal.columns and 'Group' in df_for_longitudinal.columns:
        fig_mmse_trend, ax_mmse_trend = plt.subplots(figsize=(10, 6))
        sns.lineplot(data=df_for_longitudinal.dropna(subset=['MMSE']), 
                     x='Visit', y='MMSE', hue='Group', marker='o', 
                     errorbar='sd', ax=ax_mmse_trend, hue_order=present_groups_baseline)
        ax_mmse_trend.set_title('Average MMSE Score Trend over Visits by Group')
        ax_mmse_trend.set_xlabel('Visit Number')
        ax_mmse_trend.set_ylabel('Average MMSE Score (+/- SD)')
        if not df_for_longitudinal['Visit'].empty: ax_mmse_trend.set_xticks(sorted(df_for_longitudinal['Visit'].unique()))
        finalize_plot(fig_mmse_trend, plt, run, 
                      f"charts_eda_{DATASET_IDENTIFIER}/longitudinal/mmse_trend_by_group", 
                      output_dir / '18_mmse_trend_by_group.png')
    else:
        print("Skipping MMSE trend plot due to missing columns (MMSE, Visit, or Group).")

    if 'nWBV' in df_for_longitudinal.columns and 'Visit' in df_for_longitudinal.columns and 'Group' in df_for_longitudinal.columns:
        fig_nwbv_trend, ax_nwbv_trend = plt.subplots(figsize=(10, 6))
        sns.lineplot(data=df_for_longitudinal.dropna(subset=['nWBV']), 
                     x='Visit', y='nWBV', hue='Group', marker='o', 
                     errorbar='sd', ax=ax_nwbv_trend, hue_order=present_groups_baseline)
        ax_nwbv_trend.set_title('Average nWBV Trend over Visits by Group')
        ax_nwbv_trend.set_xlabel('Visit Number')
        ax_nwbv_trend.set_ylabel('Average nWBV (+/- SD)')
        if not df_for_longitudinal['Visit'].empty: ax_nwbv_trend.set_xticks(sorted(df_for_longitudinal['Visit'].unique()))
        finalize_plot(fig_nwbv_trend, plt, run, 
                      f"charts_eda_{DATASET_IDENTIFIER}/longitudinal/nwbv_trend_by_group", 
                      output_dir / '19_nwbv_trend_by_group.png')
    else:
        print("Skipping nWBV trend plot due to missing columns (nWBV, Visit, or Group).")

    # --- 4. Individual MMSE Trajectories (Spaghetti Plot) ---
    print("\nPlotting Example Individual MMSE Trajectories...")
    if 'Subject ID' in df_for_longitudinal.columns and 'MMSE' in df_for_longitudinal.columns and 'Visit' in df_for_longitudinal.columns:
        # Select a few subjects, trying to get representation from different groups if possible
        example_subject_ids = []
        if 'Group' in df_for_longitudinal.columns:
            for grp_val in present_groups_baseline if present_groups_baseline else df_for_longitudinal['Group'].unique():
                subjects_in_grp = df_for_longitudinal[df_for_longitudinal['Group'] == grp_val]['Subject ID'].unique()
                if len(subjects_in_grp) > 0:
                    example_subject_ids.extend(np.random.choice(subjects_in_grp, size=min(2, len(subjects_in_grp)), replace=False))
        else: # If no group, just take some random subjects
            all_subjects = df_for_longitudinal['Subject ID'].unique()
            if len(all_subjects) > 0:
                example_subject_ids.extend(np.random.choice(all_subjects, size=min(5, len(all_subjects)), replace=False))
        
        example_subject_ids = list(set(example_subject_ids)) # Unique subjects

        if example_subject_ids:
            example_trajectories_df = df_for_longitudinal[df_for_longitudinal['Subject ID'].isin(example_subject_ids)].copy()
            
            fig_ind_traj, ax_ind_traj = plt.subplots(figsize=(12, 7))
            sns.lineplot(data=example_trajectories_df.dropna(subset=['MMSE']),
                         x='Visit', y='MMSE', hue='Subject ID', 
                         style='Group' if 'Group' in example_trajectories_df.columns else None, 
                         marker='o', ax=ax_ind_traj, legend="brief") # 'full' legend can be too large
            ax_ind_traj.set_title('Individual MMSE Score Trends for Example Subjects')
            ax_ind_traj.set_xlabel('Visit Number')
            ax_ind_traj.set_ylabel('MMSE Score')
            if not example_trajectories_df['Visit'].empty: ax_ind_traj.set_xticks(sorted(example_trajectories_df['Visit'].unique()))
            
            # Improve legend placement if many subjects
            if len(example_subject_ids) > 5 :
                ax_ind_traj.legend(title="Subject (Style=Group)", bbox_to_anchor=(1.05, 1), loc='upper left')
                fig_ind_traj.tight_layout(rect=[0, 0, 0.85, 1]) # Adjust layout for external legend
            else:
                ax_ind_traj.legend(title="Subject (Style=Group)")
                fig_ind_traj.tight_layout()

            finalize_plot(fig_ind_traj, plt, run, 
                          f"charts_eda_{DATASET_IDENTIFIER}/longitudinal/mmse_individual_examples", 
                          output_dir / '20_mmse_individual_examples.png')
        else:
            print("Could not select example subjects for individual trajectory plot.")
    else:
        print("Skipping individual MMSE trajectories plot due to missing columns.")
else:
    print("Skipping deeper longitudinal analysis as clinical_df is empty.")

## 11. Verify MRI Scan File Availability

This section verifies that the raw MRI scan files, corresponding to the `MRI ID` and `Subject ID` in the clinical dataset, exist in the local file system. This step:
* Iterates through each visit record in the loaded `clinical_df`.
* For each record, constructs the expected file path for the raw T1w MPRAGE scan data (looking for `.nifti.img` and `.nifti.hdr` pairs within a `RAW` subfolder, across the base MRI data paths defined in `config.json`).
* Checks for the existence of the scan folder and the required image file pairs.
* Records detailed verification status for each scan session.

*Note on OASIS-2 MPRAGE Scans:* The OASIS dataset typically includes 3-4 individual T1w MPRAGE acquisitions per imaging session. This verification checks for these and notes if at least three are found, as this often indicates a complete acquisition.

In [None]:
# --- Verify Local Availability of Raw MRI Scan Files ---
print(f"\n--- Verifying Local MPRAGE Files (.img + .hdr pairs) ---")
print(f"Using MRI Base Paths: {[str(p) for p in MRI_BASE_PATHS_ABSOLUTE]}") # From Cell 3
print(f"Using Pattern: {MPR_IMG_PATTERN_CONFIG.pattern}") # From Cell 3
print(f"Expecting structure like: <mri_base_path>/<MRI ID>/RAW/mpr-<#>.nifti.{{img,hdr}}")

verification_log_entries = [] # To store detailed log dictionaries for each scan

if not clinical_df.empty and MRI_BASE_PATHS_ABSOLUTE:
    # Check if *at least one* base path actually exists (as a directory)
    any_base_path_is_valid = any(p.is_dir() for p in MRI_BASE_PATHS_ABSOLUTE)
    if not any_base_path_is_valid:
        print(f"CRITICAL ERROR: None of the configured base MRI paths exist or are directories.")
        print(f"  Checked paths: {[str(p) for p in MRI_BASE_PATHS_ABSOLUTE]}")
        print("Skipping image file verification. Subsequent preprocessing will likely fail.")
        if run:
            run.log({'verification/error_any_mri_base_path_invalid': True})
            # run.finish(exit_code=1) # Consider exiting if this is critical
        # To allow notebook to complete if this is an error, we create an empty verification_df
        verification_df = pd.DataFrame(columns=['mri_id', 'subject_id', 'visit', 'group', 
                                                'mri_base_path_used', 'mri_folder_path_checked', 
                                                'mri_folder_exists', 'mprs_found_count', 
                                                'mpr_labels_found', 'found_three_or_more_mprs', 
                                                'error_listing_dir'])
    else:
        if run: run.log({'verification/info_mri_base_paths_valid': True})

        print("\nStarting scan file verification process...")
        verification_start_time = time.time()

        # Iterate through each row (visit) in the clinical dataframe
        for index, row in tqdm(clinical_df.iterrows(), total=len(clinical_df), desc="Verifying Scan Files"):
            log_entry = verify_scan_files(row, MRI_BASE_PATHS_ABSOLUTE, MPR_IMG_PATTERN_CONFIG)
            verification_log_entries.append(log_entry)
            
            # Optional: Print immediate feedback for missing folders if verbose
            # if not log_entry['mri_folder_exists']:
            #     print(f"  [Missing Folder] MRI ID {log_entry['mri_id']} ({log_entry['subject_id']}): "
            #           f"Checked {log_entry['mri_folder_path_checked']}")
        
        verification_df = pd.DataFrame(verification_log_entries) # Create DataFrame from all log entries
        verification_duration = time.time() - verification_start_time
        print(f"\nFinished scan verification in {verification_duration:.2f} seconds.")
        if run: run.log({'verification/duration_seconds': verification_duration})
else:
    print("Skipping MRI scan file verification: clinical_df is empty or MRI_BASE_PATHS not defined.")
    # Ensure verification_df exists and is empty for downstream cells
    verification_df = pd.DataFrame(columns=['mri_id', 'subject_id', 'visit', 'group', 
                                            'mri_base_path_used', 'mri_folder_path_checked',
                                            'mri_folder_exists', 'mprs_found_count', 
                                            'mpr_labels_found', 'found_three_or_more_mprs',
                                            'error_listing_dir'])

## 12. Final MRI Verification Summary, Logging, and Output Saving

This section summarizes the results of the MRI file verification process:
* Total MRI IDs processed from the clinical data.
* Number of expected scan folders found versus missing.
* Total count of valid MPRAGE scan pairs (`.img` + `.hdr`) located.
* Number of unique subjects for whom at least one scan pair was found.
* Number of unique subjects who have at least one scan session with three or more MPRAGE pairs (indicative of a complete OASIS-2 acquisition).

These summary statistics are printed and logged to W&B. The detailed, per-scan verification results are saved locally to `verification_details.csv` (in this notebook's output directory) and also logged as a `wandb.Table` for detailed inspection in the W&B interface. A list of MRI IDs for which scan folders were not found is also saved locally and logged as a W&B artifact.

In [None]:
# --- Process and Summarize MRI Verification Results ---
print("\n--- Processing and Summarizing MRI Verification Results ---")

if not verification_df.empty:
    # Calculate summary statistics from the verification_df
    num_mri_ids_in_clinical_df = verification_df['mri_id'].nunique() # Unique MRI IDs checked
    num_folders_found = verification_df['mri_folder_exists'].sum()
    num_folders_missing = num_mri_ids_in_clinical_df - num_folders_found # More accurate if some MRI IDs were duplicated in clinical_df but unique in verification_df
    
    total_mpr_pairs_located = verification_df['mprs_found_count'].sum()
    
    # Subjects for whom at least one scan file (any mpr pair) was found
    subjects_with_any_verified_mprs = verification_df[verification_df['mprs_found_count'] > 0]['subject_id'].nunique()
    # Subjects who have at least one scan session where 3+ mpr pairs were found
    subjects_with_at_least_one_complete_scan_session = verification_df[verification_df['found_three_or_more_mprs'] == True]['subject_id'].nunique()

    print("\n--- Final MRI Scan File Verification Summary ---")
    print(f"Total unique MRI IDs from clinical data processed: {num_mri_ids_in_clinical_df}")
    print(f"  Corresponding scan folders found: {num_folders_found}")
    print(f"  Corresponding scan folders missing: {num_folders_missing}")
    print(f"Total complete MPRAGE pairs (.img + .hdr) located across all found folders: {total_mpr_pairs_located}")
    print(f"Unique subjects with at least one MPRAGE pair found: {subjects_with_any_verified_mprs}")
    print(f"Unique subjects with at least one scan session having >= 3 MPRAGE pairs: {subjects_with_at_least_one_complete_scan_session}")

    # Save detailed verification results DataFrame locally
    verification_output_filename = f"verification_details_{DATASET_IDENTIFIER}.csv"
    verification_output_path = output_dir / verification_output_filename
    try:
        verification_df.to_csv(verification_output_path, index=False)
        print(f"\nDetailed verification results saved locally to: {verification_output_path}")
    except Exception as e_save_verif:
        print(f"Warning: Could not save detailed verification results locally. Error: {e_save_verif}")
    
    # Save list of missing folders locally
    missing_folders_df = verification_df[verification_df['mri_folder_exists'] == False][['mri_id', 'subject_id', 'visit', 'group', 'mri_folder_path_checked']]
    if not missing_folders_df.empty:
        missing_folders_filename = f"missing_mri_folders_{DATASET_IDENTIFIER}.csv"
        missing_folders_path = output_dir / missing_folders_filename
        try:
            missing_folders_df.to_csv(missing_folders_path, index=False)
            print(f"List of {len(missing_folders_df)} missing MRI folders saved locally to: {missing_folders_path}")
        except Exception as e_save_missing_folders:
            print(f"Warning: Could not save missing MRI folders list locally. Error: {e_save_missing_folders}")

    # Log to W&B
    if run:
        print("\nLogging verification summary and details to W&B...")
        run.log({
            'verification/total_mri_ids_checked': num_mri_ids_in_clinical_df,
            'verification/scan_folders_found': num_folders_found,
            'verification/scan_folders_missing': num_folders_missing,
            'verification/total_mpr_pairs_located': total_mpr_pairs_located,
            'verification/subjects_with_any_mprs': subjects_with_any_verified_mprs,
            'verification/subjects_with_min_3_mprs_in_a_scan': subjects_with_at_least_one_complete_scan_session
        })
        
        # Log the detailed verification DataFrame as a W&B Table
        try:
            verification_wandb_table = wandb.Table(dataframe=verification_df)
            run.log({"verification/scan_verification_details_table": verification_wandb_table})
            print("Detailed verification results logged as W&B Table.")
        except Exception as e_wandb_table:
            print(f"Warning: Could not log verification details table to W&B. Error: {e_wandb_table}")

        # Log the verification_details.csv itself as an artifact for direct download
        try:
            verif_details_artifact_name = f"verification_details_{DATASET_IDENTIFIER}"
            verif_details_artifact = wandb.Artifact(
                verif_details_artifact_name, 
                type="dataset_verification_report",
                description=f"Detailed MRI scan file verification status for {DATASET_IDENTIFIER}."
            )
            verif_details_artifact.add_file(str(verification_output_path))
            run.log_artifact(verif_details_artifact, aliases=["latest"])
            print(f"'{verification_output_filename}' logged as W&B Artifact.")
        except Exception as e_wandb_art_verif:
            print(f"Warning: Could not log verification_details.csv as W&B artifact. Error: {e_wandb_art_verif}")

        # Log missing folders list as a W&B artifact if it exists and was saved
        if not missing_folders_df.empty and missing_folders_path.exists():
            try:
                missing_folders_artifact_name = f"missing_mri_folders_report_{DATASET_IDENTIFIER}"
                missing_folders_artifact = wandb.Artifact(
                    missing_folders_artifact_name, 
                    type="analysis_output",
                    description=f"List of MRI IDs for {DATASET_IDENTIFIER} whose scan folders were not found."
                )
                missing_folders_artifact.add_file(str(missing_folders_path))
                run.log_artifact(missing_folders_artifact)
                print(f"List of {len(missing_folders_df)} missing folders logged as W&B Artifact.")
            except Exception as e_wandb_art_missing:
                print(f"Warning: Could not log missing_folders.csv as W&B artifact. Error: {e_wandb_art_missing}")
        elif missing_folders_df.empty:
            print("No missing MRI folders found to log as artifact.")
else:
    print("Verification DataFrame is empty. Skipping summary, saving, and W&B logging of verification results.")

## Finalize Run

Complete the execution for this notebook and finish the associated Weights & Biases run.

In [None]:
# --- Finish W&B Run ---
print("\n--- Exploration and Verification complete. Finishing W&B run. ---")
if run:
    run.finish()
    print("W&B run finished.")
else:
    print("No active W&B run to finish.")

print("\nScript execution finished.")