# Notebook 04: Fit and Save Data Preprocessors (OASIS-2)

**Project Phase:** 1 (Data Processing - Preprocessing)
**Dataset:** OASIS-2 Longitudinal MRI & Clinical Data

**Purpose:**
This notebook is a critical step in preparing data for model training. It focuses on fitting data preprocessors (imputer and scaler) using *only* the training dataset split to prevent data leakage. The objectives are:

1.  **Use Training Data Artifact:** Consume the versioned `cohort_train_oasis2` W&B Artifact (produced by Notebook 03). This provides the exact training data split.
2.  **Fetch Feature Configuration from Producer Run:** From the input training data artifact, identify the W&B run (Notebook 03) that produced it and fetch the definitive lists of `time_varying_features` and `static_features` from that run's configuration. This ensures consistency.
3.  **Identify Preprocessing Columns:** Based on the loaded training data and the fetched feature lists, determine which specific columns require missing value imputation and which numerical columns are candidates for scaling.
4.  **Log Definitive Preprocessing Configuration to W&B:** Log the chosen imputation/scaling strategies, and the *exact lists* of columns that will be imputed and scaled, along with the final time-varying and static feature lists that `OASISDataset` should use. This W&B run's (Notebook 04's) configuration becomes the **source of truth** for data preprocessing details for `OASISDataset`.
5.  **Fit Preprocessors:** Initialize and fit `sklearn.impute.SimpleImputer` and `sklearn.preprocessing.StandardScaler` instances exclusively on the identified columns of the training data.
6.  **Save & Log Preprocessors:** Save these *fitted* preprocessor objects locally (e.g., as `.joblib` files) in a structured output directory. Log these fitted objects as versioned W&B Artifacts.

**Workflow:**
1.  **Setup:** Import libraries, configure `sys.path`, load `config.json`.
2.  **W&B Initialization:** Start a new W&B run for this preprocessor fitting task using `initialize_wandb_run`.
3.  **Load Training Data via W&B Artifact:** Use `run.use_artifact()` to get `cohort_train_oasis2`, download it, and load `cohort_train.parquet`.
4.  **Fetch Feature Config from NB03 Producer Run:** Get the run that produced the training data artifact using `artifact.logged_by()` and extract feature lists from its config.
5.  **Determine Imputation & Scaling Columns:** Based on the fetched feature lists and analysis of the `train_df`.
6.  **Log Final Preprocessing & Feature Config to Current W&B Run:** Update this Notebook 04 W&B run's config.
7.  **Fit, Save, & Log Imputer Artifact.**
8.  **Fit, Save, & Log Scaler Artifact.**
9.  **Finalize W&B Run.**

**Input:**
* `config.json`: Main project configuration file.
* **W&B Artifact Name for Training Data Split:** e.g., `"cohort_split_train_oasis2:latest"` (produced by Notebook 03).

**Output:**
* **Local Files (in designated preprocessors output directory defined by `config.json` and `paths_utils.py`):**
    * `simple_imputer_median_oasis2.joblib` (or similar, based on strategy and dataset)
    * `standard_scaler_oasis2.joblib` (or similar)
* **W&B Run (for this Notebook 04 execution):**
    * Logged run configuration containing:
        * Name of the input training data W&B Artifact used.
        * ID of the source Notebook 03 run that defined the features.
        * **The definitive `features` (time-varying, static) and `preprocess` (imputation_cols, scaling_cols, imputation_strategy, scaling_strategy) dictionaries that all downstream `OASISDataset` instances will fetch and use from *this run's config*.**
    * Fitted `SimpleImputer` and `StandardScaler` objects logged as new W&B Artifacts (e.g., `imputer-median-oasis2:latest`, `scaler-standard-oasis2:latest`).

In [None]:
# In: notebooks/04_Fit_Preprocessors.ipynb
# Purpose: Load the TRAINING split data, fit data scalers (StandardScaler)
#          and imputers (SimpleImputer) based ONLY on this training data,
#          and save these fitted objects for later use in the Dataset class.

In [None]:
# --- Import Libraries ---
import pandas as pd
import numpy as np
import wandb
import json
from pathlib import Path
import time
import os
import joblib
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler

## 1. Setup: Project Configuration, Paths, and Utilities

This section initializes the notebook environment:
* Determines the project's root directory and adds the `src` directory to `sys.path`.
* Imports custom utilities for W&B run initialization and path resolution.
* Loads the main project configuration (`config.json`).
* Defines dataset and notebook-specific identifiers.
* **Uses `get_dataset_paths` to resolve the input path for `cohort_train.parquet` (from Notebook 03) and the output directory where fitted preprocessors will be saved.** These paths are derived from `config.json` for consistency.

In [None]:
# --- Project Setup, Configuration Loading, and Utility Imports ---
print("--- Initializing Project Setup & Configuration for NB04 ---")

import sys

PROJECT_ROOT = None
base_config = {}
try:
    current_notebook_path = Path.cwd() 
    potential_project_root = current_notebook_path.parent 
    if (potential_project_root / "src").is_dir() and (potential_project_root / "config.json").is_file():
        PROJECT_ROOT = potential_project_root
    else: 
        PROJECT_ROOT = current_notebook_path
    if not (PROJECT_ROOT / "src").is_dir() or not (PROJECT_ROOT / "config.json").is_file():
        raise FileNotFoundError(f"Could not find 'src' or 'config.json'. PROJECT_ROOT: {PROJECT_ROOT}")
    if str(PROJECT_ROOT) not in sys.path:
        sys.path.insert(0, str(PROJECT_ROOT))
    print(f"PROJECT_ROOT: {PROJECT_ROOT}, added to sys.path.")

    from src.wandb_utils import initialize_wandb_run
    # get_dataset_paths is used here mainly to determine where to SAVE preprocessors.
    # The INPUT training data path will come from the downloaded W&B artifact.
    from src.paths_utils import get_dataset_paths 
    print("Successfully imported custom utilities.")
except Exception as e_setup:
    print(f"CRITICAL ERROR during initial setup: {e_setup}")
    # exit()

# --- Load Main Project Configuration ---
print("\n--- Loading Main Project Configuration ---")
try:
    if PROJECT_ROOT is None: raise ValueError("PROJECT_ROOT not set.")
    CONFIG_PATH_MAIN = PROJECT_ROOT / 'config.json'
    with open(CONFIG_PATH_MAIN, 'r', encoding='utf-8') as f:
        base_config = json.load(f)
    print(f"Main project config loaded from: {CONFIG_PATH_MAIN}")
except Exception as e_cfg:
    print(f"CRITICAL ERROR loading main config.json: {e_cfg}")
    # exit() 

# --- Define Dataset, Notebook Specifics ---
DATASET_IDENTIFIER = "oasis2" 
NOTEBOOK_MODULE_NAME = "04_Fit_Preprocessors"
# Key from config.json locators for this notebook's *output preprocessor files*
NB04_PREPROCESSORS_LOCATOR_KEY = "preprocessors_subdir" # Matches key in get_dataset_stage_paths

# Path where preprocessor .joblib files will be saved.
# This uses get_dataset_paths to find the preprocessors_subdir defined in config.
output_dir_for_preprocessors = None
try:
    if not base_config: raise ValueError("base_config is empty.")
    # We call get_dataset_stage_paths primarily to get the preprocessor output paths.
    # The 'stage' argument doesn't really change these particular paths, but let's use 'training'.
    # This utility will resolve 'scaler_path' and 'imputer_path' fully.
    # We'll take their parent as the output directory for this notebook.
    pipeline_paths_for_nb04 = get_dataset_paths(
        PROJECT_ROOT, base_config, DATASET_IDENTIFIER, stage="training" 
    )
    SCALER_SAVE_PATH = pipeline_paths_for_nb04.get('scaler_path')
    IMPUTER_SAVE_PATH = pipeline_paths_for_nb04.get('imputer_path')

    if not SCALER_SAVE_PATH or not IMPUTER_SAVE_PATH:
        raise ValueError("Could not resolve scaler_path or imputer_path from paths utility.")
    
    output_dir_for_preprocessors = SCALER_SAVE_PATH.parent # e.g., .../04_Fit_Preprocessors_oasis2/
    output_dir_for_preprocessors.mkdir(parents=True, exist_ok=True)
    
    print(f"\nKey paths for Notebook 04 ({DATASET_IDENTIFIER}):")
    print(f"  Output Directory for Preprocessor .joblib files: {output_dir_for_preprocessors}")
    print(f"  Scaler will be saved as: {SCALER_SAVE_PATH.name}")
    print(f"  Imputer will be saved as: {IMPUTER_SAVE_PATH.name}")

except Exception as e_paths_nb04:
    print(f"CRITICAL ERROR during path setup for NB04: {e_paths_nb04}")
    # exit()

## 3. Initialize Weights & Biases Run for Notebook 04

A new W&B run is initiated for this specific "Fit Preprocessors" task. This run will log:

* **Input Artifact Consumed:**
    * The name and version of the input **training data W&B Artifact** (e.g., `cohort_split_train_oasis2:latest`) that this notebook consumes. This artifact is the direct output of a previous Notebook 03 execution and contains the `cohort_train.parquet` file.
* **Source Feature Configuration (from Producer Run):**
    * The W&B Run ID and name of the specific Notebook 03 execution that **produced** the consumed training data artifact. This producer run's configuration (containing the definitive `features_prepared_in_nb03` lists) is fetched automatically via artifact lineage to ensure this notebook uses the correct feature definitions.
* **Determined Preprocessing Details:**
    * The exact lists of feature columns identified from the loaded training data for **imputation**.
    * The exact lists of feature columns identified for **scaling**.
    * The chosen strategies for imputation (e.g., 'median') and scaling (e.g., 'StandardScaler'), typically sourced from `config.json`.
* **Definitive Configuration for `OASISDataset` (Critically Important):**
    * The final, definitive **`features` dictionary** (detailing `time_varying` and `static` features, reflecting the lists fetched from the NB03 producer run and considering any encoding handled by `OASISDataset`).
    * The final, definitive **`preprocess` dictionary** (detailing `imputation_cols`, `scaling_cols`, and the chosen `imputation_strategy` and `scaling_strategy`).
    * These two dictionaries, logged to **this Notebook 04 W&B run's configuration**, will serve as the **authoritative source of truth** for the `OASISDataset` class in all subsequent data loading stages (i.e., in Notebooks 05, 06, 07, and 08).
* **Fitted Preprocessor Artifacts:**
    * The *fitted* imputer object (e.g., `SimpleImputer`) saved as a versioned W&B Artifact.
    * The *fitted* scaler object (e.g., `StandardScaler`) saved as a versioned W&B Artifact.


In [None]:
# --- Initialize W&B Run for THIS Notebook Execution (NB04) ---
print("\n--- Initializing Weights & Biases Run for Notebook 04 ---")

# --- Define W&B Artifact Name for Input Training Data (Output from NB03) ---
# This should match the artifact_name used in NB03 when logging cohort_split_train
INPUT_TRAIN_DATA_ARTIFACT_NAME = f"cohort_split_train_{DATASET_IDENTIFIER}" 
INPUT_TRAIN_DATA_ARTIFACT_TYPE = f"data_split_{DATASET_IDENTIFIER}" # Matches type in NB03
INPUT_TRAIN_DATA_ARTIFACT_VERSION = "latest" # Or a specific version like "v0"

nb04_run_config_log = {
    "notebook_name_code": f"{NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER}",
    "dataset_source": DATASET_IDENTIFIER,
    "input_train_data_artifact": f"{INPUT_TRAIN_DATA_ARTIFACT_NAME}:{INPUT_TRAIN_DATA_ARTIFACT_VERSION}",
    "output_dir_for_local_preprocessors": str(output_dir_for_preprocessors),
    "scaler_save_filename": SCALER_SAVE_PATH.name, # Log the intended save names
    "imputer_save_filename": IMPUTER_SAVE_PATH.name,
    "execution_timestamp": time.strftime("%Y-%m-%d %H:%M:%S"),
    # Feature lists, imputation/scaling cols, and strategies will be logged after determination
}

nb_number_prefix_nb04 = NOTEBOOK_MODULE_NAME.split('_')[0] if '_' in NOTEBOOK_MODULE_NAME else "NB"
job_specific_type_nb04 = f"{nb_number_prefix_nb04}-FitPreprocessors-{DATASET_IDENTIFIER}"
custom_elements_for_name_nb04 = [nb_number_prefix_nb04, DATASET_IDENTIFIER.upper(), "FitPreproc"]

run = initialize_wandb_run(
    base_project_config=base_config,
    job_group="DataProcessing",
    job_specific_type=job_specific_type_nb04,
    run_specific_config=nb04_run_config_log,
    custom_run_name_elements=custom_elements_for_name_nb04,
    notes=f"{DATASET_IDENTIFIER.upper()}: Fitting and saving data preprocessors based on training data artifact."
)

if run:
    print(f"W&B run '{run.name}' (Job Type: '{run.job_type}') initialized. View at: {run.url}")
else:
    print("Proceeding without W&B logging for this session (W&B run initialization failed).")
    # output_dir_for_preprocessors should still be defined for local saves.

## 4. Load Training Data Artifact & Fetch Feature Configuration from Producer Run

This notebook consumes the training data split (`cohort_train_oasis2`) logged as an artifact by Notebook 03.
1.  The specified W&B artifact is downloaded.
2.  The `cohort_train.parquet` file is loaded from the downloaded artifact directory.
3.  Crucially, we then identify the W&B Run (from Notebook 03) that *produced* this training data artifact.
4.  The configuration of that producer run is fetched to retrieve the definitive lists of `time_varying_features` and `static_features` that were selected and prepared in Notebook 03. This ensures that preprocessors in this notebook are fitted using the exact feature set intended for modeling.

In [None]:
# --- Load Training Data from W&B Artifact and Fetch Producer Run Config ---
print(f"\n--- Loading Training Data Artifact & Fetching NB03 Feature Config ---")
train_df = None
source_time_varying_features = []
source_static_features = []
source_nb03_run_id_for_config = "N/A" # For logging

try:
    if run is None: # If W&B init failed earlier
        raise ConnectionError("W&B run not initialized. Cannot use W&B artifacts.")

    print(f"Using input training data artifact: {INPUT_TRAIN_DATA_ARTIFACT_NAME}:{INPUT_TRAIN_DATA_ARTIFACT_VERSION}")
    # Use the artifact
    train_data_artifact = run.use_artifact(
        f"{INPUT_TRAIN_DATA_ARTIFACT_NAME}:{INPUT_TRAIN_DATA_ARTIFACT_VERSION}", 
        type=INPUT_TRAIN_DATA_ARTIFACT_TYPE
    )
    train_data_artifact_dir = Path(train_data_artifact.download())
    # Construct path to the parquet file within the artifact directory
    # The filename within the artifact was defined in NB03 when artifact.add_file() was called.
    # It should match locators.get("train_data_fname", "cohort_train.parquet") from NB03's config.
    # Let's assume a consistent naming or get it from artifact metadata if logged by NB03.
    # For now, assume it's the default name from pipeline_artefact_locators.
    locators = base_config.get(f"pipeline_artefact_locators_{DATASET_IDENTIFIER}", {})
    train_fname_in_artifact = locators.get("train_data_fname", f"cohort_train_{DATASET_IDENTIFIER}.parquet") # Match NB03 save
    
    TRAIN_DATA_PATH_FROM_ARTIFACT = train_data_artifact_dir / train_fname_in_artifact
    
    if not TRAIN_DATA_PATH_FROM_ARTIFACT.is_file():
        raise FileNotFoundError(f"Training data parquet file '{train_fname_in_artifact}' not found in downloaded artifact at {train_data_artifact_dir}")

    train_df = pd.read_parquet(TRAIN_DATA_PATH_FROM_ARTIFACT)
    print(f"Training data loaded successfully from artifact. Shape: {train_df.shape}")
    run.log({'fit_preprocessors/input_train_rows_from_artifact': train_df.shape[0],
             'fit_preprocessors/input_train_cols_from_artifact': train_df.shape[1]})

    # --- Fetch Feature Configuration from the NB03 Run that Produced this Artifact ---
    nb03_producer_run = train_data_artifact.logged_by()
    if nb03_producer_run:
        source_nb03_run_id_for_config = nb03_producer_run.id
        print(f"Fetching feature config from producer NB03 run: {nb03_producer_run.name} (ID: {source_nb03_run_id_for_config})")
        
        nb03_run_config = dict(nb03_producer_run.config) # Convert to dict
        # NB03 logged its feature selection under "features_prepared_in_nb03"
        fetched_features_from_nb03 = nb03_run_config.get("features_prepared_in_nb03", {})
        
        if not fetched_features_from_nb03 or \
           'time_varying' not in fetched_features_from_nb03 or \
           'static' not in fetched_features_from_nb03:
            raise ValueError("Key 'features_prepared_in_nb03' or its subkeys ('time_varying', 'static') "
                             "not found or incomplete in the config of the NB03 run that produced the input artifact.")
            
        source_time_varying_features = fetched_features_from_nb03.get("time_varying", [])
        source_static_features = fetched_features_from_nb03.get("static", [])
        
        print(f"  Successfully fetched feature lists from producer NB03 run's config.")
        print(f"    Source Time-Varying features: {source_time_varying_features}")
        print(f"    Source Static features: {source_static_features}")
        
        # Log which NB03 run's config was used
        run.config.update({
            "source_config/nb03_producer_run_id": source_nb03_run_id_for_config,
            "source_config/nb03_producer_run_name": nb03_producer_run.name,
            "source_config/initial_time_varying_features": source_time_varying_features,
            "source_config/initial_static_features": source_static_features
        }, allow_val_change=True)
    else:
        raise ConnectionError("Could not retrieve the W&B run that produced the input training data artifact. "
                              "Feature lists cannot be fetched.")

except Exception as e_load_artifact_or_config:
    print(f"CRITICAL ERROR loading training data artifact or fetching NB03 config: {e_load_artifact_or_config}")
    if run: run.finish(exit_code=1)
    # exit()

if train_df is None or train_df.empty:
    print("CRITICAL ERROR: train_df is not loaded or empty after artifact processing. Cannot continue.")
    # exit()
    # Ensure these are lists for downstream code even if loading fails and we don't exit
    if 'source_time_varying_features' not in locals(): source_time_varying_features = []
    if 'source_static_features' not in locals(): source_static_features = []

## 5. Identify Columns for Imputation & Scaling and Log Final Configuration

Based on the loaded `train_df` and the `source_feature_lists` (time-varying and static) fetched from the Notebook 03 W&B run configuration:

1.  **Identify Imputation Columns:** Columns from the `source_feature_lists` that exhibit missing values (`NaN`) in the current `train_df` are selected for imputation. The imputation strategy (e.g., 'median') is noted.
2.  **Identify Scaling Columns:** Numerical columns from the `source_feature_lists` that are present in `train_df` are selected for scaling (e.g., using `StandardScaler`). Categorical columns (like 'M/F' which will be encoded later) and potentially some pre-scaled or identifier-like numerical columns are excluded.
3.  **Log Definitive Configuration to W&B:** The *exact* lists of `imputation_cols`, `scaling_cols`, the chosen `imputation_strategy`, `scaling_strategy`, and the definitive `features` (time-varying and static lists that `OASISDataset` should use, reflecting the original selection from NB03 and considering M/F encoding) are logged to the configuration of this **current Notebook 04 W&B run**. This logged configuration becomes the **authoritative source** for all subsequent `OASISDataset` instantiations.

In [None]:
# --- Identify Actual Columns for Imputation & Scaling from Training Data ---
# Also, define and log the final feature and preprocessing config for OASISDataset.
print("\n--- Determining Actual Columns for Imputation & Scaling from Training Data ---")
print("   (Using feature lists fetched from Notebook 03's W&B Run config)")

# Initialize lists to store final column names for preprocessing
imputation_cols_to_fit = []
scaling_cols_to_fit = []

# These are the feature lists that OASISDataset should ultimately use.
# They are based on source_time_varying_features and source_static_features from NB03,
# filtered by actual presence in train_df (already done when NB03 created them),
# and considering M/F encoding.
final_model_time_varying_features_list = [] 
final_model_static_features_list = [] 

if 'train_df' in locals() and not train_df.empty and \
   'source_time_varying_features' in locals() and \
   'source_static_features' in locals():

    available_cols_in_train_df = train_df.columns.tolist()
    candidate_model_features = [
        f for f in (source_time_varying_features + source_static_features) 
        if f in available_cols_in_train_df
    ]

    # --- Determine Imputation Columns ---
    missing_in_train_df_subset = train_df[candidate_model_features].isnull().sum()
    imputation_cols_based_on_train_nans = missing_in_train_df_subset[missing_in_train_df_subset > 0].index.tolist()
    print(f"Columns with NaNs in current train_df: {imputation_cols_based_on_train_nans}")

    # --- Define columns that are known to sometimes have sparse NaNs AND will be scaled ---
    # These should be proactively imputed even if the current train_df split has no NaNs for them.
    # Ensure these are part of the overall feature set defined by NB03.
    proactively_impute_these_if_scaled = ['MMSE', 'SES', 'nWBV', 'EDUC', 'Baseline_MMSE'] # Add any others
    
    imputation_cols_to_fit = list(set(imputation_cols_based_on_train_nans + \
                                   [col for col in proactively_impute_these_if_scaled if col in candidate_model_features]))
    imputation_cols_to_fit = sorted([col for col in imputation_cols_to_fit if col in available_cols_in_train_df]) # Final check
    
    print(f"Final columns selected for imputation: {imputation_cols_to_fit}")

    # --- Determine Scaling Columns ---
    potential_cols_for_scaling_from_source = [
        f for f in candidate_model_features 
        if pd.api.types.is_numeric_dtype(train_df[f])
    ]
    cols_to_exclude_from_scaling = ['M/F'] 
    if base_config.get("preprocessing_config",{}).get("scale_baseline_cdr", False) is False:
        if 'Baseline_CDR' not in cols_to_exclude_from_scaling:
             cols_to_exclude_from_scaling.append('Baseline_CDR')
    scaling_cols_to_fit = [col for col in potential_cols_for_scaling_from_source if col not in cols_to_exclude_from_scaling]
    print(f"Final columns selected for scaling: {scaling_cols_to_fit}")

    # --- Determine Final Feature Lists for OASISDataset ---
    final_model_time_varying_features_list = [f for f in source_time_varying_features if f in available_cols_in_train_df]
    temp_static_from_source_available = [f for f in source_static_features if f in available_cols_in_train_df]
    final_model_static_features_list = []
    if 'M/F' in temp_static_from_source_available and base_config.get('preprocessing_config',{}).get('encode_m_f_in_dataset_class', True):
        if 'M/F_encoded' not in final_model_static_features_list: 
            final_model_static_features_list.append('M/F_encoded')
    for feat in temp_static_from_source_available:
        if feat != 'M/F': 
            final_model_static_features_list.append(feat)
    if 'M/F_encoded' in source_static_features and 'M/F_encoded' in available_cols_in_train_df and \
       'M/F_encoded' not in final_model_static_features_list:
        final_model_static_features_list.append('M/F_encoded')
    final_model_static_features_list = sorted(list(set(final_model_static_features_list)))

    print(f"\nDefinitive Feature Lists for OASISDataset Configuration (to be logged to this NB04 run):")
    print(f"  Time-Varying Features for Model: {final_model_time_varying_features_list}")
    print(f"  Static Features for Model: {final_model_static_features_list}")

    # --- Log this definitive configuration to the current W&B run (NB04 run) ---
    if run:
        imputation_strategy_logged = base_config.get('preprocessing_config',{}).get('imputation_strategy', 'median')
        scaling_strategy_logged = base_config.get('preprocessing_config',{}).get('scaling_strategy', 'standard_scaler') 
        
        config_for_oasis_dataset_to_log = {
            'preprocess': {
                'imputation_cols': imputation_cols_to_fit,
                'scaling_cols': scaling_cols_to_fit,       
                'imputation_strategy': imputation_strategy_logged,
                'scaling_strategy': scaling_strategy_logged 
            },
            'features': { 
                'time_varying': final_model_time_varying_features_list,
                'static': final_model_static_features_list 
            },
            'cnn_model_params': base_config.get('cnn_model_params', {}),
            'preprocessing_config': base_config.get('preprocessing_config', {})
        }
        run.config.update(config_for_oasis_dataset_to_log, allow_val_change=True)
        print("\nDefinitive 'features' and 'preprocess' configuration for OASISDataset logged to this W&B run's config.")
else:
    print("Skipping identification of preprocessing columns as train_df is empty or source feature lists are not defined.")
    # Ensure lists are empty if not populated, to avoid NameErrors if run continues (though it shouldn't)
    imputation_cols_to_fit = []
    scaling_cols_to_fit = []
    final_model_time_varying_features_list = []
    final_model_static_features_list = []

## 6. Fit and Save Imputer

Based on the `imputation_cols_to_fit` identified from the training data, an imputer (e.g., `SimpleImputer` with a 'median' strategy) is initialized and fitted. This fitted imputer learns the imputation values (e.g., medians) *only from the training data*. The fitted object is then saved locally as a `.joblib` file and logged as a versioned artifact to W&B. This allows the exact same imputation to be applied consistently to validation and test sets later.

In [None]:
# --- Fit and Save Imputer ---
print("\n--- Fitting and Saving Imputer ---")

imputer_fitted = None # Initialize to ensure variable is defined

# Ensure imputation_cols_to_fit is defined from the previous cell,
# train_df is loaded, and base_config is available.
# IMPUTER_SAVE_PATH and DATASET_IDENTIFIER should also be defined from setup cells.
if 'imputation_cols_to_fit' in locals() and imputation_cols_to_fit and \
   'train_df' in locals() and not train_df.empty and \
   'IMPUTER_SAVE_PATH' in locals() and IMPUTER_SAVE_PATH is not None:
    
    # Determine imputation strategy from base_config (preprocessing_config section)
    imputation_strategy_to_use = base_config.get('preprocessing_config', {})\
                                           .get('imputation_strategy', 'median') # Default to 'median'
    print(f"  Using imputation strategy: '{imputation_strategy_to_use}' for columns: {imputation_cols_to_fit}")

    try:
        # Initialize SimpleImputer with the chosen strategy
        imputer_fitted = SimpleImputer(strategy=imputation_strategy_to_use)
        
        # Fit imputer ONLY on the specified columns of the training data
        # Ensure all columns in imputation_cols_to_fit actually exist in train_df 
        # (though this should be guaranteed by how imputation_cols_to_fit was created)
        valid_imputation_cols = [col for col in imputation_cols_to_fit if col in train_df.columns]
        if not valid_imputation_cols:
            print("  Warning: No valid columns found in train_df for imputation, though imputation_cols_to_fit was not empty. Skipping imputer fitting.")
        else:
            imputer_fitted.fit(train_df[valid_imputation_cols])
            print(f"  Imputer (strategy: '{imputation_strategy_to_use}') fitted successfully on training data columns: {valid_imputation_cols}")

            print(f"  Applying imputation transform to 'train_df' for columns: {valid_imputation_cols}...")
            # .transform returns a NumPy array, so reassign it back to the DataFrame columns
            train_df[valid_imputation_cols] = imputer_fitted.transform(train_df[valid_imputation_cols])
            print("  Imputation transform applied to 'train_df'.")

            # Save the fitted imputer locally
            # IMPUTER_SAVE_PATH should be fully resolved, e.g., <...>/04_Fit_Preprocessors_OASIS2/simple_imputer_median_oasis2.joblib
            IMPUTER_SAVE_PATH.parent.mkdir(parents=True, exist_ok=True) # Ensure directory exists
            joblib.dump(imputer_fitted, IMPUTER_SAVE_PATH)
            print(f"  Fitted imputer saved locally to: {IMPUTER_SAVE_PATH}")

            # Log imputer as a W&B artifact
            if run: # Check if W&B run is active
                print("  Logging imputer as W&B artifact...")
                # Consistent artifact naming including dataset and strategy
                imputer_artifact_base_name = IMPUTER_SAVE_PATH.stem # e.g., "simple_imputer_median_oasis2"
                imputer_artifact_name = f"{imputer_artifact_base_name}" # Or just imputer_artifact_base_name
                imputer_artifact_type = f"preprocessor_{DATASET_IDENTIFIER}" # e.g., "preprocessor_oasis2"
                
                imputer_description = (
                    f"SimpleImputer(strategy='{imputation_strategy_to_use}') fitted on {DATASET_IDENTIFIER.upper()} "
                    f"training data columns: {valid_imputation_cols}."
                )
                imputer_metadata = {
                    'columns_imputed': valid_imputation_cols, 
                    'imputation_strategy': imputation_strategy_to_use,
                    'dataset_identifier': DATASET_IDENTIFIER,
                    'saved_filename': IMPUTER_SAVE_PATH.name # Log the actual filename
                }
                
                imputer_wandb_artifact = wandb.Artifact(
                    imputer_artifact_name, 
                    type=imputer_artifact_type, 
                    description=imputer_description,
                    metadata=imputer_metadata
                )
                imputer_wandb_artifact.add_file(str(IMPUTER_SAVE_PATH)) # Add the .joblib file
                # Define aliases for easy retrieval, e.g., "latest" and strategy-specific
                aliases = ["latest", f"imputer_{imputation_strategy_to_use.lower()}_{time.strftime('%Y%m%d')}"]
                run.log_artifact(imputer_wandb_artifact, aliases=aliases)
                print(f"  Imputer artifact '{imputer_artifact_name}' (aliases: {aliases}) logged to W&B.")
            else:
                print("  W&B run not active. Skipping artifact logging for imputer.")

    except Exception as e_imputer:
        print(f"  CRITICAL ERROR fitting or saving imputer: {e_imputer}")
        imputer_fitted = None # Ensure it's None if fitting/saving failed
else:
    if not ('imputation_cols_to_fit' in locals() and imputation_cols_to_fit):
        print("  No columns were identified for imputation in the previous step. Skipping imputer fitting.")
    elif not ('train_df' in locals() and not train_df.empty):
        print("  Training data (train_df) is empty or not defined. Skipping imputer fitting.")
    elif not ('IMPUTER_SAVE_PATH' in locals() and IMPUTER_SAVE_PATH is not None):
        print("  IMPUTER_SAVE_PATH is not defined. Skipping imputer fitting.")

## 7. Fit and Save Scaler

Similarly, for the `scaling_cols_to_fit`, a scaler (e.g., `StandardScaler`) is initialized and fitted *only* on the training data (potentially after imputation, if the columns overlap and imputation was performed in-place on `train_df`). This fitted scaler learns the mean and standard deviation from the training data. The fitted object is saved locally and logged as a versioned W&B artifact for consistent application to all data splits.

In [None]:
# --- Fit and Save Scaler ---
print("\n--- Fitting and Saving Scaler ---")

scaler_fitted = None # Initialize

# Ensure scaling_cols_to_fit is defined and train_df is not empty
if 'scaling_cols_to_fit' in locals() and scaling_cols_to_fit and \
   'train_df' in locals() and not train_df.empty:

    # Determine scaling strategy from base_config, default to 'standard_scaler'
    # Note: StandardScaler is a class, the strategy name in config might be 'standard_scaler'
    scaling_strategy_name = base_config.get('preprocessing_config', {}).get('scaling_strategy', 'standard_scaler')
    print(f"  Using scaling strategy: '{scaling_strategy_name}' for columns: {scaling_cols_to_fit}")
    
    # Select data for scaler fitting (potentially imputed if imputation_cols overlap with scaling_cols)
    # train_df here should be the version after imputation has been applied if imputation_cols_to_fit was not empty.
    data_for_scaling = train_df[scaling_cols_to_fit]
    
    # Check if data_for_scaling still has any NaN values (imputer might not have covered all scaling_cols)
    if data_for_scaling.isnull().sum().any():
        print(f"  Warning: Data for scaling still contains NaN values AFTER potential imputation for columns: "
              f"{data_for_scaling.isnull().sum()[data_for_scaling.isnull().sum() > 0].index.tolist()}")
        print("  This might cause issues with StandardScaler. Consider refining imputation_cols or handling NaNs before scaling.")
        # Option: Impute remaining NaNs in data_for_scaling with mean/median just for fitting scaler,
        # but this implies imputer should have handled these columns. For now, proceed with warning.
        # data_for_scaling = data_for_scaling.fillna(data_for_scaling.median()) # Example: quick fix

    try:
        if scaling_strategy_name.lower() == 'standard_scaler':
            scaler_fitted = StandardScaler()
        # Add elif for other scalers like MinMaxScaler if you plan to use them via config
        # elif scaling_strategy_name.lower() == 'min_max_scaler':
        #     scaler_fitted = MinMaxScaler()
        else:
            print(f"  Warning: Unknown scaling strategy '{scaling_strategy_name}'. Defaulting to StandardScaler.")
            scaler_fitted = StandardScaler()
            scaling_strategy_name = "StandardScaler" # Update for logging

        # Fit scaler ONLY on the training data's specified columns
        scaler_fitted.fit(data_for_scaling)
        print("  Scaler fitted successfully on training data.")

        # Save the fitted scaler locally using SCALER_SAVE_PATH
        joblib.dump(scaler_fitted, SCALER_SAVE_PATH) # SCALER_SAVE_PATH from Cell 2
        print(f"  Fitted scaler saved locally to: {SCALER_SAVE_PATH}")

        # Log scaler as a W&B artifact
        if run:
            print("  Logging scaler as W&B artifact...")
            scaler_artifact_name = f"scaler_{scaling_strategy_name.lower().replace('_scaler','')}_{DATASET_IDENTIFIER}"
            scaler_artifact_type = f"preprocessor_{DATASET_IDENTIFIER}"
            scaler_description = (
                f"{scaling_strategy_name} fitted on {DATASET_IDENTIFIER} training data columns: {scaling_cols_to_fit}"
            )
            scaler_wandb_artifact = wandb.Artifact(
                scaler_artifact_name, 
                type=scaler_artifact_type, 
                description=scaler_description,
                metadata={'columns_scaled': scaling_cols_to_fit, 'strategy': scaling_strategy_name}
            )
            scaler_wandb_artifact.add_file(str(SCALER_SAVE_PATH))
            run.log_artifact(scaler_wandb_artifact, aliases=["latest", f"{scaling_strategy_name.lower()}_{time.strftime('%Y%m%d')}"])
            print(f"  Scaler artifact '{scaler_artifact_name}' logged to W&B.")

    except Exception as e_scaler:
        print(f"  Error fitting or saving scaler: {e_scaler}")
        scaler_fitted = None # Ensure it's None if fitting/saving failed
else:
    print("  No columns identified for scaling, or training data is empty. Skipping scaler fitting.")

## 7. Finalize W&B Run

Finish the Weights & Biases run associated with fitting and saving the preprocessors. The saved `.joblib` files (fitted imputer and scaler) and their corresponding W&B artifacts, along with the definitive logged configuration for features and preprocessing, are now ready for use in the downstream data loading pipeline (`OASISDataset`) and model training stages.

In [None]:
# --- Finish W&B Run for Notebook 04 ---
print(f"\n--- {NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER} complete. Finishing W&B run. ---")

if run: # Check if 'run' object exists and is an active run
    try:
        # Add any final summary metrics for NB04 to run.summary if applicable
        # For example, number of features imputed/scaled if not already in config.
        run.summary["num_imputation_cols_fitted"] = len(imputation_cols_to_fit) if 'imputation_cols_to_fit' in locals() else 0
        run.summary["num_scaling_cols_fitted"] = len(scaling_cols_to_fit) if 'scaling_cols_to_fit' in locals() else 0
        
        run.finish()
        run_name_to_print = run.name_synced if hasattr(run, 'name_synced') and run.name_synced else \
                            run.name if hasattr(run, 'name') and run.name else \
                            run.id if hasattr(run, 'id') else "current run"
        print(f"W&B run '{run_name_to_print}' finished successfully.")
    except Exception as e_finish_run_nb04:
        print(f"Error during wandb.finish() for Notebook 04: {e_finish_run_nb04}")
else:
    print("No active W&B run to finish for this session.")

print(f"\n--- Notebook {NOTEBOOK_MODULE_NAME}_{DATASET_IDENTIFIER} execution finished. ---")