# Sprint 6: Análisis de Interpretabilidad (SHAP) y Ética (Fairness)

Este notebook implementa las recomendaciones técnicas para el Sprint 6, enfocándose en:
1.  **Interpretabilidad Eficiente:** Uso de SHAP con sampling del dataset de entrenamiento.
2.  **Análisis de Ética:** Evaluación de sesgos utilizando `fairlearn` (FNR Parity).

**Nota:** Se utiliza una muestra del dataset para demostración, pero el código está preparado para escalar.

In [None]:
import pandas as pd
import numpy as np
import shap
import matplotlib.pyplot as plt
import seaborn as sns
from pycaret.classification import load_model, predict_model, get_config
from fairlearn.metrics import MetricFrame
from sklearn.metrics import recall_score, accuracy_score, confusion_matrix
import os
import sys

# Ensure project root is in path
sys.path.append(os.path.abspath('..'))

## 1. Cargar Modelo y Datos
Cargamos el pipeline final y el dataset de prueba.

In [None]:
# Load the finalized pipeline
try:
    pipeline = load_model('../models/best_pipeline')
    print("Pipeline loaded successfully.")
except Exception as e:
    print(f"Error loading pipeline: {e}")
    # Fallback for dev environment if model doesn't exist yet
    pipeline = None

# Load processed data (Simulating fetching X_test)
try:
    data = pd.read_parquet('../data/02_intermediate/processed_data.parquet')
    # Split for demonstration (In real scenario, use the split saved during training)
    from sklearn.model_selection import train_test_split
    train_data, test_data = train_test_split(data, test_size=0.2, random_state=42)
    
    X_train = train_data.drop(columns=['CVDINFR4'])
    y_train = train_data['CVDINFR4']
    X_test = test_data.drop(columns=['CVDINFR4'])
    y_test = test_data['CVDINFR4']
    
    print(f"Data loaded. Train shape: {X_train.shape}, Test shape: {X_test.shape}")
except Exception as e:
    print(f"Error loading data: {e}")
    X_train, X_test, y_test = None, None, None

## 2. Interpretabilidad con SHAP (Optimizado)
Calculamos los valores SHAP utilizando una muestra del set de entrenamiento como background data para eficiencia.

In [None]:
if pipeline and X_train is not None:
    # Extract the model from the pipeline (assuming PyCaret structure)
    # Note: PyCaret pipeline steps usually include 'trained_model'
    try:
        model = pipeline.named_steps['trained_model']
    except:
        # If not in named_steps, might be the last step
        model = pipeline[-1]

    # Sampling background data (Optimization for Large Datasets)
    background_sample = X_train.sample(n=min(100, len(X_train)), random_state=42)
    
    # Initialize Explainer
    # Note: For Tree models use TreeExplainer. For others, KernelExplainer (slower).
    try:
        explainer = shap.TreeExplainer(model)
        shap_values = explainer.shap_values(X_test)
        print("SHAP values calculated.")
    except Exception as e:
        print(f"TreeExplainer failed (model might not be tree-based or compatible): {e}")
        print("Attempting KernelExplainer (slower)...")
        explainer = shap.KernelExplainer(model.predict, background_sample)
        shap_values = explainer.shap_values(X_test.sample(n=min(50, len(X_test)), random_state=42)) # Test on small sample for speed in dev

    # Summary Plot
    plt.figure()
    shap.summary_plot(shap_values, X_test)
    plt.show()

## 3. Análisis de Casos (Waterfall)
Identificamos casos específicos: TP, TN, FP, FN.

In [None]:
if pipeline and X_test is not None:
    y_pred = pipeline.predict(X_test)
    
    results_df = X_test.copy()
    results_df['Actual'] = y_test
    results_df['Predicted'] = y_pred
    
    # Find indices for each case
    tp_idx = results_df[(results_df['Actual'] == 1) & (results_df['Predicted'] == 1)].index
    tn_idx = results_df[(results_df['Actual'] == 0) & (results_df['Predicted'] == 0)].index
    fp_idx = results_df[(results_df['Actual'] == 0) & (results_df['Predicted'] == 1)].index
    fn_idx = results_df[(results_df['Actual'] == 1) & (results_df['Predicted'] == 0)].index
    
    print(f"Found: {len(tp_idx)} TP, {len(tn_idx)} TN, {len(fp_idx)} FP, {len(fn_idx)} FN")
    
    # Function to plot waterfall
    def plot_waterfall(index, title):
        if len(index) > 0:
            idx = index[0]
            # Locating the position in X_test to match shap_values index if numpy array
            # But shap_values might be list of arrays for classification
            # Assuming binary classification, shap_values[1] is for positive class
            
            # Need numeric index for shap_values
            numeric_idx = X_test.index.get_loc(idx)
            
            print(f"--- {title} (Index: {idx}) ---")
            # Handle shap_values structure (list for classification vs array for regression)
            sv = shap_values[1][numeric_idx] if isinstance(shap_values, list) else shap_values[numeric_idx]
            exp_val = explainer.expected_value[1] if isinstance(explainer.expected_value, list) else explainer.expected_value
            
            shap.waterfall_plot(
                shap.Explanation(values=sv, 
                                 base_values=exp_val, 
                                 data=X_test.iloc[numeric_idx],
                                 feature_names=X_test.columns)
            )
            plt.show()
    
    # Plotting one of each
    plot_waterfall(fn_idx, "False Negative (Critical)")
    plot_waterfall(fp_idx, "False Positive")

## 4. Ética y Sesgos (Fairlearn)
Evaluamos la paridad de Tasa de Falsos Negativos (FNR) en grupos protegidos (Sexo).

In [None]:
if pipeline and X_test is not None and 'SEXVAR' in X_test.columns: # Assuming SEXVAR is the column name
    # Map SEXVAR back to readable if encoded (assuming 1=Male, 2=Female based on BRFSS usually)
    # Adjust column name based on actual data schema
    sensitive_feature = X_test['SEXVAR'] 
    
    # MetricFrame
    # We focus on Recall (Sensitivity). Low recall = High FNR.
    metric_frame = MetricFrame(
        metrics=recall_score,
        y_true=y_test,
        y_pred=y_pred,
        sensitive_features=sensitive_feature
    )
    
    print("Recall per group:")
    print(metric_frame.by_group)
    
    # Plot
    metric_frame.by_group.plot(kind='bar', title='Recall by Sex')
    plt.ylabel('Recall')
    plt.show()
    
    # FNR Calculation (1 - Recall)
    fnr_frame = MetricFrame(
        metrics=lambda y_t, y_p: 1 - recall_score(y_t, y_p),
        y_true=y_test,
        y_pred=y_pred,
        sensitive_features=sensitive_feature
    )
    print("\nFalse Negative Rate (FNR) per group:")
    print(fnr_frame.by_group)
else:
    print("Skipping Fairness analysis: Pipeline, Data or 'SEXVAR' column missing.")