In [None]:
# In: notebooks/01_Data_Exploration_OASIS2.ipynb
# Purpose: Load OASIS-2 clinical data using config, perform detailed exploration
#          (including longitudinal aspects), verify MPRAGE file presence,
#          log results efficiently to W&B, and save key outputs locally.

In [None]:
# --- Import Libraries ---
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import wandb
import os
import re
import time
import json
from pathlib import Path

In [None]:
# --- Config Loading ---
print("--- Loading Configuration ---")
CONFIG_PATH = Path('../config.json')
try:
    # --- Determine Project Root ---
    # Assumes config.json is in the project root directory
    PROJECT_ROOT = CONFIG_PATH.parent.resolve()
    print(f"Project Root detected as: {PROJECT_ROOT}")

    with open(CONFIG_PATH, 'r', encoding='utf-8') as f: # Added encoding just in case
        config = json.load(f)
    print("Configuration loaded successfully.")

    # --- Resolve paths relative to PROJECT_ROOT ---
    INPUT_DATA_PATH = PROJECT_ROOT / config['data']['clinical_excel_path']
    MRI_BASE_PATHS = [PROJECT_ROOT / p for p in config['data']['mri_base_paths']]
    OUTPUT_DIR_BASE = PROJECT_ROOT / config['data']['output_dir_base']
    # --- (End of path resolution) ---

    MPR_IMG_PATTERN = re.compile(config['mri_verification']['mpr_img_pattern'])
    WANDB_PROJECT = config['wandb']['project_name']
    WANDB_ENTITY = config['wandb'].get('entity', None)

    NOTEBOOK_NAME = "01_Data_Exploration"
    output_dir = OUTPUT_DIR_BASE / NOTEBOOK_NAME
    output_dir.mkdir(parents=True, exist_ok=True)
    print(f"Outputs will be saved to: {output_dir}")

except FileNotFoundError:
    print(f"Error: Configuration file not found at {CONFIG_PATH}")
    print("Please ensure 'config.json' exists in the project root.")
    exit()
except KeyError as e:
    print(f"Error: Missing key {e} in configuration file.")
    exit()
except Exception as e:
    print(f"An error occurred loading the config file: {e}")
    exit()

In [None]:
# --- Helper Functions ---
print("\n--- Defining Helper Functions ---")

def finalize_plot(fig, run, wandb_key, save_path):
    """Handles logging plot to W&B, saving locally, showing, and closing."""
    if run:
        run.log({wandb_key: wandb.Image(fig)})
    if save_path:
        try:
            fig.savefig(save_path, bbox_inches='tight') # Use bbox_inches for better layout saving
        except Exception as e:
            print(f"Warning: Could not save plot to {save_path}. Error: {e}")
    plt.show()
    plt.close(fig) # Close the specific figure

def verify_scan_files(row, base_paths, pattern):
    """Verifies presence of img/hdr pairs for a given clinical data row."""
    subject_id = row['Subject ID']
    mri_id = row['MRI ID']
    mri_folder = None
    found_in_path = None

    log_entry = {
        'mri_id': mri_id,
        'subject_id': subject_id,
        'visit': row.get('Visit', None),
        'group': row.get('Group', None),
    }

    for base_path in base_paths:
        potential_folder = base_path / mri_id / 'RAW'
        if potential_folder.is_dir():
            mri_folder = potential_folder
            found_in_path = str(base_path) # Store as string for logging
            break

    log_entry['mri_base_path_used'] = found_in_path

    if mri_folder is None:
        log_entry.update({
            'mri_folder_exists': False,
            'mri_folder_is_dir': False,
            'mprs_found_count': 0,
            'mpr_labels_found': [],
            'found_three_or_more_mprs': False
        })
        return log_entry

    log_entry['mri_folder_exists'] = True
    log_entry['mri_folder_is_dir'] = True
    found_mpr_pairs = {}

    try:
        filenames = [f.name for f in mri_folder.iterdir()] # Use pathlib
    except OSError as e:
        log_entry.update({
            'mprs_found_count': 0,
            'error_listing_dir': str(e),
            'mpr_labels_found': [],
            'found_three_or_more_mprs': False
        })
        return log_entry

    for filename in filenames:
        match = pattern.match(filename)
        if match:
            mpr_label = match.group(1)
            hdr_filename = f"{mpr_label}.nifti.hdr"
            hdr_path = mri_folder / hdr_filename
            img_path = mri_folder / filename

            if hdr_path.is_file(): # Check if header file exists
                found_mpr_pairs[mpr_label] = (str(img_path), str(hdr_path))

    num_found = len(found_mpr_pairs)
    log_entry.update({
        'mprs_found_count': num_found,
        'mpr_labels_found': sorted(list(found_mpr_pairs.keys())),
        'found_three_or_more_mprs': num_found >= 3
    })
    return log_entry

In [None]:
# --- Initialize W&B Run ---
print("\n--- Initializing Weights & Biases Run ---")
try:
    run = wandb.init(
        project=WANDB_PROJECT,
        entity=WANDB_ENTITY,
        job_type="data-exploration-validation",
        name=f"{NOTEBOOK_NAME}-run-{time.strftime('%Y%m%d-%H%M')}",
        config={ # Log configuration parameters derived from the loaded config
            "input_data_path": str(INPUT_DATA_PATH),
            "mri_base_paths": [str(p) for p in MRI_BASE_PATHS],
            "mpr_img_pattern": MPR_IMG_PATTERN.pattern,
            "output_dir": str(output_dir),
            "execution_date": time.strftime("%Y-%m-%d %H:%M:%S")
        }
    )
    print(f"W&B run '{run.name}' initialized successfully (ID: {run.id}). View at: {run.url}")
except Exception as e:
    print(f"Error initializing W&B: {e}")
    print("Proceeding without W&B logging.")
    run = None


## Load Clinical Data

Load the raw longitudinal clinical and demographic data from the specified Excel file using pandas. We also initialize Weights & Biases here to track this exploration run and log the source data as an artifact for reproducibility.

In [None]:
print(f"\n--- Loading Clinical Data from: {INPUT_DATA_PATH} ---")
try:
    # Make sure the file path exists before trying to read
    if not INPUT_DATA_PATH.is_file():
         raise FileNotFoundError(f"Input data file not found at {INPUT_DATA_PATH}")

    clinical_df = pd.read_excel(INPUT_DATA_PATH)
    print(f"Data loaded successfully. Shape: {clinical_df.shape}")

    if clinical_df.empty:
        print("Error: Loaded dataframe is empty.")
        if run: run.finish()
        exit()

    if run:
        # Log a summary artifact of the input table
        print("Logging raw data as W&B artifact...")
        raw_data_at = wandb.Artifact(f"{INPUT_DATA_PATH.stem}_raw", type="dataset",
                                     description=f"Raw clinical data from {INPUT_DATA_PATH.name}",
                                     metadata={"shape": clinical_df.shape, "source_path": str(INPUT_DATA_PATH)})
        # Adding the file to the artifact (W&B handles upload)
        try:
             raw_data_at.add_file(str(INPUT_DATA_PATH))
             run.log_artifact(raw_data_at)
             print("Raw data artifact logged.")
        except Exception as e:
             print(f"Warning: Could not add file to W&B artifact. Error: {e}")


except FileNotFoundError as e:
    print(f"Error: {e}")
    if run: run.finish()
    exit()
except ImportError: # More specific error for missing excel readers
     print(f"Error loading Excel file: Missing library.")
     print("You might need to install 'openpyxl' (`pip install openpyxl`)")
     if run: run.finish()
     exit()
except Exception as e:
    print(f"An error occurred loading the data file: {e}")
    if run: run.finish()
    exit()

## Initial Data Inspection

Perform basic checks on the loaded `clinical_df_raw` DataFrame:
* View data types, non-null counts, memory usage (`.info()`).
* Calculate descriptive statistics for all columns (`.describe(include='all')`).
* Identify and count missing values per column (`.isnull().sum()`).
Log summary statistics derived from these checks to Weights & Biases.

In [None]:
print("\n--- Basic Data Information ---")
print("DataFrame Info:")
clinical_df.info()

print("\nDescriptive Statistics:")
desc_stats = clinical_df.describe(include='all')
print(desc_stats)
# Save descriptive stats locally
desc_stats_path = output_dir / 'descriptive_stats.csv'
try:
    desc_stats.to_csv(desc_stats_path)
    print(f"Descriptive stats saved to {desc_stats_path}")
except Exception as e:
    print(f"Warning: Could not save descriptive stats. Error: {e}")

### Missing Value Strategy Note

Acknowledge the identified missing values (especially in columns like MMSE, SES). Note that a specific strategy for handling these (e.g., imputation) will be decided upon and implemented during the preprocessing stage *after* the data has been split into train/validation/test sets to prevent data leakage.

In [None]:
print("\nMissing Values per Column (Only showing columns with missing values):")
missing_values = clinical_df.isnull().sum()
missing_values_filtered = missing_values[missing_values > 0]
print(missing_values_filtered)
# Save missing values summary locally
missing_values_path = output_dir / 'missing_values_summary.csv'
try:
    missing_values_filtered.to_csv(missing_values_path, header=['missing_count'])
    print(f"Missing values summary saved to {missing_values_path}")
except Exception as e:
    print(f"Warning: Could not save missing values summary. Error: {e}")

In [None]:
# --- Log key stats to W&B ---
if run:
    print("\nLogging summary statistics to W&B...")
    log_dict = {
        'dataset/num_rows': clinical_df.shape[0],
        'dataset/num_columns': clinical_df.shape[1],
        'dataset/missing_values_total': int(missing_values.sum()),
        'dataset/columns_with_missing_values': int(len(missing_values_filtered)),
        'dataset/memory_usage_MB': clinical_df.memory_usage(deep=True).sum() / (1024**2)
    }
    # Add descriptive stats for key columns if they exist
    for col in ['Age', 'EDUC', 'SES', 'MMSE', 'CDR', 'eTIV', 'nWBV', 'ASF']:
        if col in desc_stats.columns:
            # Check for non-numeric stats like 'unique', 'top', 'freq' if include='all' was used
            stats_to_log = {k: v for k, v in desc_stats[col].dropna().items() if pd.api.types.is_number(v)}
            log_dict.update({f'stats/{col}_{k}': v for k, v in stats_to_log.items()})

    run.log(log_dict)
    print("Summary statistics logged to W&B.")

## Analyze Variable Distributions

Visualize the distributions of key individual variables using histograms (for numerical data like Age, MMSE, nWBV) and count plots (for categorical data like Group, Gender, or discrete counts like VisitsPerSubject). This helps understand the overall characteristics of the cohort across all recorded visits. Plots are logged to W&B and saved locally to the notebook's output directory.

In [None]:
print("\n--- Analyzing Variable Distributions ---")

# Distribution of Age
fig, ax = plt.subplots(figsize=(10, 5)) # Use subplots for more control
sns.histplot(data=clinical_df, x='Age', kde=True, bins=20, ax=ax)
ax.set_title('Distribution of Age')
ax.set_xlabel('Age (Years)')
ax.set_ylabel('Frequency')
finalize_plot(fig, run, "charts/distribution/age", output_dir / 'age_distribution.png')

# Distribution of MMSE scores
fig, ax = plt.subplots(figsize=(10, 5))
sns.histplot(data=clinical_df.dropna(subset=['MMSE']), x='MMSE', kde=True, bins=15, ax=ax)
ax.set_title('Distribution of MMSE Scores')
ax.set_xlabel('MMSE Score')
ax.set_ylabel('Frequency')
finalize_plot(fig, run, "charts/distribution/mmse", output_dir / 'mmse_distribution.png')

# Distribution of nWBV
fig, ax = plt.subplots(figsize=(10, 5))
sns.histplot(data=clinical_df.dropna(subset=['nWBV']), x='nWBV', kde=True, bins=20, ax=ax)
ax.set_title('Distribution of Normalized Whole Brain Volume (nWBV)')
ax.set_xlabel('nWBV')
ax.set_ylabel('Frequency')
finalize_plot(fig, run, "charts/distribution/nwbv", output_dir / 'nwbv_distribution.png')

# Count of Visits per Subject
visits_per_subject = clinical_df['Subject ID'].value_counts()
fig, ax = plt.subplots(figsize=(10, 5))
sns.countplot(x=visits_per_subject, ax=ax)
ax.set_title('Number of Visits per Subject')
ax.set_xlabel('Number of Visits')
ax.set_ylabel('Number of Subjects')
finalize_plot(fig, run, "charts/distribution/visits_per_subject", output_dir / 'visits_per_subj.png')

# Distribution of Clinical Groups
fig, ax = plt.subplots(figsize=(8, 5))
group_order = clinical_df['Group'].value_counts().index # Determine order dynamically
sns.countplot(data=clinical_df, x='Group', order=group_order, ax=ax)
ax.set_title('Distribution of Clinical Groups (All Visits)')
ax.set_xlabel('Group')
ax.set_ylabel('Number of Visits/Scans')
finalize_plot(fig, run, "charts/distribution/group", output_dir / 'clinical_groups_distribution.png')

# Distribution of Gender
fig, ax = plt.subplots(figsize=(6, 4))
sns.countplot(data=clinical_df, x='M/F', ax=ax)
ax.set_title('Distribution of Gender (All Visits)')
ax.set_xlabel('Gender')
ax.set_ylabel('Number of Visits/Scans')
finalize_plot(fig, run, "charts/distribution/gender", output_dir / 'gender_distribution.png')

## Analyze Relationships Between Variables

Explore potential pairwise relationships between important variables. This includes:
* Scatter plots to visualize relationships like Age vs. MMSE, Age vs. nWBV, colored by clinical group.
* A correlation matrix heatmap for key numerical variables to quantify linear associations.
Plots are logged to W&B and saved locally.

In [None]:
print("\n--- Analyzing Relationships Between Variables ---")

# Age vs. MMSE score
fig, ax = plt.subplots(figsize=(10, 6))
sns.scatterplot(data=clinical_df, x='Age', y='MMSE', hue='Group', alpha=0.6, ax=ax, hue_order=group_order)
ax.set_title('Age vs. MMSE Score by Group')
ax.set_xlabel('Age (Years)')
ax.set_ylabel('MMSE Score')
finalize_plot(fig, run, "charts/relationship/age_vs_mmse", output_dir / 'age_mmse.png')

# Age vs. nWBV
fig, ax = plt.subplots(figsize=(10, 6))
sns.scatterplot(data=clinical_df, x='Age', y='nWBV', hue='Group', alpha=0.6, ax=ax, hue_order=group_order)
ax.set_title('Age vs. Normalized Whole Brain Volume (nWBV) by Group')
ax.set_xlabel('Age (Years)')
ax.set_ylabel('nWBV')
finalize_plot(fig, run, "charts/relationship/age_vs_nwbv", output_dir / 'age_nwbv.png')

# MMSE vs. nWBV
fig, ax = plt.subplots(figsize=(10, 6))
sns.scatterplot(data=clinical_df.dropna(subset=['MMSE', 'nWBV']), x='MMSE', y='nWBV', hue='Group', alpha=0.6, ax=ax, hue_order=group_order)
ax.set_title('MMSE vs. Normalized Whole Brain Volume (nWBV) by Group')
ax.set_xlabel('MMSE Score')
ax.set_ylabel('nWBV')
finalize_plot(fig, run, "charts/relationship/mmse_vs_nwbv", output_dir / 'mmse_nwbv.png')

# Correlation Heatmap
numerical_cols = ['Age', 'EDUC', 'SES', 'MMSE', 'CDR', 'eTIV', 'nWBV', 'ASF', 'Visit'] # Added Visit
valid_numerical_cols = [col for col in numerical_cols if col in clinical_df.columns and pd.api.types.is_numeric_dtype(clinical_df[col])]
if valid_numerical_cols:
    fig, ax = plt.subplots(figsize=(10, 8))
    correlation_matrix = clinical_df[valid_numerical_cols].corr()
    sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', fmt=".2f", ax=ax, annot_kws={"size": 8}) # Smaller annotation font
    ax.set_title('Correlation Matrix of Key Numerical Variables')
    plt.xticks(rotation=45, ha='right')
    plt.yticks(rotation=0)
    fig.tight_layout() # Use figure's tight_layout
    finalize_plot(fig, run, "charts/relationship/correlation_heatmap", output_dir / 'corr_matrix.png')

## Analyze Differences Between Clinical Groups

Compare the distributions of key variables (e.g., MMSE, nWBV, Age) across the different clinical groups ('Nondemented', 'Converted', 'Demented') defined in the dataset. Box plots and violin plots are used for visualization. Plots are logged to W&B and saved locally.

In [None]:
print("\n--- Analyzing Differences Between Clinical Groups ---")
group_order_analysis = ['Nondemented', 'Converted', 'Demented'] # Define consistent order

# MMSE scores by group
fig, ax = plt.subplots(figsize=(8, 6))
sns.boxplot(data=clinical_df, x='Group', y='MMSE', order=group_order_analysis, ax=ax)
ax.set_title('MMSE Scores by Clinical Group')
ax.set_xlabel('Clinical Group')
ax.set_ylabel('MMSE Score')
finalize_plot(fig, run, "charts/group_comparison/mmse_boxplot", output_dir / 'mmse_clinical_boxplot.png')

# nWBV by group
fig, ax = plt.subplots(figsize=(8, 6))
sns.boxplot(data=clinical_df, x='Group', y='nWBV', order=group_order_analysis, ax=ax)
ax.set_title('nWBV by Clinical Group')
ax.set_xlabel('Clinical Group')
ax.set_ylabel('Normalized Whole Brain Volume (nWBV)')
finalize_plot(fig, run, "charts/group_comparison/nwbv_boxplot", output_dir / 'nwbv_clinical_boxplot.png')

# Age distribution by group
fig, ax = plt.subplots(figsize=(8, 6))
sns.violinplot(data=clinical_df, x='Group', y='Age', order=group_order_analysis, ax=ax)
ax.set_title('Age Distribution by Clinical Group')
ax.set_xlabel('Clinical Group')
ax.set_ylabel('Age (Years)')
finalize_plot(fig, run, "charts/group_comparison/age_violinplot", output_dir / 'age_clinical_violinplot.png')

## Deeper Longitudinal Analysis

Explore the time-based aspects of the data more explicitly:
1.  **Visit Intervals:** Calculate the approximate time elapsed between consecutive visits for each subject and visualize the distribution of these intervals.
2.  **Baseline Characteristics:** Filter the data to include only the first visit (`Visit == 1`) for each subject and analyze the distributions (Group, MMSE, Age) for this baseline cohort specifically.
3.  **Average Trends:** Plot the average MMSE and nWBV scores across visit numbers, separated by clinical group.
4.  **Individual Examples:** Plot the MMSE trajectories for a few example subjects to visualize individual variability (spaghetti plot).
Plots and relevant statistics are logged to W&B and saved locally.

In [None]:
print("\n--- Performing Deeper Longitudinal Analysis ---")

# Calculate Time Between Visits (assuming 'Visit' column indicates order and 'Age' reflects age at visit)
# More robust: Use 'Days_from_Baseline' if available, otherwise calculate from Age, assuming visits are sorted.
# Let's assume we need to calculate from Age and Visit number.
print("Calculating time intervals between visits...")
clinical_df_sorted = clinical_df.sort_values(by=['Subject ID', 'Visit'])
# Calculate difference in Age between consecutive visits for the same subject
clinical_df_sorted['Age_Diff'] = clinical_df_sorted.groupby('Subject ID')['Age'].diff()
# Estimate years between visits (might be noisy if birthdays fall between visits)
# A dedicated Days_from_Baseline column is usually better if present.
clinical_df_sorted['Years_Since_Prev_Visit'] = clinical_df_sorted['Age_Diff'] # Approximation

# Handle the first visit (NaN difference)
first_visit_mask = clinical_df_sorted['Visit'] == 1
clinical_df_sorted.loc[first_visit_mask, 'Years_Since_Prev_Visit'] = 0

# Analyze distribution of intervals (excluding first visit)
visit_intervals = clinical_df_sorted[clinical_df_sorted['Visit'] > 1]['Years_Since_Prev_Visit'].dropna()

if not visit_intervals.empty:
    fig, ax = plt.subplots(figsize=(10, 5))
    sns.histplot(visit_intervals, kde=True, bins=20, ax=ax)
    ax.set_title('Distribution of Approximate Years Between Consecutive Visits')
    ax.set_xlabel('Approx. Years Since Previous Visit')
    ax.set_ylabel('Frequency')
    finalize_plot(fig, run, "charts/longitudinal/visit_interval_distribution", output_dir / 'visit_interval_distribution.png')

    interval_stats = visit_intervals.describe()
    print("\nApproximate Visit Interval Stats (Years):")
    print(interval_stats)
    if run:
        run.log({f'stats/visit_interval_{k}': v for k, v in interval_stats.items()})
else:
    print("\nCould not calculate meaningful visit intervals (e.g., only single visits).")

# --- Baseline Analysis (Visit == 1) ---
print("\nAnalyzing Baseline Characteristics (Visit == 1)...")
baseline_df = clinical_df[clinical_df['Visit'] == 1].copy() # Use copy to avoid SettingWithCopyWarning

if not baseline_df.empty:
    print(f"Number of subjects at baseline (Visit 1): {baseline_df['Subject ID'].nunique()}")
    if run: run.log({'dataset/num_subjects_baseline': baseline_df['Subject ID'].nunique()})

    # Baseline Group Distribution
    fig, ax = plt.subplots(figsize=(8, 5))
    baseline_group_order = baseline_df['Group'].value_counts().index
    sns.countplot(data=baseline_df, x='Group', order=baseline_group_order, ax=ax)
    ax.set_title('Distribution of Clinical Groups at Baseline (Visit 1)')
    ax.set_xlabel('Group at Visit 1')
    ax.set_ylabel('Number of Subjects')
    finalize_plot(fig, run, "charts/baseline/group_distribution", output_dir / 'baseline_group_distribution.png')

    # Baseline MMSE by Group
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.boxplot(data=baseline_df, x='Group', y='MMSE', order=group_order_analysis, ax=ax)
    ax.set_title('MMSE Scores at Baseline (Visit 1) by Group')
    ax.set_xlabel('Clinical Group at Visit 1')
    ax.set_ylabel('MMSE Score')
    finalize_plot(fig, run, "charts/baseline/mmse_boxplot", output_dir / 'baseline_mmse_boxplot.png')

    # Baseline Age by Group
    fig, ax = plt.subplots(figsize=(8, 6))
    sns.violinplot(data=baseline_df, x='Group', y='Age', order=group_order_analysis, ax=ax)
    ax.set_title('Age Distribution at Baseline (Visit 1) by Group')
    ax.set_xlabel('Clinical Group at Visit 1')
    ax.set_ylabel('Age (Years)')
    finalize_plot(fig, run, "charts/baseline/age_violinplot", output_dir / 'baseline_age_violinplot.png')

else:
    print("No data found for Visit == 1.")


# --- Average MMSE/nWBV Trends ---
print("\nPlotting Average Trends Over Visits...")
# Average MMSE score change over visits, separated by group
fig, ax = plt.subplots(figsize=(10, 6))
sns.lineplot(data=clinical_df.dropna(subset=['MMSE']), x='Visit', y='MMSE', hue='Group', marker='o', errorbar='sd', ax=ax, hue_order=group_order_analysis) # Show standard deviation
ax.set_title('Average MMSE Score Trend over Visits by Group')
ax.set_xlabel('Visit Number')
ax.set_ylabel('Average MMSE Score')
ax.set_xticks(sorted(clinical_df['Visit'].unique())) # Ensure all visit numbers are shown as ticks
finalize_plot(fig, run, "charts/longitudinal/mmse_trend_by_visit_group", output_dir / 'mmse_trend_by_visit_group.png')

# Average nWBV change over visits, separated by group
fig, ax = plt.subplots(figsize=(10, 6))
sns.lineplot(data=clinical_df.dropna(subset=['nWBV']), x='Visit', y='nWBV', hue='Group', marker='o', errorbar='sd', ax=ax, hue_order=group_order_analysis)
ax.set_title('Average nWBV Trend over Visits by Group')
ax.set_xlabel('Visit Number')
ax.set_ylabel('Average nWBV')
ax.set_xticks(sorted(clinical_df['Visit'].unique()))
finalize_plot(fig, run, "charts/longitudinal/nwbv_trend_by_visit_group", output_dir / 'nwbv_trend_by_visit_group.png')

# Example Individual Subject Plot
example_subjects = []
# Ensure groups exist before trying to sample from them
present_groups = clinical_df['Group'].unique()
for grp in present_groups:
    subjects_in_group = clinical_df[clinical_df['Group'] == grp]['Subject ID'].unique()
    example_subjects.extend(subjects_in_group[:3]) # Take first 3 available subjects

if example_subjects:
    example_df = clinical_df[clinical_df['Subject ID'].isin(example_subjects)].copy() # Use copy

    fig, ax = plt.subplots(figsize=(12, 7))
    sns.lineplot(data=example_df.dropna(subset=['MMSE']),
                 x='Visit', y='MMSE', hue='Subject ID', style='Group', marker='o', ax=ax)
    ax.set_title('Individual MMSE Score Trends for Example Subjects')
    ax.set_xlabel('Visit Number')
    ax.set_ylabel('MMSE Score')
    ax.set_xticks(sorted(example_df['Visit'].unique()))
    # Adjust legend positioning
    handles, labels = ax.get_legend_handles_labels()
    ax.legend(handles=handles, labels=labels, title="Subject (Style=Group)", bbox_to_anchor=(1.05, 1), loc='upper left')
    fig.tight_layout(rect=[0, 0, 0.85, 1]) # Adjust layout to make space for legend outside
    finalize_plot(fig, run, "charts/longitudinal/mmse_individual_trends_example", output_dir / 'mmse_individual_trends_example.png')
else:
    print("Could not generate example individual subject plot (no subjects found).")

## Verify MRI Scan File Availability

Iterate through each visit record in the loaded clinical data (`clinical_df_raw`). For each record:
1.  Construct the expected file path(s) for the raw T1w MPRAGE scan data (specifically looking for `.nifti.img` and `.nifti.hdr` pairs within the `RAW` subfolder of the directory named by `MRI ID`, checking across the base paths defined in `config.json`).
2.  Check if the expected folder and both files (`.img` + `.hdr`) actually exist in the local filesystem.
3.  Record the verification status (folder found, number of valid pairs found) for each `MRI ID`.

NOTE ON MULTIPLE MPRAGE SCANS PER SESSION: The OASIS documentation notes that 3-4 individual T1w MPRAGE scans were typically acquired during each imaging session. Our check for sessions with >= 3 pairs helps confirm we are finding data consistent with this expected acquisition protocol.

In [None]:
print(f"\n--- Verifying Local MPRAGE Files (.img + .hdr) in specified base paths ---")
print(f"Expecting structure like: <base_path>/<MRI ID>/RAW/mpr-<#>.nifti.{{img,hdr}}")

# Check if *at least one* base path exists
any_base_path_exists = any(p.is_dir() for p in MRI_BASE_PATHS)
if not any_base_path_exists:
    print(f"Error: None of the base MRI paths {[str(p) for p in MRI_BASE_PATHS]} exist or are directories.")
    print("Skipping image file verification.")
    if run:
        run.log({'verification/any_mri_base_path_exists': False, 'config/checked_mri_paths': [str(p) for p in MRI_BASE_PATHS]})
        run.finish()
    exit()
else:
    if run: run.log({'verification/any_mri_base_path_exists': True})

# Initialize counters and storage for W&B table
mri_ids_processed = 0
folders_missing = 0
folders_found = 0
mprs_found_total = 0
subjects_with_any_mprs = set()
subjects_with_three_or_more_mprs = set()
verification_log_entries = [] # Store detailed results for W&B Table

print("Starting scan verification...")
start_time = time.time()

# --- Loop and Verify using Helper Function ---
for index, row in clinical_df.iterrows():
    log_entry = verify_scan_files(row, MRI_BASE_PATHS, MPR_IMG_PATTERN)
    verification_log_entries.append(log_entry)

    # Update summary stats based on the result
    mri_ids_processed += 1
    if not log_entry['mri_folder_exists']:
        folders_missing += 1
        potential_paths_str = ", ".join([str(p / row['MRI ID'] / 'RAW') for p in MRI_BASE_PATHS])
        print(f"[{index+1}/{len(clinical_df)}] [Missing Folder] MRI ID {log_entry['mri_id']} ({log_entry['subject_id']}): Expected folder not found. Searched: {potential_paths_str}")
    else:
        folders_found += 1
        num_found = log_entry['mprs_found_count']
        mprs_found_total += num_found
        if num_found > 0:
            subjects_with_any_mprs.add(log_entry['subject_id'])
            if num_found >= 3:
                subjects_with_three_or_more_mprs.add(log_entry['subject_id'])

## Final Verification Summary, Logging, and Output Saving

Summarize the results of the MRI file verification (number of IDs processed, folders found/missing, total pairs found, subject counts meeting criteria). Log these summary statistics and the detailed per-scan verification results (as a `wandb.Table`) to Weights & Biases. Save the detailed verification results DataFrame (`verification_df`) locally as `verification_details.csv` for use in the next notebook. Also, log the list of missing folders as a W&B artifact.

In [None]:
end_time = time.time()
verification_duration = end_time - start_time
print(f"\nFinished scan verification in {verification_duration:.2f} seconds.")


print("Processing verification results...")
verification_df = pd.DataFrame(verification_log_entries)

# Calculate Final Summary Statistics (using the DataFrame)
mri_ids_processed = len(verification_df)
folders_found = 0
mprs_found_total = 0
subjects_with_any_mprs = set()
subjects_with_three_or_more_mprs = set()

if not verification_df.empty:
    folders_found = verification_df['mri_folder_exists'].sum()
    mprs_found_total = verification_df['mprs_found_count'].sum()
    subjects_with_any_mprs = set(verification_df[verification_df['mprs_found_count'] > 0]['subject_id'].unique())
    subjects_with_three_or_more_mprs = set(verification_df[verification_df['found_three_or_more_mprs'] == True]['subject_id'].unique())
else:
    print("Note: Verification DataFrame is empty (no entries processed).")

folders_missing = mri_ids_processed - folders_found
any_subject_with_all_mprs = len(subjects_with_three_or_more_mprs) > 0 

print("\n--- Final Verification Summary ---")
print(f"Total MRI IDs processed: {mri_ids_processed}")
print(f"Expected folders found: {folders_found}")
print(f"Expected folders missing: {folders_missing}")
print(f"Total complete MPR pairs (.img + .hdr) found: {mprs_found_total}")
print(f"Unique subjects with >= 1 MPR pair found: {len(subjects_with_any_mprs)}")
print(f"Unique subjects with >= 1 scan having >= 3 MPR pairs: {len(subjects_with_three_or_more_mprs)}")

verification_output_path = output_dir / "verification_details.csv"
try:
    # Use index=False to avoid writing the default DataFrame index
    verification_df.to_csv(verification_output_path, index=False)
    print(f"\nDetailed verification results saved locally to: {verification_output_path}")
except Exception as e:
    print(f"Warning: Could not save verification details locally. Error: {e}")
    
if folders_missing > 0:
     print(f"⚠️ Warning: {folders_missing} expected MRI folders were not found.")
if not any_subject_with_all_mprs:
    print("⚠️ Warning: No single MRI scan session was found containing 3 or more valid MPR pairs.")


if run:
    print("\nLogging verification results to W&B...")
    # Log summary stats
    run.log({
        'verification/duration_seconds': verification_duration,
        'verification/total_mri_ids_processed': mri_ids_processed,
        'verification/mri_folders_found': folders_found,
        'verification/mri_folders_missing': folders_missing,
        'verification/total_mpr_pairs_found': mprs_found_total,
        'verification/unique_subjects_with_any_mprs': len(subjects_with_any_mprs),
        'verification/unique_subjects_with_three_or_more_mprs': len(subjects_with_three_or_more_mprs),
    })

    # Log detailed table
    if not verification_df.empty:
        verification_table = wandb.Table(dataframe=verification_df)
        run.log({"verification/details_per_scan": verification_table})
        print("Detailed verification results logged as W&B Table.")
    else:
        print("Skipping detailed table logging as verification DataFrame is empty.")

    # Log missing files list as artifact
    missing_folders_df = verification_df[~verification_df['mri_folder_exists']][['mri_id', 'subject_id', 'visit']]
    if not missing_folders_df.empty:
         missing_file_path = output_dir / "missing_mri_folders.csv"
         try:
             missing_folders_df.to_csv(missing_file_path, index=False)
             missing_artifact = wandb.Artifact("missing_mri_folders", type="analysis_output",
                                              description="List of MRI IDs whose scan folders were not found.")
             missing_artifact.add_file(str(missing_file_path))
             run.log_artifact(missing_artifact)
             print(f"List of {len(missing_folders_df)} missing folders logged as W&B Artifact and saved locally.")
         except Exception as e:
             print(f"Warning: Could not save/log missing folders list. Error: {e}")
    else:
         print("No missing MRI folders found to log/save.")


## Finalize Run

Complete the execution for this notebook and finish the associated Weights & Biases run.

In [None]:
# --- Finish W&B Run ---
print("\n--- Exploration and Verification complete. Finishing W&B run. ---")
if run:
    run.finish()
    print("W&B run finished.")
else:
    print("No active W&B run to finish.")

print("\nScript execution finished.")