# Notebook 08: OASIS-2 Model Analysis - Interpretability & Uncertainty

**Project Phase:** 1 (Model Analysis & Evaluation)
**Dataset:** OASIS-2 Longitudinal MRI & Clinical Data

**Purpose:**
This notebook is dedicated to in-depth analysis of trained models (both baseline LSTM and hybrid CNN+LSTM) for predicting CDR progression using the OASIS-2 dataset. Key objectives include:
1.  Loading pre-trained model checkpoints and their original training configurations from Weights & Biases (W&B) artifacts and runs.
2.  Instantiating the `OASISDataset` with the correct feature sets and preprocessor configurations (sourced from the W&B config of the relevant Notebook 04 run that fitted the preprocessors, which is linked through the training run's config).
3.  Performing **Uncertainty Quantification** using MC Dropout to understand model confidence.
4.  Conducting **Model Interpretability** analyses:
    * **Permutation Feature Importance** for tabular features.
    * **Integrated Gradients** (or other saliency methods via Captum) for the 3D CNN component of the hybrid model to identify salient input MRI regions.
    * **SHAP (SHapley Additive exPlanations)** to understand feature contributions, focusing on the baseline LSTM and the fusion stage of the hybrid model.
5.  Logging all analysis results, visualizations, and summary tables to a new W&B run for this notebook, enabling comparison and reporting in Notebook 09.
6.  Saving generated plots and analysis summaries locally in a run-specific output directory.

**Workflow:**
1.  **Setup:** Import libraries, configure `sys.path`, load `config.json`. Define analysis parameters.
2.  **Path Resolution:** Use `get_dataset_paths` to resolve paths for test data, training data (for SHAP background), and preprocessor files (though preprocessor paths will be confirmed/overridden by the model's training config).
3.  **W&B Initialization:** Start a new W&B run for this analysis notebook using `initialize_wandb_run`. Define a run-specific output directory.
4.  **Main Analysis Loop (Iterate through models to analyze):**
    * **Dynamic Model Selection:** Fetch W&B run paths for models to analyze (e.g., based on tags or job types, or from a predefined list).
    * **Load Model & Training Config:** Use `load_model_from_wandb_artifact` to load the model and its original training `config`.
    * **Prepare `config_for_this_model_dataset`:** Combine the loaded model's training config (which includes feature lists from NB03 and preprocessor choices from NB04) with `base_config` to create the definitive configuration for `OASISDataset`.
    * **Instantiate `analysis_dataset` and `analysis_loader`:** Use the test data.
    * **Run Selected Analyses (based on flags):**
        * MC Dropout Analysis.
        * Permutation Feature Importance.
        * CNN Interpretability (Integrated Gradients) - for hybrid models.
        * SHAP Analysis (Baseline LSTM, Hybrid Fusion Stage).
    * Store results for each model.
5.  **Summarize & Log Overall Results:** Create a summary table of all analyses for all models and log to W&B.
6.  **Finalize W&B Run.**

**Input:**
* `config.json`: Main project configuration file.
* **W&B Run Paths/IDs of trained models (from Notebook 06 & 07)**: These runs contain the model artifacts and their training configurations (which in turn reference the NB04 run's config for features/preprocessing).
* `cohort_test.parquet` (for main analysis) and `cohort_train.parquet` (for SHAP background) - outputs from Notebook 03.
* Paths to preprocessor `.joblib` files - these will be confirmed/derived from the loaded model's `original_training_config` which should point back to NB04's artifacts/config.
* `src/` utility modules.

**Output:**
* **Local Files (in run-specific output directory for this NB08 run, e.g., `notebooks/outputs/08_Model_Analysis_OASIS2/<run_name>/`):**
    * Plots from MC Dropout, Permutation Importance, Integrated Gradients, SHAP.
    * Potentially, tables of results.
* **W&B Run (for this Notebook 08 execution):**
    * Logged analysis configuration.
    * For each analyzed model:
        * MC Dropout metrics and plots.
        * Permutation Importance scores and plot.
        * Integrated Gradients visualizations (for hybrid).
        * SHAP summary plots.
    * A final summary table comparing key analysis results across models.

In [None]:
# In: notebooks/08_Model_Analysis_Interpretability_Uncertainty_OASIS2.ipynb
# Purpose: Load trained models, perform uncertainty quantification (MC Dropout),
#          and various interpretability analyses (Permutation Importance, Integrated Gradients, SHAP).

In [None]:
# --- Standard Libraries & Imports ---
import wandb
import torch
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import sys
import os
# joblib might be used by OASISDataset if it were to load preprocessors directly,
# but here paths passed to it are from downloaded artifacts or resolved.
from tqdm.auto import tqdm # tqdm is used by utility functions
import time
import shap
from torch.utils.data import DataLoader
import torch.nn as nn # For criterion definition if needed
from sklearn.metrics import r2_score

In [None]:
# --- Add src directory to Python path ---

# Initialize
PROJECT_ROOT = None 
try:
    current_notebook_path = Path.cwd() 
    PROJECT_ROOT = current_notebook_path.parent 
    if not (PROJECT_ROOT / "src").exists(): 
        PROJECT_ROOT = current_notebook_path
    if str(PROJECT_ROOT) not in sys.path:
        sys.path.insert(0, str(PROJECT_ROOT))
    print(f"PROJECT_ROOT: {PROJECT_ROOT}") # Changed from "Attempting to use"
    if not (PROJECT_ROOT / "src").exists():
        raise FileNotFoundError("Could not reliably find 'src' directory from PROJECT_ROOT.")
except Exception as e_path:
    print(f"Error setting up PROJECT_ROOT and sys.path: {e_path}")
    # exit() 

# --- Import Custom Modules & Functions ---
try:
    from src.datasets import OASISDataset, pad_collate_fn
    from src.models import BaselineLSTMRegressor, ModularLateFusionLSTM 
    from src.wandb_utils import initialize_wandb_run, load_model_from_wandb_artifact
    from src.plotting_utils import finalize_plot
    from src.paths_utils import get_dataset_paths, get_notebook_run_output_dir # Changed from get_dataset_paths
    from src.uncertainty_utils import get_mc_dropout_predictions, calculate_uncertainty_metrics 
    from src.interpretability_utils import (
         calculate_permutation_importance,
         generate_integrated_gradients_cnn,
         explain_lstm_with_shap,
         explain_hybrid_fusion_with_shap
     )
    from src.evaluation_utils import evaluate_model 
    print("Successfully imported all required custom modules and classes.")
except ModuleNotFoundError as e_mod:
    print(f"ModuleNotFoundError: {e_mod}. Ensure all src/ files exist and sys.path is correct.")
    # exit()
except ImportError as e_imp:
    print(f"ImportError: {e_imp}. Check for circular dependencies or errors within src modules.")
    # exit()
except Exception as e_gen_imp:
    print(f"An unexpected error occurred during custom module imports: {e_gen_imp}")
    # exit()


# --- Load Main Project Configuration ---
print("\n--- Loading Main Project Configuration ---")
base_config = {}
WANDB_ENTITY = None
WANDB_PROJECT = None
try:
    if PROJECT_ROOT is None: raise ValueError("PROJECT_ROOT not defined.")
    CONFIG_PATH_MAIN = PROJECT_ROOT / 'config.json'
    with open(CONFIG_PATH_MAIN, 'r', encoding='utf-8') as f:
        base_config = json.load(f)
    WANDB_ENTITY = base_config.get('wandb', {}).get('entity')
    WANDB_PROJECT = base_config.get('wandb', {}).get('project_name')
    if not WANDB_ENTITY or not WANDB_PROJECT:
        raise KeyError("'wandb:entity' or 'wandb:project_name' not found in config.json.")
    print(f"Main project config loaded. W&B Entity: {WANDB_ENTITY}, Project: {WANDB_PROJECT}")
except Exception as e_cfg:
    print(f"CRITICAL ERROR loading main config.json or W&B details: {e_cfg}")
    # exit()


# --- Device Setup ---
print("\n--- Setting up PyTorch Device ---")
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
    device = torch.device("mps")
    print("Using MPS (Apple Silicon GPU).")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print(f"CUDA is available. Using GPU: {torch.cuda.get_device_name(0)}")
else:
    device = torch.device("cpu")
    print("CUDA and MPS not available. Using CPU.")

## 2. Configuration for Notebook 08 Analysis

This section defines the parameters and settings specific to this analysis notebook:
* **Models to Analyze (`MODELS_TO_ANALYZE_FETCH`):** A dictionary specifying which trained models (by their W&B Run Path) will be loaded and analyzed. This will be updated to dynamically fetch runs.
* **MC Dropout Parameters:** `N_MC_SAMPLES` (number of Monte Carlo forward passes).
* **Detailed Analysis Sample Size:** `N_SAMPLES_FOR_DETAILED_MC_ANALYSIS` (e.g., for per-sample SHAP or IG visualizations).
* **Analysis Hyperparameters (`HP_analysis`):** A dictionary holding other analysis-specific settings, including:
    * `analysis_batch_size`: Batch size for the `analysis_loader`.
    * Boolean flags to control which analyses are run (e.g., `RUN_MC_DROPOUT`, `RUN_PERMUTATION_IMPORTANCE`, `RUN_INTEGRATED_GRADIENTS`, `RUN_SHAP_BASELINE`, `RUN_SHAP_HYBRID_FUSION`).
    * Parameters for specific interpretability methods (e.g., `shap_num_background_samples`, `ig_n_steps`).

In [None]:
# --- Configuration for Notebook 08 Analysis ---
print("\n--- Configuring Notebook 08 Analysis Parameters ---")

DATASET_IDENTIFIER = "oasis2" # Ensure this is consistent
NOTEBOOK_MODULE_NAME = "08_Model_Analysis_Interpretability_Uncertainty"
# Key from config.json locators for this notebook's output subfolder
NB08_OUTPUT_LOCATOR_KEY = base_config.get(f"pipeline_artefact_locators_{DATASET_IDENTIFIER}", {})\
                                     .get("analysis_subdir", "08_Model_Analysis_Default_Outputs")


# --- Dynamic Fetching of Models to Analyze (Replaces hardcoded MODELS_TO_ANALYZE_FETCH) ---
print("\n--- Dynamically Fetching Trained Model Runs from W&B for Analysis ---")
MODELS_TO_ANALYZE_FETCH = {}
try:
    if WANDB_ENTITY and WANDB_PROJECT: # Ensure these are defined from base_config
        wandb_api = wandb.Api(timeout=29) # Set a timeout for API calls
        
        # Define criteria for runs to analyze
        # This uses the job_type defined by initialize_wandb_run in NB06 and NB07
        model_run_criteria = {
            "BaselineLSTM_OASIS2_Best": {
                "job_type": f"Training-06-BaselineLSTM-{DATASET_IDENTIFIER}", 
                # Add tags here if you use them, e.g., "tags": "best_candidate_baseline"
            },
            "HybridCNNLSTM_OASIS2_Best": {
                "job_type": f"Training-07-HybridCNNLSTM-{DATASET_IDENTIFIER}", 
                # Add tags here if you use them, e.g., "tags": "best_candidate_hybrid"
            }
            # Add more criteria for other models if needed
        }

        for nickname, criteria in model_run_criteria.items():
            filters = {"$and": [{"state": "finished"}]} # Start with finished runs
            if "job_type" in criteria and criteria["job_type"]: # Check if job_type is not None or empty
                filters["$and"].append({"job_type": criteria["job_type"]})
            if "tags" in criteria and criteria["tags"]: # Check if tags are not None or empty
                # Ensure tags are passed as a list if it's a single tag string for "$in"
                tags_to_filter = criteria["tags"]
                if isinstance(tags_to_filter, str):
                    tags_to_filter = [tags_to_filter]
                filters["$and"].append({"tags": {"$in": tags_to_filter}})
            
            print(f"  Querying W&B for '{nickname}' with Job Type: '{criteria.get('job_type', 'Any')}' and Tags: '{criteria.get('tags', 'Any')}'")
            
            selected_runs = wandb_api.runs(
                path=f"{WANDB_ENTITY}/{WANDB_PROJECT}",
                filters=filters,
                order="-created_at" # Get the most recent run satisfying criteria
            )
            
            if selected_runs:
                # Take the first run from the ordered list
                MODELS_TO_ANALYZE_FETCH[nickname] = selected_runs[0].path # "entity/project/run_id"
                print(f"    Found for '{nickname}': {selected_runs[0].name} (Run Path: {selected_runs[0].path})")
            else:
                print(f"    WARNING: No runs found via W&B API for '{nickname}' with specified criteria.")
                # Fallback to your manually defined MODELS_TO_ANALYZE if dynamic fetch fails
                if 'MODELS_TO_ANALYZE' in locals() and isinstance(MODELS_TO_ANALYZE, dict):
                    manual_fallback_path = MODELS_TO_ANALYZE.get(nickname)
                    if manual_fallback_path and "RUN_ID_" not in manual_fallback_path.upper(): # Check if not placeholder
                        MODELS_TO_ANALYZE_FETCH[nickname] = manual_fallback_path
                        print(f"      Using manually specified fallback path for '{nickname}': {manual_fallback_path}")
                    else:
                        print(f"      No valid manual fallback path found for '{nickname}' in MODELS_TO_ANALYZE.")
        
        if not MODELS_TO_ANALYZE_FETCH:
            print("WARNING: No models were successfully fetched dynamically or found as valid manual fallbacks. "
                  "The main analysis loop will be empty.")
        else:
            print("\nFinal list of models to be analyzed in this session:")
            for name, path in MODELS_TO_ANALYZE_FETCH.items(): 
                print(f"  - {name}: {path}")

    else: # WANDB_ENTITY or WANDB_PROJECT were not defined
        print("WARNING: WANDB_ENTITY or WANDB_PROJECT not defined from base_config. Cannot dynamically fetch runs.")
        # Fallback to using the manually defined MODELS_TO_ANALYZE dictionary
        if 'MODELS_TO_ANALYZE' in locals() and isinstance(MODELS_TO_ANALYZE, dict):
            MODELS_TO_ANALYZE_FETCH = {
                name: path for name, path in MODELS_TO_ANALYZE.items() 
                if "RUN_ID_" not in path.upper() # Ensure it's not a placeholder
            }
            if MODELS_TO_ANALYZE_FETCH:
                print(f"Using manually defined MODELS_TO_ANALYZE_FETCH: {MODELS_TO_ANALYZE_FETCH}")
            else:
                print("Manually defined MODELS_TO_ANALYZE is also empty or contains only placeholders.")
        else:
            print("Manually defined MODELS_TO_ANALYZE dictionary not found.")

except Exception as e_fetch_runs:
    print(f"An error occurred during dynamic fetching of W&B runs: {e_fetch_runs}")
    print("Falling back to manually defined MODELS_TO_ANALYZE_FETCH if available, or it will be empty.")
    # Ensure MODELS_TO_ANALYZE_FETCH exists, even if empty, to prevent later NameErrors
    if 'MODELS_TO_ANALYZE_FETCH' not in locals():
        MODELS_TO_ANALYZE_FETCH = {}
    if not MODELS_TO_ANALYZE_FETCH and 'MODELS_TO_ANALYZE' in locals() and isinstance(MODELS_TO_ANALYZE, dict):
        MODELS_TO_ANALYZE_FETCH = {
            name: path for name, path in MODELS_TO_ANALYZE.items() 
            if "RUN_ID_" not in path.upper()
        }
        if MODELS_TO_ANALYZE_FETCH:
             print(f"Using manually defined MODELS_TO_ANALYZE_FETCH due to error: {MODELS_TO_ANALYZE_FETCH}")


# --- Analysis-Specific Hyperparameters & Control Flags ---
N_MC_SAMPLES = 30  
N_SAMPLES_FOR_DETAILED_ANALYSIS = 5 # Reduced for quicker local analysis runs

HP_analysis = {
    'analysis_batch_size': 4,        # Batch size for DataLoaders created in this notebook
    'num_detailed_samples': N_SAMPLES_FOR_DETAILED_ANALYSIS,

    # --- Flags to control which analyses are run ---
    'RUN_MC_DROPOUT': True,
    'RUN_PERMUTATION_IMPORTANCE': True,
    'RUN_INTEGRATED_GRADIENTS': True, # For CNN part of hybrid models
    'RUN_SHAP_BASELINE': True,        # SHAP for baseline LSTM
    'RUN_SHAP_HYBRID_FUSION': True,   # SHAP for hybrid model's fusion stage

    # --- Parameters for specific interpretability methods ---
    'pfi_num_permutations': 5, # For Permutation Feature Importance
    'ig_n_steps': 10,          # n_steps for Integrated Gradients (keep low for local runs)
    'ig_captum_internal_batch_size': None, # Optional for Captum's IG

    'shap_num_background_samples': 20, # Max samples for SHAP background dataset
    'shap_background_batch_size': 10,  # Batch size for creating SHAP background tensor
    'shap_explain_batch_size': N_SAMPLES_FOR_DETAILED_ANALYSIS, # How many test instances to explain with SHAP
    'shap_kernel_nsamples': 50, # nsamples for KernelExplainer (for hybrid fusion SHAP)
    'shap_kmeans_k': 5         # k for shap.kmeans for KernelExplainer background
}
print("\nAnalysis hyperparameters (HP_analysis) and control flags defined.")
# from pprint import pprint; print("HP_analysis:"); pprint(HP_analysis) # Optional debug

## 3. Define Data Paths for Analysis

Resolve paths to the necessary data files using the `get_dataset_paths` utility. This notebook primarily requires:
* **Test Data (`cohort_test.parquet`):** For evaluating model performance and running most interpretability analyses.
* **Training Data (`cohort_train.parquet`):** Specifically needed as background data for SHAP explanations.
* **Preprocessor Paths (`scaler.joblib`, `imputer.joblib`):** While `OASISDataset` will ultimately use preprocessor paths defined in the *training run's configuration* (fetched from NB04 via NB06/07 config), resolving these paths here based on the *current main `config.json`* can be useful for direct instantiation of `OASISDataset` if needed for specific SHAP background data preparation, or as a reference. The `config_for_this_model_dataset` prepared later will ensure the correct preprocessors tied to each model are used.
* **MRI Data Directory:** For hybrid models.

In [None]:
# --- Define Data Paths for Analysis ---
print("\n--- Defining Data Paths for Analysis using utility ---")

# DATASET_IDENTIFIER should be "oasis2" (defined in setup)
# PROJECT_ROOT and base_config should be available from setup

TEST_DATA_PATH_NB08 = None
TRAIN_DATA_PATH_FOR_SHAP_BG_NB08 = None
# SCALER_PATH_NB08 and IMPUTER_PATH_NB08 will be derived from the specific model's training config later
# However, we can resolve the *expected* paths based on current config for direct use if needed (e.g. SHAP bg dataset)
EXPECTED_SCALER_PATH_NB08 = None
EXPECTED_IMPUTER_PATH_NB08 = None
MRI_DATA_DIR_NB08 = None

try:
    # Get paths for the "analysis" stage, which includes test data and also train data (for SHAP bg)
    # get_dataset_paths with stage="analysis" should return 'test_data_parquet' and 'train_data_for_analysis_bg'
    # It also returns scaler_path and imputer_path based on current config.json, which can serve as a reference
    # or for instantiating OASISDataset directly for SHAP background if needed.
    analysis_stage_paths = get_dataset_paths(PROJECT_ROOT, base_config, DATASET_IDENTIFIER, stage="analysis")
    
    TEST_DATA_PATH_NB08 = analysis_stage_paths.get('test_data_parquet')
    TRAIN_DATA_PATH_FOR_SHAP_BG_NB08 = analysis_stage_paths.get('train_data_for_analysis_bg') # This key is from paths_utils
    if TRAIN_DATA_PATH_FOR_SHAP_BG_NB08 is None: # Fallback if 'train_data_for_analysis_bg' isn't a key
        TRAIN_DATA_PATH_FOR_SHAP_BG_NB08 = analysis_stage_paths.get('train_data_parquet')

    EXPECTED_SCALER_PATH_NB08 = analysis_stage_paths.get('scaler_path') 
    EXPECTED_IMPUTER_PATH_NB08 = analysis_stage_paths.get('imputer_path')
    MRI_DATA_DIR_NB08 = analysis_stage_paths.get('mri_data_dir')

    if not all([TEST_DATA_PATH_NB08, TRAIN_DATA_PATH_FOR_SHAP_BG_NB08, 
                EXPECTED_SCALER_PATH_NB08, EXPECTED_IMPUTER_PATH_NB08, MRI_DATA_DIR_NB08]):
        raise ValueError("One or more essential paths for NB08 analysis could not be resolved from get_dataset_paths.")

    print(f"  Test Data Path (for analysis_dataset): {TEST_DATA_PATH_NB08}")
    print(f"  Train Data Path (for SHAP background): {TRAIN_DATA_PATH_FOR_SHAP_BG_NB08}")
    print(f"  Expected Scaler Path (reference): {EXPECTED_SCALER_PATH_NB08}")
    print(f"  Expected Imputer Path (reference): {EXPECTED_IMPUTER_PATH_NB08}")
    print(f"  MRI Data Directory: {MRI_DATA_DIR_NB08}")

    # Verify existence of data files needed
    if not TEST_DATA_PATH_NB08.is_file(): raise FileNotFoundError(f"Test data parquet not found: {TEST_DATA_PATH_NB08}")
    if not TRAIN_DATA_PATH_FOR_SHAP_BG_NB08.is_file(): raise FileNotFoundError(f"Train data for SHAP background not found: {TRAIN_DATA_PATH_FOR_SHAP_BG_NB08}")
    # Preprocessor files will be effectively checked when OASISDataset tries to load them via paths from config_for_this_model_dataset
    if not MRI_DATA_DIR_NB08.is_dir(): raise FileNotFoundError(f"MRI Data Directory not found: {MRI_DATA_DIR_NB08}")
    print("Key data paths for NB08 resolved and inputs verified.")

except (KeyError, ValueError, FileNotFoundError) as e_paths_nb08:
    print(f"CRITICAL ERROR during path setup for NB08: {e_paths_nb08}")
    # exit()
except Exception as e_general_nb08_paths:
    print(f"CRITICAL ERROR during path setup for NB08: {e_general_nb08_paths}")
    # exit()

## 4. Initialize Weights & Biases Run for Notebook 08 Analysis

A new W&B run is initiated for this comprehensive model analysis notebook (NB08). This run will track:
* The configuration parameters specific to this analysis execution (e.g., list of models analyzed, MC Dropout settings, SHAP parameters).
* All generated plots from uncertainty and interpretability analyses.
* Summary tables or metrics comparing different models or analysis results.
A run-specific local output directory is also created for saving plots and other files generated by this notebook.

In [None]:
# --- Initialize W&B Run for this Analysis Notebook (NB08) ---
print("\n--- Initializing Weights & Biases Run for Notebook 08 Analysis ---")

# Configuration specific to this NB08 run
nb08_run_config_log = {
    "notebook_name_code": f"{NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER}",
    "dataset_source": DATASET_IDENTIFIER,
    # MODELS_TO_ANALYZE_FETCH will be populated after dynamic fetching
    "models_to_be_analyzed_sources": MODELS_TO_ANALYZE_FETCH if 'MODELS_TO_ANALYZE_FETCH' in locals() and MODELS_TO_ANALYZE_FETCH else "Manual List or Fetch Failed",
    "num_mc_samples_configured": N_MC_SAMPLES,
    "num_detailed_analysis_samples_target": N_SAMPLES_FOR_DETAILED_ANALYSIS,
    "execution_timestamp": time.strftime("%Y-%m-%d %H:%M:%S")
}
# Merge HP_analysis which contains the boolean flags for analyses and SHAP/IG params
if 'HP_analysis' in locals() and isinstance(HP_analysis, dict):
    nb08_run_config_log.update(HP_analysis) 

nb_prefix_nb08 = NOTEBOOK_MODULE_NAME.split('_')[0] if '_' in NOTEBOOK_MODULE_NAME else "NB"
# Using DATASET_IDENTIFIER in job_specific_type
job_specific_type_nb08 = f"{nb_prefix_nb08}-ModelAnalysis-{DATASET_IDENTIFIER}" 
custom_elements_name_nb08 = [nb_prefix_nb08, DATASET_IDENTIFIER.upper(), "FullAnalysis"]

run_nb08 = initialize_wandb_run( # Assign to run_nb08
    base_project_config=base_config,
    job_group="Analysis",
    job_specific_type=job_specific_type_nb08,
    run_specific_config=nb08_run_config_log,
    custom_run_name_elements=custom_elements_name_nb08,
    notes=f"Comprehensive Analysis (Uncertainty, Interpretability) for {DATASET_IDENTIFIER.upper()} models."
)

RUN_OUTPUT_DIR_NB08 = None # Initialize for clarity
if run_nb08:
    print(f"W&B run for NB08 '{run_nb08.name}' (Job Type: '{run_nb08.job_type}') initialized. View at: {run_nb08.url}")
    # Use the locator key defined in config.json for NB08's output subfolder
    # NB08_OUTPUT_LOCATOR_KEY was defined in Cell 3 of this notebook.
    RUN_OUTPUT_DIR_NB08 = get_notebook_run_output_dir(
        PROJECT_ROOT, base_config, NB08_OUTPUT_LOCATOR_KEY, run_nb08, DATASET_IDENTIFIER
    )
    print(f"Outputs for this NB08 run will be saved locally to: {RUN_OUTPUT_DIR_NB08}")
    # Log the actual output directory to W&B config
    run_nb08.config.update({"run_outputs/local_analysis_dir": str(RUN_OUTPUT_DIR_NB08)}, allow_val_change=True)
else:
    print("W&B run initialization failed for NB08. Local outputs may go to a default fallback path.")
    # Define a fallback RUN_OUTPUT_DIR_NB08 for local-only execution without W&B
    # NB08_OUTPUT_LOCATOR_KEY might not be defined if base_config load failed, so use NOTEBOOK_MODULE_NAME
    fallback_locator = NB08_OUTPUT_LOCATOR_KEY if 'NB08_OUTPUT_LOCATOR_KEY' in locals() and NB08_OUTPUT_LOCATOR_KEY else \
                       f"{NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER}_outputs"
    RUN_OUTPUT_DIR_NB08 = get_notebook_run_output_dir(
        PROJECT_ROOT, base_config if base_config else {}, # Pass empty dict if base_config failed
        fallback_locator, 
        None, DATASET_IDENTIFIER
    )
    print(f"Using fallback local output directory for NB08: {RUN_OUTPUT_DIR_NB08}")

# Ensure RUN_OUTPUT_DIR_NB08 is a Path object for subsequent use
if not isinstance(RUN_OUTPUT_DIR_NB08, Path):
    RUN_OUTPUT_DIR_NB08 = Path(RUN_OUTPUT_DIR_NB08)

## 5. Main Analysis Loop: Iterate Through Models

This section iterates through the models specified in `MODELS_TO_ANALYZE_FETCH`. For each model:
1.  **Load Model and Original Training Configuration:** The `load_model_from_wandb_artifact` utility is used to download the model artifact (e.g., the 'best' version) from the specified W&B training run and instantiate the model. The original W&B configuration from that training run (`original_model_train_config`) is also fetched.
2.  **Prepare `config_for_this_model_dataset`:** A definitive configuration dictionary is prepared for instantiating `OASISDataset`. This uses the `features`, `preprocess`, `cnn_model_params`, and `preprocessing_config` (for MRI suffix) that were logged to the W&B config of the *original training run* of the model (which themselves were sourced from the relevant Notebook 04 run). This ensures that `OASISDataset` processes data for analysis in a way that is perfectly consistent with how data was processed during that model's training.
3.  **Instantiate `analysis_dataset` and `analysis_loader`:** An `OASISDataset` instance is created using the `TEST_DATA_PATH_NB08` and the `config_for_this_model_dataset`. Preprocessor paths for this step should ideally also be derived from `original_model_train_config` if that run logged specific preprocessor artifact versions it used. For simplicity in this phase, we might assume all analyzed models used a common set of preprocessors whose paths (`SCALER_PATH_NB08`, `IMPUTER_PATH_NB08`) were resolved earlier. A `DataLoader` then provides batches for analysis.
4.  **Perform Selected Analyses:** Based on the boolean flags set in `HP_analysis` (e.g., `RUN_MC_DROPOUT`), the following analyses are performed:
    * **MC Dropout:** Quantifies predictive uncertainty.
    * **Permutation Feature Importance:** Assesses tabular feature importance.
    * **Integrated Gradients:** (For hybrid models) Visualizes CNN voxel attributions.
    * **SHAP Analysis:** Provides SHAP values for baseline LSTM (on tabular features) and for the fusion stage of the hybrid model.
5.  **Log Results:** All generated metrics, tables, and plots are logged to the current Notebook 08 W&B run (`run_nb08`) and saved locally to `RUN_OUTPUT_DIR_NB08`.
6.  Aggregate summary results for all analyzed models.

In [None]:
# --- Main Analysis Loop ---
print("\n--- Starting Main Analysis Loop for Selected Models ---")

# Ensure necessary global variables for the loop are defined
required_globals_for_loop = [
    'TEST_DATA_PATH_NB08', 'TRAIN_DATA_PATH_FOR_SHAP_BG_NB08', 
    'EXPECTED_SCALER_PATH_NB08', 'EXPECTED_IMPUTER_PATH_NB08', 'MRI_DATA_DIR_NB08',
    'RUN_OUTPUT_DIR_NB08', 'HP_analysis', 'MODELS_TO_ANALYZE_FETCH', 
    'base_config', 'device', 'run_nb08', 'N_MC_SAMPLES', 'N_SAMPLES_FOR_DETAILED_ANALYSIS',
    'DATASET_IDENTIFIER' # For consistent W&B logging keys/artifact names
]
if not all(var_name in locals() and locals()[var_name] is not None for var_name in required_globals_for_loop):
    print("CRITICAL ERROR: One or more global setup variables for the analysis loop are not defined or are None. "
          "Please check preceding cells (Setup, Config for NB08, Path Definitions, W&B Init for NB08).")
    # exit() # Or raise error
else:
    print(f"Analysis outputs will be saved to: {RUN_OUTPUT_DIR_NB08}")
    # Ensure output directory exists (it should have been created by get_notebook_run_output_dir)
    if not RUN_OUTPUT_DIR_NB08.exists(): RUN_OUTPUT_DIR_NB08.mkdir(parents=True, exist_ok=True)


all_models_analysis_results = [] # To store a summary of results for each model

for model_nickname, model_training_run_path in MODELS_TO_ANALYZE_FETCH.items():
    print(f"\n\n{'='*60}")
    print(f"ANALYZING MODEL: {model_nickname} (Source W&B Run: {model_training_run_path})")
    print(f"{'='*60}")
    
    current_model_results = {"model_nickname": model_nickname, "source_wandb_run_path": model_training_run_path}

    # --- 1. Load Model and its Original Training Configuration ---
    print(f"  Loading model and original training config for {model_nickname}...")
    loaded_model, original_model_train_config_dict, is_model_hybrid = load_model_from_wandb_artifact(
        run_path=model_training_run_path, 
        base_config_dict=base_config, 
        device_to_load=device
    )

    if loaded_model is None:
        print(f"  ERROR: Failed to load model for '{model_nickname}'. Skipping this model.")
        current_model_results["status_load_model"] = "Failed"
        all_models_analysis_results.append(current_model_results)
        if run_nb08: run_nb08.log({f"{model_nickname}/status/model_load": "Failed"})
        continue 
    
    current_model_results["model_type_is_hybrid"] = is_model_hybrid
    print(f"  Model '{model_nickname}' loaded successfully. Type: {'Hybrid' if is_model_hybrid else 'Baseline'}.")

    # --- 2. Prepare `config_for_this_model_dataset` ---
    # This uses the *actual configuration* that OASISDataset used during the model's training,
    # fetched from the training run's W&B config (logged by NB06/NB07, sourced from NB04).
    config_for_this_model_dataset = {
        'features': original_model_train_config_dict.get('dataset_config_used/features', {}),
        'preprocess': original_model_train_config_dict.get('dataset_config_used/preprocess', {}),
        'cnn_model_params': original_model_train_config_dict.get('dataset_config_used/cnn_model_params', 
                                                                base_config.get('cnn_model_params', {})),
        'preprocessing_config': original_model_train_config_dict.get('dataset_config_used/preprocessing_config_mri', 
                                                                     base_config.get('preprocessing_config', {}))
    }
    if not config_for_this_model_dataset.get('features') or not config_for_this_model_dataset.get('preprocess'):
        print(f"  WARNING: Critical 'features' or 'preprocess' sections missing in config_for_this_model_dataset "
              f"derived from run {model_training_run_path}. Analysis might use fallbacks or fail.")
        # Provide minimal fallback if critical keys are missing from original_model_train_config_dict
        config_for_this_model_dataset.setdefault('features', {'time_varying':[], 'static':[]})
        config_for_this_model_dataset.setdefault('preprocess', {'imputation_cols':[], 'scaling_cols':[]})
        config_for_this_model_dataset.setdefault('cnn_model_params', base_config.get('cnn_model_params', {}))
        config_for_this_model_dataset.setdefault('preprocessing_config', base_config.get('preprocessing_config', {}))


    # --- 3. Instantiate `analysis_dataset` (Test Set) and `analysis_loader` ---
    # For analysis, use preprocessors consistent with what this model was trained on.
    # EXPECTED_SCALER_PATH_NB08 and EXPECTED_IMPUTER_PATH_NB08 (defined globally in NB08)
    # are assumed to be the paths to the .joblib files from the "official" NB04 run
    # whose config was used by this model's training run (NB06/NB07).
    # A more advanced MLOps setup would involve NB06/07 logging the *exact artifact versions*
    # of preprocessors they used, and NB08 fetching those specific versions.
    # For now, this assumes all analyzed models used preprocessors at EXPECTED_..._PATH.
    print(f"  Instantiating analysis_dataset (test set) for {model_nickname}...")
    try:
        analysis_dataset = OASISDataset(
            data_parquet_path=TEST_DATA_PATH_NB08,    
            scaler_path=EXPECTED_SCALER_PATH_NB08,    # Path to .joblib from an official NB04 run
            imputer_path=EXPECTED_IMPUTER_PATH_NB08,  # Path to .joblib from an official NB04 run
            config=config_for_this_model_dataset,     # Config from this model's training run (NB04 via NB06/07)
            mri_data_dir=MRI_DATA_DIR_NB08 if is_model_hybrid else None,
            include_mri=is_model_hybrid
        )
        analysis_loader = DataLoader(
            analysis_dataset,
            batch_size=HP_analysis.get('analysis_batch_size', 4), 
            shuffle=False, 
            collate_fn=pad_collate_fn,
            num_workers=0 
        )
        print(f"  Analysis DataLoader created for {model_nickname} with {len(analysis_dataset)} subjects.")
    except Exception as e_ds_load_loop:
        print(f"  ERROR creating analysis dataset/loader for {model_nickname}: {e_ds_load_loop}")
        current_model_results["status_dataset_load"] = f"Error: {e_ds_load_loop}"
        all_models_analysis_results.append(current_model_results)
        if run_nb08: run_nb08.log({f"{model_nickname}/status/dataset_load": "Failed"})
        continue 

    # --- A. MC Dropout Analysis ---
    if HP_analysis.get('RUN_MC_DROPOUT', False):
        print(f"\n  --- A. MC Dropout Analysis for {model_nickname} ---")
        try:
            all_predictions_mc, all_actuals_mc = get_mc_dropout_predictions(
                loaded_model, analysis_loader, N_MC_SAMPLES, device
            )
            if all_predictions_mc and all_actuals_mc: # Ensure both lists are populated
                mc_uncertainty_stats = calculate_uncertainty_metrics(all_predictions_mc)
                if mc_uncertainty_stats:
                    # Filter out NaN values before calculating mean for robustness
                    valid_variances = [s['variance'] for s in mc_uncertainty_stats if 'variance' in s and not np.isnan(s['variance'])]
                    valid_std_devs = [s['std_dev'] for s in mc_uncertainty_stats if 'std_dev' in s and not np.isnan(s['std_dev'])]
                    avg_mc_variance = np.mean(valid_variances) if valid_variances else np.nan
                    avg_mc_std_dev = np.mean(valid_std_devs) if valid_std_devs else np.nan
                    
                    current_model_results["Avg_MC_Variance"] = avg_mc_variance
                    current_model_results["Avg_MC_Std_Dev"] = avg_mc_std_dev
                    print(f"    {model_nickname} - Avg MC Variance: {avg_mc_variance:.4e}, Avg MC Std Dev: {avg_mc_std_dev:.4f}")
                    if run_nb08:
                        run_nb08.log({
                            f"{model_nickname}/uncertainty/avg_mc_variance": avg_mc_variance,
                            f"{model_nickname}/uncertainty/avg_mc_std_dev": avg_mc_std_dev
                        })

                    # Plotting for MC Dropout
                    actuals_np = np.array(all_actuals_mc)
                    mean_preds_np = np.array([s['mean'] for s in mc_uncertainty_stats])
                    std_devs_np = np.array([s['std_dev'] for s in mc_uncertainty_stats]) # Already filtered for NaNs if valid_std_devs was used

                    # 1. Actual vs. Mean Predicted CDR, Colored by Uncertainty
                    fig_mc_avp, ax_mc_avp = plt.subplots(figsize=(8, 6))
                    scatter = ax_mc_avp.scatter(actuals_np, mean_preds_np, c=std_devs_np, cmap='viridis', alpha=0.7, vmin=0) # Ensure vmin for colorbar
                    min_val_plot = min(actuals_np.min(), mean_preds_np.min()) if actuals_np.size > 0 else 0
                    max_val_plot = max(actuals_np.max(), mean_preds_np.max()) if actuals_np.size > 0 else 1
                    ax_mc_avp.plot([min_val_plot, max_val_plot], [min_val_plot, max_val_plot], 'r--', lw=2, label="Ideal")
                    fig_mc_avp.colorbar(scatter, ax=ax_mc_avp, label='Predictive Std Dev (Uncertainty)')
                    ax_mc_avp.set_title(f'MC Dropout: Actual vs. Mean Predicted CDR ({model_nickname})')
                    ax_mc_avp.set_xlabel('Actual CDR')
                    ax_mc_avp.set_ylabel('Mean Predicted CDR')
                    ax_mc_avp.legend()
                    ax_mc_avp.grid(True, linestyle='--', alpha=0.7)
                    finalize_plot(fig_mc_avp, plt, run_nb08, 
                                  f"{model_nickname}/uncertainty/plot_actual_vs_pred_by_std", 
                                  RUN_OUTPUT_DIR_NB08 / f"{model_nickname}_mc_actual_vs_pred.png")

                    # 2. Prediction Error vs. Uncertainty
                    errors_np = np.abs(actuals_np - mean_preds_np)
                    fig_mc_evu, ax_mc_evu = plt.subplots(figsize=(8, 6))
                    ax_mc_evu.scatter(std_devs_np, errors_np, alpha=0.7)
                    ax_mc_evu.set_title(f'MC Dropout: Prediction Error vs. Uncertainty ({model_nickname})')
                    ax_mc_evu.set_xlabel('Predictive Std Dev (Uncertainty)')
                    ax_mc_evu.set_ylabel('Absolute Prediction Error |Actual - Mean Pred|')
                    ax_mc_evu.grid(True, linestyle='--', alpha=0.7)
                    finalize_plot(fig_mc_evu, plt, run_nb08, 
                                  f"{model_nickname}/uncertainty/plot_error_vs_std", 
                                  RUN_OUTPUT_DIR_NB08 / f"{model_nickname}_mc_error_vs_uncertainty.png")
                    
                    # 3. Histogram of Predictive Standard Deviations
                    fig_mc_hist, ax_mc_hist = plt.subplots(figsize=(8, 6))
                    sns.histplot(std_devs_np, kde=True, bins=15, ax=ax_mc_hist, stat="density")
                    ax_mc_hist.set_title(f'MC Dropout: Distribution of Predictive Std Devs ({model_nickname})')
                    ax_mc_hist.set_xlabel('Predictive Std Dev (Uncertainty)')
                    ax_mc_hist.set_ylabel('Density')
                    finalize_plot(fig_mc_hist, plt, run_nb08, 
                                  f"{model_nickname}/uncertainty/plot_uncertainty_dist", 
                                  RUN_OUTPUT_DIR_NB08 / f"{model_nickname}_mc_uncertainty_distribution.png")
                    
                    # Display detailed per-sample results table for a few samples
                    print(f"    MC Dropout Results for first {N_SAMPLES_FOR_DETAILED_ANALYSIS} test samples ({model_nickname}):")
                    mc_results_df_list = []
                    for i_mc in range(min(N_SAMPLES_FOR_DETAILED_ANALYSIS, len(all_actuals_mc))):
                        mc_results_df_list.append({
                            "Sample_Index": i_mc, 
                            "Actual_CDR": all_actuals_mc[i_mc],
                            "MC_Mean_Pred": mc_uncertainty_stats[i_mc]['mean'],
                            "MC_Std_Dev": mc_uncertainty_stats[i_mc]['std_dev'],
                            "MC_Variance": mc_uncertainty_stats[i_mc]['variance']
                        })
                    if mc_results_df_list:
                        display(pd.DataFrame(mc_results_df_list)) # For Jupyter display
                        if run_nb08: # Log as W&B Table
                            run_nb08.log({f"{model_nickname}/uncertainty/detailed_mc_samples_table": wandb.Table(dataframe=pd.DataFrame(mc_results_df_list))})
                else: 
                    print(f"    MC Dropout produced empty or invalid uncertainty stats for {model_nickname}.")
                    current_model_results["status_mc_dropout"] = "EmptyStats"
            else: 
                print(f"    MC Dropout did not produce predictions for {model_nickname}.")
                current_model_results["status_mc_dropout"] = "NoPredictions"
        except Exception as e_mc_main:
            print(f"    ERROR during MC Dropout analysis for {model_nickname}: {e_mc_main}")
            current_model_results["status_mc_dropout"] = f"Error: {str(e_mc_main)[:100]}"
            import traceback; traceback.print_exc()
    else:
        print(f"  Skipping MC Dropout Analysis for {model_nickname} as per RUN_MC_DROPOUT flag.")


    # --- B. Permutation Feature Importance (PFI) ---
    if HP_analysis.get('RUN_PERMUTATION_IMPORTANCE', False):
        print(f"\n  --- B. Permutation Feature Importance for {model_nickname} ---")
        try:
            # Get tabular features from the analysis_dataset instance
            # model_input_features from OASISDataset includes both time-varying and static (after M/F encoding)
            tabular_features_for_pfi = analysis_dataset.model_input_features 
            if not tabular_features_for_pfi:
                print("    Skipping PFI: No tabular features found in analysis_dataset.")
            else:
                tabular_feature_indices_pfi = {name: i for i, name in enumerate(tabular_features_for_pfi)}
                
                # Define criterion if not already defined (e.g., if MC Dropout was skipped)
                if 'criterion' not in locals() or criterion is None: criterion = nn.MSELoss()
                
                print("    Calculating baseline R2 for PFI using analysis_loader (test set)...")
                model_unpack_flag_pfi = "hybrid" if is_model_hybrid else "baseline"
                baseline_eval_metrics_pfi = evaluate_model(
                    loaded_model, analysis_loader, criterion, device, 
                    model_name_for_batch_unpack=model_unpack_flag_pfi
                )
                baseline_r2_for_pfi = baseline_eval_metrics_pfi.get('r2', np.nan) # Default to NaN
                current_model_results["PFI_Baseline_R2"] = baseline_r2_for_pfi
                print(f"    Baseline R2 for PFI: {baseline_r2_for_pfi:.4f}")

                if not np.isnan(baseline_r2_for_pfi): # Proceed only if baseline R2 is valid
                    perm_importance_results = calculate_permutation_importance(
                        loaded_model, analysis_loader, 
                        feature_names_to_permute=tabular_features_for_pfi,
                        tabular_feature_indices=tabular_feature_indices_pfi,
                        metric_fn=lambda y_true, y_pred: r2_score(np.array(y_true), np.array(y_pred)), # Ensure numpy arrays
                        baseline_score=baseline_r2_for_pfi,
                        device=device, is_hybrid_model=is_model_hybrid,
                        num_permutations=HP_analysis.get('pfi_num_permutations', 5)
                    )
                    current_model_results["PFI_Importances_R2_Drop"] = perm_importance_results
                    print(f"    {model_nickname} - PFI (Drop in R2):")
                    # Print top N important features, e.g., top 10 or all if fewer than 10
                    num_pfi_to_print = min(10, len(perm_importance_results))
                    for feat, imp in sorted(perm_importance_results.items(), key=lambda item: item[1], reverse=True)[:num_pfi_to_print]:
                        print(f"      {feat}: {imp:.4f}")
                    
                    if run_nb08:
                        run_nb08.log({f"{model_nickname}/interpretability/pfi_r2_drop_dict": perm_importance_results})
                        if perm_importance_results:
                            perm_df_pfi = pd.DataFrame(
                                list(perm_importance_results.items()), 
                                columns=['Feature', 'Importance_R2_Drop']
                            ).sort_values(by='Importance_R2_Drop', ascending=False)
                            
                            # Log PFI results as a W&B Table
                            try:
                                run_nb08.log({f"{model_nickname}/interpretability/pfi_table": wandb.Table(dataframe=perm_df_pfi)})
                            except Exception as e_pfi_tbl_log: print(f"Warning: Could not log PFI table to W&B. Error: {e_pfi_tbl_log}")

                            # Plot PFI (e.g., top 15 features)
                            fig_pfi, ax_pfi = plt.subplots(figsize=(10, max(6, len(perm_df_pfi.head(15)) * 0.35)))
                            sns.barplot(x='Importance_R2_Drop', y='Feature', data=perm_df_pfi.head(15), ax=ax_pfi, palette="viridis")
                            ax_pfi.set_title(f'Permutation Feature Importance (Top 15) for {model_nickname}')
                            finalize_plot(fig_pfi, plt, run_nb08, 
                                          f"{model_nickname}/interpretability/plot_pfi_top15", 
                                          RUN_OUTPUT_DIR_NB08 / f"{model_nickname}_pfi_top15_plot.png")
                else: 
                    print("    Skipping PFI calculation: Baseline R2 is NaN or no tabular features.")
        except Exception as e_pfi_main:
            print(f"    ERROR during Permutation Feature Importance for {model_nickname}: {e_pfi_main}")
            current_model_results["status_pfi"] = f"Error: {str(e_pfi_main)[:100]}"
            import traceback; traceback.print_exc()
    else:
        print(f"  Skipping Permutation Feature Importance for {model_nickname} as per RUN_PERMUTATION_IMPORTANCE flag.")


    # --- C. CNN Interpretability (Integrated Gradients) ---
    # RUN_SALIENCY_MAPS flag was renamed to RUN_INTEGRATED_GRADIENTS in HP_analysis
    if is_model_hybrid and hasattr(loaded_model, 'cnn_feature_extractor') and HP_analysis.get('RUN_INTEGRATED_GRADIENTS', False):
        print(f"\n  --- C. CNN Interpretability (Integrated Gradients) for {model_nickname} ---")
        try:
            # Determine how many samples/batches to get IG for.
            # For visualization, usually one batch (or first few samples from it) is enough.
            num_ig_batches_to_process = 1 
            ig_samples_visualized_count = 0
            
            for batch_idx_ig, ig_batch_data_full in enumerate(analysis_loader):
                if batch_idx_ig >= num_ig_batches_to_process: break
                
                # Unpack: analysis_loader for hybrid yields (tab_seq, mri_seq, lengths, targets, masks)
                _, s_mri_sequences_for_ig, s_lengths_for_ig, _, _ = ig_batch_data_full
                
                # Collect the first valid MRI scan from each sequence in the current batch
                mris_to_explain_in_batch_list = []
                for i_seq in range(s_mri_sequences_for_ig.size(0)): 
                    if s_lengths_for_ig[i_seq].item() > 0: # If sequence has at least one MRI
                        mris_to_explain_in_batch_list.append(s_mri_sequences_for_ig[i_seq, 0, :, :, :, :]) # Get 0th MRI of sequence
                
                if not mris_to_explain_in_batch_list:
                    print(f"    No valid MRIs in analysis_loader batch {batch_idx_ig} for IG. Skipping this batch.")
                    continue

                mri_batch_for_ig_utility_input = torch.stack(mris_to_explain_in_batch_list).to(device)
                print(f"    Explaining MRI batch (shape: {mri_batch_for_ig_utility_input.shape}) for Integrated Gradients...")

                attributions_ig_batch = generate_integrated_gradients_cnn(
                    cnn_model_part=loaded_model.cnn_feature_extractor, 
                    mri_input_tensor_batch=mri_batch_for_ig_utility_input,
                    n_steps=HP_analysis.get('ig_n_steps', 25), # Use a reasonable default
                    captum_internal_batch_size=HP_analysis.get('ig_captum_internal_batch_size') # Can be None
                )

                if attributions_ig_batch is not None:
                    current_model_results["IG_Attributions_Shape"] = list(attributions_ig_batch.shape)
                    print(f"    Generated IG attribution maps, batch shape: {attributions_ig_batch.shape}")
                    
                    # Visualize IG for a limited number of samples from this batch
                    num_samples_to_plot_ig = min(attributions_ig_batch.shape[0], HP_analysis.get('num_detailed_samples', 1))
                    print(f"    Visualizing IG for first {num_samples_to_plot_ig} sample(s) in this batch...")

                    for sample_idx_plot_ig in range(num_samples_to_plot_ig):
                        # Assuming single channel attribution, C=1
                        if attributions_ig_batch.shape[1] == 1: 
                            attr_map_single_sample = attributions_ig_batch[sample_idx_plot_ig, 0] # Shape (D, H, W)
                            original_mri_single_sample = mri_batch_for_ig_utility_input[sample_idx_plot_ig, 0].cpu().numpy() # (D,H,W)

                            # Choose a central slice (e.g., axial) for visualization
                            slice_idx_d = attr_map_single_sample.shape[0] // 2 
                            attr_slice = attr_map_single_sample[slice_idx_d, :, :]
                            mri_slice = original_mri_single_sample[slice_idx_d, :, :]

                            fig_ig, axs_ig = plt.subplots(1, 2, figsize=(12, 5))
                            fig_ig.suptitle(f"Integrated Gradients for {model_nickname} - Explained Sample {sample_idx_plot_ig} (Batch {batch_idx_ig}), Axial Slice {slice_idx_d}", fontsize=14)
                            
                            axs_ig[0].imshow(np.rot90(mri_slice), cmap='gray')
                            axs_ig[0].set_title("Original MRI Slice")
                            axs_ig[0].axis('off')

                            im_ig = axs_ig[1].imshow(np.rot90(mri_slice), cmap='gray') # Show original as background
                            im_attr = axs_ig[1].imshow(np.rot90(attr_slice), cmap='hot', alpha=0.6, 
                                                       vmin=np.percentile(attr_slice,1), vmax=np.percentile(attr_slice,99)) # Overlay heatmap
                            axs_ig[1].set_title("IG Attribution Overlay")
                            axs_ig[1].axis('off')
                            fig_ig.colorbar(im_attr, ax=axs_ig[1], label="Attribution Strength", fraction=0.046, pad=0.04)
                            
                            finalize_plot(fig_ig, plt, run_nb08, 
                                          f"{model_nickname}/interpretability/ig_sample{ig_samples_visualized_count + sample_idx_plot_ig}_axial", 
                                          RUN_OUTPUT_DIR_NB08 / f"{model_nickname}_ig_sample{ig_samples_visualized_count + sample_idx_plot_ig}_axial_slice{slice_idx_d}.png")
                        else:
                            print(f"    Cannot plot IG for sample {sample_idx_plot_ig}, unexpected channel count: {attributions_ig_batch.shape[1]}")
                    
                    ig_samples_visualized_count += attributions_ig_batch.shape[0]
                    if run_nb08: run_nb08.log({f"{model_nickname}/interpretability/ig_batch_processed_successfully": True})
                else:
                    print(f"    Integrated Gradients generation returned None for {model_nickname}.")
                    if run_nb08: run_nb08.log({f"{model_nickname}/interpretability/ig_generation_failed": True})
                # break # Usually explain one batch for IG visualization is enough for a demo
        except Exception as e_ig_main:
            print(f"    ERROR during Integrated Gradients for {model_nickname}: {e_ig_main}")
            current_model_results["status_ig"] = f"Error: {str(e_ig_main)[:100]}"
            import traceback; traceback.print_exc()
    elif is_model_hybrid and not HP_analysis.get('RUN_INTEGRATED_GRADIENTS', False):
        print(f"  Skipping CNN Interpretability (Integrated Gradients) for {model_nickname} as per RUN_INTEGRATED_GRADIENTS flag.")
    # No IG for baseline models


    # --- D. SHAP Analysis ---
    # SHAP for Baseline Model (Tabular LSTM)
    if not is_model_hybrid and HP_analysis.get('RUN_SHAP_BASELINE', False):
        print(f"\n  --- D. SHAP Analysis for {model_nickname} (BaselineLSTM using DeepExplainer) ---")
        try:
            # Create DataLoaders for SHAP: background from training data, instances from test data (analysis_dataset)
            # For baseline, OASISDataset needs config_for_this_model_dataset (for this baseline model) and include_mri=False
            # Preprocessor paths EXPECTED_SCALER_PATH_NB08 and EXPECTED_IMPUTER_PATH_NB08 are used.
            print("    SHAP (Baseline): Instantiating training dataset for background...")
            train_dataset_for_shap_bg_baseline = OASISDataset(
                TRAIN_DATA_PATH_FOR_SHAP_BG_NB08, 
                EXPECTED_SCALER_PATH_NB08, # Assuming these are the correct preprocessors
                EXPECTED_IMPUTER_PATH_NB08,    
                config=config_for_this_model_dataset, # Config for the current baseline model
                include_mri=False # Explicitly False for baseline
            )
            shap_background_loader_baseline = DataLoader( 
                torch.utils.data.Subset(train_dataset_for_shap_bg_baseline, list(range(min(HP_analysis.get('shap_num_background_samples', 20), len(train_dataset_for_shap_bg_baseline))))),
                batch_size=HP_analysis.get('shap_background_batch_size', 10), 
                shuffle=True, collate_fn=pad_collate_fn, num_workers=0
            )
            # analysis_dataset is already instantiated for the current model (baseline)
            shap_instances_loader_baseline = DataLoader( 
                 torch.utils.data.Subset(analysis_dataset, list(range(min(HP_analysis.get('num_detailed_samples', 5), len(analysis_dataset))))),
                 batch_size=HP_analysis.get('shap_explain_batch_size', HP_analysis.get('num_detailed_samples', 5)),
                 shuffle=False, collate_fn=pad_collate_fn, num_workers=0
            )
            current_feature_names_for_shap = analysis_dataset.model_input_features

            shap_values_lstm, explained_instances_np_lstm = explain_lstm_with_shap(
                loaded_model, shap_background_loader_baseline, shap_instances_loader_baseline, 
                device, feature_names=current_feature_names_for_shap
            )
            if shap_values_lstm is not None and explained_instances_np_lstm is not None:
                current_model_results["SHAP_Baseline_Values_Shape"] = list(shap_values_lstm.shape)
                print("    SHAP values for BaselineLSTM obtained.")
                # Plotting (Bar plot of mean absolute SHAP values)
                if shap_values_lstm.ndim == 3 and len(current_feature_names_for_shap) == shap_values_lstm.shape[2]:
                    mean_abs_shap_baseline = np.mean(np.abs(shap_values_lstm), axis=(0,1))
                    shap_summary_df_baseline = pd.DataFrame({
                        'feature': current_feature_names_for_shap, 
                        'mean_abs_shap': mean_abs_shap_baseline
                    }).sort_values(by='mean_abs_shap', ascending=False)
                    
                    fig_shap_bl, ax_shap_bl = plt.subplots(figsize=(10, max(6, len(shap_summary_df_baseline) * 0.35)))
                    sns.barplot(x='mean_abs_shap', y='feature', data=shap_summary_df_baseline.head(15), ax=ax_shap_bl, palette="mako")
                    ax_shap_bl.set_title(f'SHAP Mean Abs. Feature Importance for {model_nickname}')
                    finalize_plot(fig_shap_bl, plt, run_nb08, 
                                  f"{model_nickname}/interpretability/shap_baseline_mean_abs_bar", 
                                  RUN_OUTPUT_DIR_NB08 / f"{model_nickname}_shap_baseline_barplot.png")
                    
                    # Beeswarm if seq_len was 1
                    if shap_values_lstm.shape[1] == 1 and shap is not None : # Check if shap imported in NB08
                        shap_values_2d = shap_values_lstm.squeeze(axis=1)
                        explained_instances_2d = explained_instances_np_lstm.squeeze(axis=1)
                        plt.figure() # New figure for beeswarm
                        shap.summary_plot(shap_values_2d, features=explained_instances_2d, feature_names=current_feature_names_for_shap, show=False)
                        plt.title(f"SHAP Summary (Beeswarm) for {model_nickname}")
                        finalize_plot(plt.gcf(), plt, run_nb08, # Get current figure
                                      f"{model_nickname}/interpretability/shap_baseline_beeswarm",
                                      RUN_OUTPUT_DIR_NB08 / f"{model_nickname}_shap_baseline_beeswarm.png")
                else: print("    SHAP values for baseline have unexpected shape or feature name mismatch for plotting.")
                if run_nb08: run_nb08.log({f"{model_nickname}/interpretability/shap_baseline_generated": True})
            else: 
                print(f"    SHAP (Baseline) analysis did not produce results for {model_nickname}.")
                current_model_results["status_shap_baseline"] = "NoResults"
        except Exception as e_shap_baseline_main:
            print(f"    ERROR during SHAP (Baseline) for {model_nickname}: {e_shap_baseline_main}")
            current_model_results["status_shap_baseline"] = f"Error: {str(e_shap_baseline_main)[:100]}"
            import traceback; traceback.print_exc()

    # SHAP for Hybrid Model (Fusion Stage)
    if is_model_hybrid and HP_analysis.get('RUN_SHAP_HYBRID_FUSION', False):
        print(f"\n  --- D. SHAP Analysis for {model_nickname} (Hybrid Model - Fusion Stage using KernelExplainer) ---")
        try:
            # Create DataLoaders for SHAP background and instances (these need full hybrid inputs)
            train_dataset_for_hybrid_shap_bg = OASISDataset(
                TRAIN_DATA_PATH_FOR_SHAP_BG_NB08, EXPECTED_SCALER_PATH_NB08, EXPECTED_IMPUTER_PATH_NB08,
                config=config_for_this_model_dataset, # Config for the current hybrid model
                mri_data_dir=MRI_DATA_DIR_NB08, include_mri=True
            )
            hybrid_shap_background_loader = DataLoader(
                 torch.utils.data.Subset(train_dataset_for_hybrid_shap_bg, 
                                         list(range(min(HP_analysis.get('shap_num_background_samples', 20), len(train_dataset_for_hybrid_shap_bg))))),
                 batch_size=HP_analysis.get('shap_background_batch_size', 4), 
                 shuffle=False, collate_fn=pad_collate_fn, num_workers=0
            )
            hybrid_shap_instances_loader = DataLoader(
                 torch.utils.data.Subset(analysis_dataset, list(range(min(HP_analysis.get('num_detailed_samples', 5), len(analysis_dataset))))),
                 batch_size=HP_analysis.get('shap_explain_batch_size', HP_analysis.get('num_detailed_samples', 5)),
                 shuffle=False, collate_fn=pad_collate_fn, num_workers=0
            )
            
            # Add comment on ConvergenceWarnings for KernelExplainer
            print("    Note: shap.KernelExplainer might produce ConvergenceWarnings from scikit-learn if features are highly collinear.")

            shap_values_f, explained_df_f, mri_f_count, tab_f_count = explain_hybrid_fusion_with_shap(
                loaded_model, hybrid_shap_background_loader, hybrid_shap_instances_loader, device,
                num_background_samples_for_kmeans=HP_analysis.get('shap_kmeans_k', 5),
                num_shap_samples=HP_analysis.get('shap_kernel_nsamples', 50)
            )
            if shap_values_f is not None and explained_df_f is not None:
                current_model_results["SHAP_HybridFusion_Values_Shape"] = list(shap_values_f.shape)
                current_model_results["SHAP_HybridFusion_MRI_Feat_Count"] = mri_f_count
                current_model_results["SHAP_HybridFusion_Tab_Feat_Count"] = tab_f_count
                print("    SHAP values for hybrid fusion stage obtained.")

                # Plotting for SHAP hybrid fusion (bar plot of all fused features)
                if shap is not None: # Ensure shap module is imported in NB08
                    fig_shap_hf, ax_shap_hf = plt.subplots(figsize=(12, max(8, explained_df_f.shape[1] * 0.25) )) # Dynamic height
                    shap.summary_plot(shap_values_f, features=explained_df_f, plot_type="bar", show=False, axis_color="black", color_bar=False, fig=fig_shap_hf) # Pass fig
                    # Manually adjust title for the figure, not the axes returned by shap.summary_plot
                    fig_shap_hf.suptitle(f"SHAP: Feature Importance at Fusion Stage - {model_nickname}", fontsize=14) # Use suptitle for figure
                    # plt.xlabel("Mean |SHAP value| (Average impact on model output magnitude at fusion)") # shap.summary_plot adds this
                    # fig_shap_hf.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust for suptitle
                    finalize_plot(fig_shap_hf, plt, run_nb08,
                                  f"{model_nickname}/interpretability/shap_hybrid_fusion_bar",
                                  RUN_OUTPUT_DIR_NB08 / f"{model_nickname}_SHAP_HybridFusion_BarPlot.png")
                
                if mri_f_count > 0 and tab_f_count > 0:
                    abs_shap_mri_stream_avg = np.mean(np.sum(np.abs(shap_values_f[:, :mri_f_count]), axis=1))
                    abs_shap_tab_stream_avg = np.mean(np.sum(np.abs(shap_values_f[:, mri_f_count:(mri_f_count+tab_f_count)]), axis=1)) # Correct slicing
                    current_model_results["SHAP_Hybrid_MRI_Stream_AvgTotalAbs"] = abs_shap_mri_stream_avg
                    current_model_results["SHAP_Hybrid_Tab_Stream_AvgTotalAbs"] = abs_shap_tab_stream_avg
                    print(f"    Avg total |SHAP| for MRI stream features at fusion: {abs_shap_mri_stream_avg:.4f}")
                    print(f"    Avg total |SHAP| for Tabular stream features at fusion: {abs_shap_tab_stream_avg:.4f}")
                    if run_nb08:
                        run_nb08.log({
                            f"{model_nickname}/interpretability/shap_hybrid_fusion_mri_stream_avg_total_abs": abs_shap_mri_stream_avg,
                            f"{model_nickname}/interpretability/shap_hybrid_fusion_tab_stream_avg_total_abs": abs_shap_tab_stream_avg
                        })
                if run_nb08: run_nb08.log({f"{model_nickname}/interpretability/shap_hybrid_fusion_generated": True})
            else: 
                print(f"    SHAP (Hybrid Fusion) analysis did not produce results for {model_nickname}.")
                current_model_results["status_shap_hybrid"] = "NoResults"
        except Exception as e_shap_hybrid_main:
            print(f"    ERROR during SHAP (Hybrid Fusion) for {model_nickname}: {e_shap_hybrid_main}")
            current_model_results["status_shap_hybrid"] = f"Error: {str(e_shap_hybrid_main)[:100]}"
            import traceback; traceback.print_exc()
    
    # General skip message if RUN_SHAP_ANALYSIS was false for the model type
    elif not HP_analysis.get('RUN_SHAP_BASELINE', False) and not HP_analysis.get('RUN_SHAP_HYBRID_FUSION', False):
        print(f"\n  --- D. SHAP Analysis: Skipped for {model_nickname} as per relevant RUN_SHAP flags ---")


    all_models_analysis_results.append(current_model_results)
    print(f"\n--- Finished All Analyses for: {model_nickname} ---")
# --- End of Main Analysis Loop ---

# Convert overall results to DataFrame for final summary display and logging
if all_models_analysis_results:
    nb08_summary_df = pd.DataFrame(all_models_analysis_results)
    # Ensure critical columns are present, fill with NaN if a model failed an analysis
    expected_cols = ["model_nickname", "source_wandb_run_path", "model_type_is_hybrid", 
                     "Avg_MC_Variance", "Avg_MC_Std_Dev", 
                     "PFI_Baseline_R2", "PFI_Importances (R2_Drop)",
                     "SHAP_Baseline_Values_Shape", 
                     "SHAP_HybridFusion_Values_Shape", "SHAP_Hybrid_MRI_Stream_AvgAbs", "SHAP_Hybrid_Tab_Stream_AvgAbs",
                     "status_mc_dropout", "status_pfi", "status_ig", "status_shap_baseline", "status_shap_hybrid"]
    for col in expected_cols:
        if col not in nb08_summary_df.columns:
            nb08_summary_df[col] = np.nan # Add missing columns with NaNs
            
    print("\n\n--- Notebook 08: Overall Analysis Summary ---")
    # display(nb08_summary_df) # Use display for richer output in Jupyter
    print(nb08_summary_df.to_string()) # Print full DataFrame to console

    if run_nb08: 
        try:
            summary_table_wandb = wandb.Table(dataframe=nb08_summary_df.fillna("N/A")) # Replace NaNs for W&B Table
            run_nb08.log({"analysis_run_summary/all_models_table": summary_table_wandb})
            print("Overall analysis summary table logged to W&B.")
        except Exception as e_log_summary_tbl:
            print(f"Error logging summary table to W&B: {e_log_summary_tbl}")
else:
    print("\nNo models were analyzed, or no results collected. Summary DataFrame not created.")

## 6. Finalize W&B Run for Notebook 08 Analysis

Complete the execution of this analysis notebook and finish its associated Weights & Biases run. This ensures all queued logs, metrics, configurations, plots, and summary tables are fully uploaded and synchronized with the W&B platform.

In [None]:
# --- Finish W&B Run for this Analysis Notebook (NB08) ---
# run_nb08 is the W&B run object for this notebook.

if run_nb08:
    print(f"\n--- Finishing W&B run '{run_nb08.name}' for Notebook 08 ---")
    try:
        # Add a final status to the summary if needed
        if not all_models_analysis_results: # Check if the main list is empty
             run_nb08.summary["overall_analysis_status"] = "NoModelsSuccessfullyAnalyzed"
        elif nb08_summary_df.empty: # Check if the summary DataFrame is empty
             run_nb08.summary["overall_analysis_status"] = "AnalysisRanButSummaryDFEmpty"
        else:
             run_nb08.summary["overall_analysis_status"] = "Completed"
        
        run_nb08.finish()
        # Use a more robust way to get the final run name for printing
        final_run_name_nb08 = run_nb08.name_synced if hasattr(run_nb08, 'name_synced') and run_nb08.name_synced else \
                              run_nb08.name if hasattr(run_nb08, 'name') and run_nb08.name else \
                              run_nb08.id if hasattr(run_nb08, 'id') else "current NB08 run"
        print(f"W&B run '{final_run_name_nb08}' finished successfully.")
    except Exception as e_finish_nb08:
        print(f"Error during wandb.finish() for Notebook 08: {e_finish_nb08}")
        print("The W&B run may not have finalized correctly on the server.")
else:
    print("\nNo active W&B run for Notebook 08 to finish (likely initialization failed or was skipped).")

print(f"\n--- Notebook {NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER} execution finished. ---")