# ü§ñ Entra√Ænement RandomForest - Classificateur de Triage

**Dataset :** 79 cas √©quilibr√©s (20 ROUGE, 19 JAUNE, 20 VERT, 20 GRIS)

**Features :** 776 dimensions
- 768 : Embeddings CamemBERT-bio (sympt√¥mes)
- 6 : Constantes vitales (FC, FR, SpO2, TA, Temp)
- 2 : Patient (√¢ge, sexe)

## 1Ô∏è‚É£ Imports

In [None]:
import pickle
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    classification_report, 
    confusion_matrix, 
    accuracy_score,
    f1_score
)



## 2Ô∏è‚É£ Charger le dataset

In [None]:
# Charger le dataset COMPLET (79 cas)
data_path = Path('../data/triage_dataset_complete.pkl')

with open(data_path, 'rb') as f:
    data = pickle.load(f)

X = data['X']
y = data['y']

print(f"‚úÖ Dataset charg√© !")
print(f"   X shape : {X.shape}")
print(f"   y shape : {y.shape}")
print(f"   Features : {X.shape[1]} dimensions")
print(f"   Cas : {X.shape[0]}")

## 3Ô∏è‚É£ Distribution des classes

In [None]:
from collections import Counter

counts = Counter(y)

print("üìä DISTRIBUTION DES CLASSES")
print("="*60)
for label in ['ROUGE', 'JAUNE', 'VERT', 'GRIS']:
    count = counts[label]
    pct = (count / len(y)) * 100
    bar = '‚ñà' * (count // 2)
    print(f"   {label:6s} : {count:2d} ({pct:5.1f}%) {bar}")
print("="*60)

# Visualisation
plt.figure(figsize=(10, 6))
colors = {'ROUGE': 'red', 'JAUNE': 'yellow', 'VERT': 'green', 'GRIS': 'gray'}
bar_colors = [colors[label] for label in ['ROUGE', 'JAUNE', 'VERT', 'GRIS']]
values = [counts[label] for label in ['ROUGE', 'JAUNE', 'VERT', 'GRIS']]

plt.bar(['ROUGE', 'JAUNE', 'VERT', 'GRIS'], values, color=bar_colors, alpha=0.7, edgecolor='black')
plt.title('Distribution des Classes', fontsize=16, fontweight='bold')
plt.ylabel('Nombre de Cas')
plt.xlabel('Niveau de Gravit√©')
plt.grid(axis='y', alpha=0.3)

for i, v in enumerate(values):
    plt.text(i, v + 0.5, str(v), ha='center', fontweight='bold')

plt.tight_layout()
plt.show()

## 4Ô∏è‚É£ Split Train/Test

In [None]:
# Split stratifi√© (maintient distribution)
X_train, X_test, y_train, y_test = train_test_split(
    X, y, 
    test_size=0.25,  # 25% test
    random_state=42,
    stratify=y  # Important : maintient distribution
)

print(f"üìä SPLIT TRAIN/TEST")
print(f"="*60)
print(f"   Train : {len(X_train)} cas ({len(X_train)/len(X)*100:.0f}%)")
print(f"   Test  : {len(X_test)} cas ({len(X_test)/len(X)*100:.0f}%)")
print(f"="*60)

# V√©rifier distribution train
print(f"\n   Distribution TRAIN :")
train_counts = Counter(y_train)
for label in ['ROUGE', 'JAUNE', 'VERT', 'GRIS']:
    print(f"      {label} : {train_counts[label]} cas")

print(f"\n   Distribution TEST :")
test_counts = Counter(y_test)
for label in ['ROUGE', 'JAUNE', 'VERT', 'GRIS']:
    print(f"      {label} : {test_counts[label]} cas")

## 5Ô∏è‚É£ Entra√Æner RandomForest

In [None]:
print("üéì Entra√Ænement du mod√®le...\n")

# Cr√©er et entra√Æner
clf = RandomForestClassifier(
    n_estimators=100,  # 100 arbres
    max_depth=10,      # Profondeur max
    min_samples_split=2,
    min_samples_leaf=1,
    random_state=42,
    n_jobs=-1          # Utiliser tous les CPU
)

clf.fit(X_train, y_train)

print("‚úÖ Mod√®le entra√Æn√© !")
print(f"   Nombre d'arbres : {clf.n_estimators}")
print(f"   Profondeur max : {clf.max_depth}")

## 6Ô∏è‚É£ Pr√©dictions

In [None]:
# Pr√©dictions sur test
y_pred = clf.predict(X_test)

# Accuracy
accuracy = accuracy_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred, average='weighted')

print(f"üìä PERFORMANCES SUR TEST SET")
print(f"="*60)
print(f"   Accuracy : {accuracy*100:.1f}%")
print(f"   F1-Score : {f1:.3f}")
print(f"="*60)

## 7Ô∏è‚É£ Rapport de classification

In [None]:
print("\nüìã RAPPORT DE CLASSIFICATION D√âTAILL√â")
print("="*60)
print(classification_report(
    y_test, 
    y_pred,
    target_names=['GRIS', 'JAUNE', 'ROUGE', 'VERT'],
    zero_division=0
))

## 8Ô∏è‚É£ Matrice de confusion

In [None]:
# Calculer matrice
cm = confusion_matrix(y_test, y_pred, labels=['ROUGE', 'JAUNE', 'VERT', 'GRIS'])

# Afficher texte
print("\nüî¢ MATRICE DE CONFUSION")
print("="*60)
cm_df = pd.DataFrame(
    cm,
    index=['ROUGE', 'JAUNE', 'VERT', 'GRIS'],
    columns=['ROUGE', 'JAUNE', 'VERT', 'GRIS']
)
print(cm_df)
print("\n(Lignes = Vrai, Colonnes = Pr√©dit)")

# Visualiser
plt.figure(figsize=(10, 8))
sns.heatmap(
    cm, 
    annot=True, 
    fmt='d', 
    cmap='Blues',
    xticklabels=['ROUGE', 'JAUNE', 'VERT', 'GRIS'],
    yticklabels=['ROUGE', 'JAUNE', 'VERT', 'GRIS'],
    cbar_kws={'label': 'Nombre de cas'}
)
plt.title('Matrice de Confusion', fontsize=16, fontweight='bold', pad=20)
plt.ylabel('Classe R√©elle', fontweight='bold')
plt.xlabel('Classe Pr√©dite', fontweight='bold')
plt.tight_layout()
plt.show()

## 9Ô∏è‚É£ Validation crois√©e

In [None]:
# Cross-validation 5-fold
print("üîÑ VALIDATION CROIS√âE (5-fold)")
print("="*60)

cv_scores = cross_val_score(clf, X, y, cv=5, scoring='accuracy')

print(f"   Scores par fold :")
for i, score in enumerate(cv_scores, 1):
    print(f"      Fold {i} : {score*100:.1f}%")

print(f"\n   Moyenne : {cv_scores.mean()*100:.1f}% (¬± {cv_scores.std()*100:.1f}%)")
print("="*60)

## üîü Importance des features

In [None]:
# Top 20 features les plus importantes
feature_importance = clf.feature_importances_

# Indices des constantes vitales
vitals_indices = {
    768: 'FC (fr√©quence cardiaque)',
    769: 'FR (fr√©quence respiratoire)',
    770: 'SpO2 (saturation)',
    771: 'TA systolique',
    772: 'TA diastolique',
    773: 'Temp√©rature',
    774: '√Çge',
    775: 'Sexe'
}

print("\nüîù TOP 10 FEATURES LES PLUS IMPORTANTES")
print("="*60)

# Trier par importance
indices = np.argsort(feature_importance)[::-1][:10]

for rank, idx in enumerate(indices, 1):
    importance = feature_importance[idx]
    if idx in vitals_indices:
        name = vitals_indices[idx]
    else:
        name = f"Embedding sympt√¥me #{idx}"
    print(f"   {rank:2d}. {name:30s} : {importance:.4f}")

print("="*60)

## 1Ô∏è‚É£1Ô∏è‚É£ Test sur cas sp√©cifiques

In [None]:
print("\nüß™ TEST SUR CAS SP√âCIFIQUES")
print("="*60)

# Prendre 5 cas al√©atoires du test set
np.random.seed(42)
test_indices = np.random.choice(len(X_test), size=min(5, len(X_test)), replace=False)

for i, idx in enumerate(test_indices, 1):
    x_case = X_test[idx:idx+1]
    y_true = y_test[idx]
    y_pred_case = clf.predict(x_case)[0]
    proba = clf.predict_proba(x_case)[0]
    
    # Probabilit√©s par classe
    proba_dict = dict(zip(clf.classes_, proba))
    
    print(f"\n   Cas #{i}")
    print(f"      Vrai label : {y_true}")
    print(f"      Pr√©diction : {y_pred_case} {'‚úÖ' if y_pred_case == y_true else '‚ùå'}")
    print(f"      Confiance :")
    for label in ['ROUGE', 'JAUNE', 'VERT', 'GRIS']:
        prob = proba_dict.get(label, 0)
        bar = '‚ñà' * int(prob * 20)
        print(f"         {label:6s} : {prob*100:5.1f}% {bar}")

print("\n" + "="*60)

## 1Ô∏è‚É£2Ô∏è‚É£ Sauvegarder le mod√®le

In [None]:
# Cr√©er dossier models
models_dir = Path('../models')
models_dir.mkdir(exist_ok=True)

# Sauvegarder
model_path = models_dir / 'random_forest_triage.pkl'

with open(model_path, 'wb') as f:
    pickle.dump(clf, f)

print(f"‚úÖ Mod√®le sauvegard√© : {model_path}")
print(f"\nüì¶ Taille du fichier : {model_path.stat().st_size / 1024:.1f} KB")

## ‚úÖ R√©sum√© & Conclusion

In [None]:
print("\n" + "="*70)
print("üìä R√âSUM√â FINAL")
print("="*70)
print(f"\n‚úÖ DATASET")
print(f"   Total : {len(X)} cas")
print(f"   Features : {X.shape[1]} dimensions")
print(f"   Classes : 4 (ROUGE, JAUNE, VERT, GRIS)")
print(f"   √âquilibr√© : Oui")

print(f"\n‚úÖ MOD√àLE")
print(f"   Algorithme : RandomForest")
print(f"   Arbres : {clf.n_estimators}")
print(f"   Accuracy : {accuracy*100:.1f}%")
print(f"   F1-Score : {f1:.3f}")

print(f"\n‚úÖ FICHIERS CR√â√âS")
print(f"   Mod√®le : models/random_forest_triage.pkl")

print(f"\nüí° PROCHAINES √âTAPES")
print(f"   1. Int√©grer dans Streamlit")
print(f"   2. Tester avec nouveaux patients")
print(f"   3. Am√©liorer avec plus de donn√©es")

print("\n" + "="*70)
print("üéâ ENTRA√éNEMENT TERMIN√â AVEC SUCC√àS !")
print("="*70)