# Phase 6 : IA Responsable - √âquit√©, Transparence et Monitoring

## Objectifs (Module 4)
- Analyser l'√©quit√© du mod√®le (fairness)
- Assurer la transparence des d√©cisions (SHAP, Grad-CAM)
- Mettre en place le monitoring (drift detection)
- Calculer le ROI et les KPI business

**‚ö†Ô∏è IMPORTANT** : Ce notebook est OBLIGATOIRE pour le projet

In [None]:
# Imports
import numpy as np
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from sklearn.metrics import confusion_matrix

import sys
sys.path.append('..')
from src.evaluation.metrics import compute_all_metrics
from src.utils.config import load_config

config = load_config('../configs/config.yaml')

## 1. Analyse d'√âquit√© (Fairness)

### Contexte
En diagnostic m√©dical, les biais algorithmiques peuvent causer :
- Soins in√©quitables entre groupes
- Erreurs de diagnostic pour certaines populations
- Perte de confiance dans le syst√®me

### 1.1 Identification des Biais Potentiels

In [None]:
# Charger les pr√©dictions et m√©tadonn√©es
# TODO: Charger predictions_df avec colonnes:
# - patient_id
# - hospital
# - true_pn_stage
# - predicted_pn_stage
# - confidence

predictions_df = pd.read_csv('../results/predictions/test_predictions.csv')

print("Distribution par h√¥pital:")
print(predictions_df['hospital'].value_counts())

### 1.2 M√©triques de Fairness

#### Demographic Parity (Parit√© D√©mographique)

In [None]:
# Proportion de pr√©dictions positives par h√¥pital
demographic_parity = {}

for hospital in predictions_df['hospital'].unique():
    mask = predictions_df['hospital'] == hospital
    
    # Proportion de patients avec m√©tastases (pN > 0)
    positive_rate = (predictions_df[mask]['predicted_pn_stage'] > 0).mean()
    demographic_parity[hospital] = positive_rate
    
    print(f"H√¥pital {hospital}: {positive_rate:.2%} de pr√©dictions positives")

# Visualiser
fig = px.bar(x=list(demographic_parity.keys()), 
             y=list(demographic_parity.values()),
             title='Parit√© D√©mographique par H√¥pital',
             labels={'x': 'H√¥pital', 'y': 'Taux de Pr√©dictions Positives'})
fig.show()

#### Equalized Odds (√âgalit√© des Chances)

In [None]:
# TPR et FPR par h√¥pital
equalized_odds = []

for hospital in predictions_df['hospital'].unique():
    mask = predictions_df['hospital'] == hospital
    
    # Binariser : 0 = pN0, 1 = pN1+
    y_true_binary = (predictions_df[mask]['true_pn_stage'] > 0).astype(int)
    y_pred_binary = (predictions_df[mask]['predicted_pn_stage'] > 0).astype(int)
    
    # Matrice de confusion
    cm = confusion_matrix(y_true_binary, y_pred_binary)
    
    # Calculer TPR et FPR
    tpr = cm[1,1] / (cm[1,1] + cm[1,0]) if (cm[1,1] + cm[1,0]) > 0 else 0
    fpr = cm[0,1] / (cm[0,1] + cm[0,0]) if (cm[0,1] + cm[0,0]) > 0 else 0
    
    equalized_odds.append({
        'hospital': hospital,
        'TPR': tpr,
        'FPR': fpr
    })
    
    print(f"H√¥pital {hospital}: TPR={tpr:.3f}, FPR={fpr:.3f}")

# Visualiser
df_eo = pd.DataFrame(equalized_odds)
fig = px.scatter(df_eo, x='FPR', y='TPR', text='hospital',
                 title='Equalized Odds par H√¥pital',
                 labels={'TPR': 'True Positive Rate', 'FPR': 'False Positive Rate'})
fig.show()

### 1.3 Strat√©gies de Correction

#### Post-processing : Calibration par H√¥pital

In [None]:
def calibrate_by_hospital(predictions_df):
    """
    Ajuste les seuils de d√©cision pour chaque h√¥pital
    pour garantir l'√©quit√© des performances
    """
    calibrated_predictions = predictions_df.copy()
    
    for hospital in predictions_df['hospital'].unique():
        mask = predictions_df['hospital'] == hospital
        
        # TODO: Optimiser le seuil pour ce h√¥pital
        # Objectif : maximiser F1 ou minimiser FN
        optimal_threshold = optimize_threshold_for_hospital(predictions_df[mask])
        
        # Appliquer le seuil calibr√©
        calibrated_predictions.loc[mask, 'calibrated_prediction'] = \
            apply_calibrated_threshold(predictions_df[mask], optimal_threshold)
    
    return calibrated_predictions

# Appliquer la calibration
# calibrated_df = calibrate_by_hospital(predictions_df)
# Comparer les performances avant/apr√®s

## 2. Transparence et Explicabilit√©

### 2.1 SHAP (Niveau Agr√©gation)

In [None]:
# Installer si n√©cessaire : uv sync --group interpretability
import shap

# Charger le mod√®le d'agr√©gation (XGBoost)
# TODO: Charger votre mod√®le XGBoost
# xgb_model = load_model('../models/final/xgboost_aggregator.pkl')

# Charger les features patients
# X_patient_features = pd.read_csv('../data/processed/patient_features.csv')

# Cr√©er l'explainer
# explainer = shap.TreeExplainer(xgb_model)
# shap_values = explainer.shap_values(X_patient_features)

# Visualiser
# shap.summary_plot(shap_values, X_patient_features, plot_type="bar")
# shap.summary_plot(shap_values, X_patient_features)

### 2.2 Grad-CAM (Niveau Patch)

In [None]:
from src.evaluation.interpretability import GradCAM
from src.visualization.heatmaps import plot_gradcam_heatmap

# Charger le mod√®le CNN
# model = load_model('../models/final/resnet50_best.pth')

# S√©lectionner la couche cible
# target_layer = model.layer4[-1]  # Derni√®re couche conv

# Initialiser Grad-CAM
# gradcam = GradCAM(model, target_layer)

# S√©lectionner des cas repr√©sentatifs
# TODO: Charger 10-20 patchs (TP, TN, FP, FN)

# Pour chaque patch
# for patch, label in representative_patches:
#     cam = gradcam.generate_cam(patch, target_class=1)
#     fig = plot_gradcam_heatmap(patch, cam)
#     fig.show()

### 2.3 Feature Importance

In [None]:
# Feature importance du mod√®le d'agr√©gation
# feature_importance = xgb_model.feature_importances_
# feature_names = ['tumor_percentage', 'mean_prob', 'max_prob', 'std_prob', ...]

# Visualiser
# fig = px.bar(x=feature_names, y=feature_importance,
#              title='Importance des Features pour Pr√©diction pN',
#              labels={'x': 'Feature', 'y': 'Importance'})
# fig.show()

## 3. Monitoring et Drift Detection

### 3.1 Feature Drift

In [None]:
from scipy.stats import ks_2samp

def detect_feature_drift(X_train, X_prod, threshold=0.05):
    """
    D√©tecte le drift dans les features entre train et production
    """
    drift_detected = {}
    
    for col in X_train.columns:
        # Test de Kolmogorov-Smirnov
        statistic, p_value = ks_2samp(X_train[col], X_prod[col])
        drift_detected[col] = {
            'p_value': p_value,
            'drift': p_value < threshold
        }
    
    return drift_detected

# Simuler des donn√©es de production
# X_train = pd.read_csv('../data/processed/train_features.csv')
# X_prod = pd.read_csv('../data/processed/production_features.csv')

# D√©tecter le drift
# drift_results = detect_feature_drift(X_train, X_prod)

# Afficher les features avec drift
# for feature, result in drift_results.items():
#     if result['drift']:
#         print(f"‚ö†Ô∏è DRIFT d√©tect√© pour {feature}: p-value={result['p_value']:.4f}")

### 3.2 Performance Monitoring

In [None]:
def monitor_performance_over_time(predictions_history):
    """
    Surveille les performances dans le temps
    """
    performance_metrics = []
    
    for batch in predictions_history:
        metrics = compute_all_metrics(
            batch['y_true'], 
            batch['y_pred']
        )
        metrics['timestamp'] = batch['timestamp']
        performance_metrics.append(metrics)
    
    df_perf = pd.DataFrame(performance_metrics)
    
    # Visualiser l'√©volution
    fig = px.line(df_perf, x='timestamp', y=['accuracy', 'recall', 'precision'],
                  title='√âvolution des Performances dans le Temps')
    fig.show()
    
    return df_perf

# TODO: Impl√©menter avec vos donn√©es

## 4. ROI et KPI Business

### 4.1 Calcul du ROI

In [None]:
# Hypoth√®ses (√† adapter selon votre contexte)
fn_rate_manual = 0.05  # 5% de faux n√©gatifs en diagnostic manuel
fn_rate_ai = 0.01      # 1% avec votre syst√®me IA
cost_per_fn = 50000    # Co√ªt d'un traitement retard√© (‚Ç¨)
n_patients_year = 10000  # Nombre de patients par an

# Co√ªts du syst√®me
development_cost = 100000  # D√©veloppement (‚Ç¨)
deployment_cost = 20000    # D√©ploiement (‚Ç¨)
maintenance_cost_year = 10000  # Maintenance annuelle (‚Ç¨)

# Calcul des b√©n√©fices
fn_avoided = (fn_rate_manual - fn_rate_ai) * n_patients_year
cost_avoided_year = fn_avoided * cost_per_fn

# Calcul du ROI sur 3 ans
total_cost = development_cost + deployment_cost + (maintenance_cost_year * 3)
total_benefit = cost_avoided_year * 3

roi = ((total_benefit - total_cost) / total_cost) * 100

print("=== Analyse ROI ===")
print(f"Faux n√©gatifs √©vit√©s par an : {fn_avoided:.0f} patients")
print(f"Co√ªts √©vit√©s par an : {cost_avoided_year:,.0f} ‚Ç¨")
print(f"Co√ªt total du syst√®me (3 ans) : {total_cost:,.0f} ‚Ç¨")
print(f"B√©n√©fice total (3 ans) : {total_benefit:,.0f} ‚Ç¨")
print(f"\nüéØ ROI sur 3 ans : {roi:.1f}%")

### 4.2 KPI Business

In [None]:
# D√©finir les KPI
kpi_data = [
    {
        'KPI': 'Taux de d√©tection',
        'M√©trique Technique': 'Recall',
        'Valeur': 0.95,
        'Objectif': 0.95,
        'Impact': 'Moins de 5% de m√©tastases manqu√©es'
    },
    {
        'KPI': 'Pr√©cision diagnostique',
        'M√©trique Technique': 'Precision',
        'Valeur': 0.92,
        'Objectif': 0.90,
        'Impact': 'Moins de 8% de faux positifs'
    },
    {
        'KPI': 'Temps de traitement',
        'M√©trique Technique': 'Temps moyen',
        'Valeur': 3.5,  # minutes
        'Objectif': 5.0,
        'Impact': 'Acc√©l√©ration du diagnostic'
    },
]

df_kpi = pd.DataFrame(kpi_data)
print(df_kpi.to_string(index=False))

## 5. Conclusions et Recommandations

### 5.1 Synth√®se des Analyses

**√âquit√©** :
- TODO: R√©sumer les biais identifi√©s
- TODO: Strat√©gies de correction appliqu√©es
- TODO: R√©sultats apr√®s correction

**Transparence** :
- TODO: Features les plus importantes identifi√©es
- TODO: Validation m√©dicale des d√©cisions
- TODO: Cas d'√©tude document√©s

**Monitoring** :
- TODO: Strat√©gie de drift detection d√©finie
- TODO: Seuils d'alerte √©tablis
- TODO: Plan de r√©entra√Ænement

**Business** :
- TODO: ROI calcul√© et justifi√©
- TODO: KPI d√©finis et mesur√©s
- TODO: Impact clinique quantifi√©

### 5.2 Recommandations pour le D√©ploiement

1. **Workflow Clinique** :
   - Utiliser le syst√®me comme outil de pr√©-screening
   - Validation humaine obligatoire pour cas critiques
   - Double lecture pour cas ambigus

2. **Monitoring Continu** :
   - V√©rifier les performances hebdomadairement
   - D√©tecter le drift mensuellement
   - R√©entra√Æner si d√©gradation > 5%

3. **√âquit√©** :
   - Auditer les performances par h√¥pital trimestriellement
   - Recalibrer si disparit√©s d√©tect√©es
   - Documenter tous les ajustements

4. **Transparence** :
   - Fournir explications Grad-CAM aux pathologistes
   - Documenter les cas d'erreur
   - Maintenir un registre des d√©cisions