In [None]:
# Installation des librairies n√©cessaires (sans tensorflow-addons)
!pip install -q tensorflow scikit-plot

# Importations
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import random
from collections import Counter
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import classification_report, confusion_matrix, f1_score
from sklearn.model_selection import train_test_split

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint, CSVLogger

# Configuration
BASE_PATH = "/content/drive/MyDrive/MobileNetV2/data"
IMAGE_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 50
SEED = 42

# Fixer les seeds pour la reproductibilit√©
np.random.seed(SEED)
tf.random.set_seed(SEED)
random.seed(SEED)

# 1. ANALYSE DES DONN√âES
print("=== ANALYSE DU D√âS√âQUILIBRE ===")

# Lister toutes les images et leurs classes
image_paths = []
labels = []
class_names = []

for class_dir in os.listdir(BASE_PATH):
    class_path = os.path.join(BASE_PATH, class_dir)
    if os.path.isdir(class_path):
        images_in_class = []
        for img_file in os.listdir(class_path):
            if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
                full_path = os.path.join(class_path, img_file)
                image_paths.append(full_path)
                labels.append(class_dir)
                class_names.append(class_dir)
                images_in_class.append(full_path)
        print(f"Dossier {class_dir}: {len(images_in_class)} images")

class_names = sorted(list(set(class_names)))
print(f"\nClasses trouv√©es: {class_names}")
print(f"Nombre total d'images: {len(image_paths)}")

# Compter les images par classe
class_counts = Counter(labels)
print("\nDistribution r√©elle des classes:")
for class_name in class_names:
    count = class_counts[class_name]
    print(f"{class_name}: {count} images ({count/len(labels)*100:.1f}%)")

# Visualisation
plt.figure(figsize=(10, 6))
bars = plt.bar(class_counts.keys(), [class_counts[cn] for cn in class_names])
plt.title('Distribution r√©elle des classes')
plt.xlabel('Classe d\'√¢ge')
plt.ylabel('Nombre d\'images')
for bar, count in zip(bars, [class_counts[cn] for cn in class_names]):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 50,
             str(count), ha='center', va='bottom')
plt.tight_layout()
plt.show()

# 2. STRAT√âGIE D'√âQUILIBRAGE
print("\n=== STRAT√âGIE D'√âQUILIBRAGE ===")

# Cibles pour chaque classe (AVANT division)
target_counts = {
    '1-20': 6638,    # Garder telles quelles
    '21-50': 8776,   # SOUS-√©chantillonnage (r√©duire de 21497 √† 8776)
    '51-100': 6638   # SUR-√©chantillonnage (augmenter √† 6638)
}

print("Cibles par classe (avant division train/val/test):")
for class_name in class_names:
    original = class_counts[class_name]
    target = target_counts[class_name]
    if target > original:
        action = "SUR-√©chantillonnage"
    elif target < original:
        action = "SOUS-√©chantillonnage"
    else:
        action = "Garder telles quelles"
    print(f"{class_name}: {original} ‚Üí {target} images ({action})")

# 3. PR√âPARATION DES DONN√âES AVEC √âQUILIBRAGE PUIS DIVISION GLOBALE (70/15/15)
def prepare_datasets_70_15_15():
    """Pr√©pare les datasets avec √©quilibrage d'abord, puis division globale 70/15/15"""

    # Collecter tous les chemins par classe
    class_paths = {}
    for class_name in class_names:
        class_dir = os.path.join(BASE_PATH, class_name)
        images = []
        for img_file in os.listdir(class_dir):
            if img_file.lower().endswith(('.png', '.jpg', '.jpeg')):
                images.append(os.path.join(class_dir, img_file))
        class_paths[class_name] = images

    print(f"\nImages disponibles par classe:")
    for class_name in class_names:
        print(f"{class_name}: {len(class_paths[class_name])} images")

    # √âTAPE 1: Appliquer l'√©quilibrage pour chaque classe
    balanced_paths_by_class = {}

    for class_name in class_names:
        paths = class_paths[class_name]
        target = target_counts[class_name]
        original_count = len(paths)

        if original_count > target:
            # SOUS-√©chantillonnage
            selected_paths = random.sample(paths, target)
            print(f"{class_name}: {original_count} ‚Üí {len(selected_paths)} (SOUS-√©chantillonnage)")

        elif original_count < target:
            # SUR-√©chantillonnage
            needed = target - original_count
            extra_samples = random.choices(paths, k=needed)
            selected_paths = paths + extra_samples
            print(f"{class_name}: {original_count} ‚Üí {len(selected_paths)} (SUR-√©chantillonnage)")

        else:
            # D√©j√† √† la bonne taille
            selected_paths = paths
            print(f"{class_name}: {original_count} (gard√©es telles quelles)")

        balanced_paths_by_class[class_name] = selected_paths

    # √âTAPE 2: Division globale (70% train, 15% val, 15% test)
    print(f"\nDivision globale (70% train, 15% val, 15% test):")

    all_train_paths = []
    all_train_labels = []
    all_val_paths = []
    all_val_labels = []
    all_test_paths = []
    all_test_labels = []

    for class_name in class_names:
        paths = balanced_paths_by_class[class_name]
        labels_list = [class_name] * len(paths)

        # Division: 70% train, 30% temporaire (val+test)
        train_paths, temp_paths, train_labels, temp_labels = train_test_split(
            paths, labels_list, test_size=0.3, random_state=SEED, stratify=labels_list
        )

        # Division du temporaire: 50% val (15% du total), 50% test (15% du total)
        val_paths, test_paths, val_labels, test_labels = train_test_split(
            temp_paths, temp_labels, test_size=0.5, random_state=SEED, stratify=temp_labels
        )

        all_train_paths.extend(train_paths)
        all_train_labels.extend(train_labels)
        all_val_paths.extend(val_paths)
        all_val_labels.extend(val_labels)
        all_test_paths.extend(test_paths)
        all_test_labels.extend(test_labels)

        # Afficher les statistiques par classe
        total = len(paths)
        train_count = len(train_paths)
        val_count = len(val_paths)
        test_count = len(test_paths)

        print(f"\n{class_name}:")
        print(f"  Train: {train_count} images ({train_count/total*100:.1f}%)")
        print(f"  Validation: {val_count} images ({val_count/total*100:.1f}%)")
        print(f"  Test: {test_count} images ({test_count/total*100:.1f}%)")
        print(f"  Total: {total} images")

    # Cr√©er les DataFrames
    train_df = pd.DataFrame({'filename': all_train_paths, 'class': all_train_labels})
    val_df = pd.DataFrame({'filename': all_val_paths, 'class': all_val_labels})
    test_df = pd.DataFrame({'filename': all_test_paths, 'class': all_test_labels})

    print(f"\nTotaux globaux:")
    print(f"  Train: {len(train_df)} images")
    print(f"  Validation: {len(val_df)} images")
    print(f"  Test: {len(test_df)} images")
    print(f"  Total: {len(train_df) + len(val_df) + len(test_df)} images")

    return train_df, val_df, test_df

# Pr√©parer les datasets avec division 70/15/15
train_df, val_df, test_df = prepare_datasets_70_15_15()

# 4. PR√âPARATION DES G√âN√âRATEURS
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=30,
    width_shift_range=0.2,
    height_shift_range=0.2,
    shear_range=0.2,
    zoom_range=0.2,
    horizontal_flip=True,
    brightness_range=[0.8, 1.2],
    fill_mode='nearest'
)

val_test_datagen = ImageDataGenerator(rescale=1./255)

# Cr√©er les g√©n√©rateurs
train_generator = train_datagen.flow_from_dataframe(
    dataframe=train_df,
    x_col='filename',
    y_col='class',
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=True,
    seed=SEED
)

val_generator = val_test_datagen.flow_from_dataframe(
    dataframe=val_df,
    x_col='filename',
    y_col='class',
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False
)

test_generator = val_test_datagen.flow_from_dataframe(
    dataframe=test_df,
    x_col='filename',
    y_col='class',
    target_size=IMAGE_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='categorical',
    shuffle=False
)

print(f"\nG√©n√©rateurs cr√©√©s:")
print(f"  Train: {len(train_df)} images")
print(f"  Validation: {len(val_df)} images")
print(f"  Test: {len(test_df)} images")

# 5. CALCUL DES POIDS DES CLASSES (bas√© sur le train)
def compute_class_weights_for_loss(labels):
    """Calcule les poids des classes"""
    unique_classes = np.unique(labels)
    class_weights = compute_class_weight(
        class_weight='balanced',
        classes=unique_classes,
        y=labels
    )
    class_weights = [float(w) for w in class_weights]
    return dict(zip(range(len(unique_classes)), class_weights))

# Calculer les poids sur le train seulement
class_weights = compute_class_weights_for_loss(train_df['class'])
print(f"\nPoids des classes calcul√©s (sur train): {class_weights}")

# 6. FOCAL LOSS
class FocalLoss(keras.losses.Loss):
    """Focal Loss pour d√©s√©quilibre de classes"""
    def __init__(self, gamma=2.0, alpha=0.25, name='focal_loss'):
        super().__init__(name=name)
        self.gamma = gamma
        self.alpha = alpha

    def call(self, y_true, y_pred):
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
        cross_entropy = -y_true * tf.math.log(y_pred)
        weight = self.alpha * y_true * tf.pow((1 - y_pred), self.gamma)
        focal_loss = weight * cross_entropy
        focal_loss = tf.reduce_sum(focal_loss, axis=1)
        return tf.reduce_mean(focal_loss)

# 7. MOD√àLE MobileNetV2
def build_mobilenetv2_model(num_classes=3):
    """Construit le mod√®le"""

    base_model = MobileNetV2(
        weights='imagenet',
        include_top=False,
        input_shape=(224, 224, 3)
    )

    # Fine-tuning
    base_model.trainable = True
    fine_tune_at = 100
    for layer in base_model.layers[:fine_tune_at]:
        layer.trainable = False

    # Architecture
    inputs = keras.Input(shape=(224, 224, 3))
    x = base_model(inputs, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.5)(x)
    x = layers.Dense(256, activation='relu')(x)
    x = layers.BatchNormalization()(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(128, activation='relu')(x)
    outputs = layers.Dense(num_classes, activation='softmax')(x)

    return keras.Model(inputs, outputs)

# Construction
model = build_mobilenetv2_model(num_classes=3)

# 8. COMPILATION
initial_learning_rate = 1e-4
lr_schedule = keras.optimizers.schedules.ExponentialDecay(
    initial_learning_rate,
    decay_steps=1000,
    decay_rate=0.96,
    staircase=True
)

optimizer = keras.optimizers.Adam(
    learning_rate=lr_schedule,
    beta_1=0.9,
    beta_2=0.999
)

# Utiliser Focal Loss
loss_fn = FocalLoss(gamma=2.0, alpha=0.25)

model.compile(
    optimizer=optimizer,
    loss=loss_fn,
    metrics=['accuracy', keras.metrics.Precision(name='precision'),
             keras.metrics.Recall(name='recall')]
)

model.summary()

# 9. CALLBACKS
callbacks = [
    EarlyStopping(
        monitor='val_loss',
        patience=10,
        restore_best_weights=True,
        verbose=1
    ),
    ReduceLROnPlateau(
        monitor='val_loss',
        factor=0.5,
        patience=5,
        min_lr=1e-7,
        verbose=1
    ),
    ModelCheckpoint(
        '/content/drive/MyDrive/MobileNetV2/best_model.h5',
        monitor='val_loss',
        save_best_only=True,
        mode='min',
        verbose=1
    ),
    CSVLogger(
        '/content/drive/MyDrive/MobileNetV2/training_log.csv',
        separator=',',
        append=False
    )
]

# 10. ENTRA√éNEMENT
print("\n=== D√âBUT DE L'ENTRA√éNEMENT ===")

# Calcul des steps
train_steps = len(train_df) // BATCH_SIZE
if len(train_df) % BATCH_SIZE != 0:
    train_steps += 1

val_steps = len(val_df) // BATCH_SIZE
if len(val_df) % BATCH_SIZE != 0:
    val_steps += 1

print(f"Train steps par epoch: {train_steps}")
print(f"Validation steps: {val_steps}")
print(f"Train images: {len(train_df)}")
print(f"Validation images: {len(val_df)}")

# Entra√Ænement standard (sans application manuelle des poids)
history = model.fit(
    train_generator,
    steps_per_epoch=train_steps,
    epochs=EPOCHS,
    validation_data=val_generator,
    validation_steps=val_steps,
    callbacks=callbacks,
    verbose=1
)

# 11. VISUALISATION
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Loss
axes[0, 0].plot(history.history['loss'], label='Train')
axes[0, 0].plot(history.history['val_loss'], label='Validation')
axes[0, 0].set_title('Loss')
axes[0, 0].set_xlabel('Epochs')
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Accuracy
axes[0, 1].plot(history.history['accuracy'], label='Train')
axes[0, 1].plot(history.history['val_accuracy'], label='Validation')
axes[0, 1].set_title('Accuracy')
axes[0, 1].set_xlabel('Epochs')
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Precision
axes[0, 2].plot(history.history['precision'], label='Train')
axes[0, 2].plot(history.history['val_precision'], label='Validation')
axes[0, 2].set_title('Precision')
axes[0, 2].set_xlabel('Epochs')
axes[0, 2].legend()
axes[0, 2].grid(True, alpha=0.3)

# Recall
axes[1, 0].plot(history.history['recall'], label='Train')
axes[1, 0].plot(history.history['val_recall'], label='Validation')
axes[1, 0].set_title('Recall')
axes[1, 0].set_xlabel('Epochs')
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Comparaison Train vs Val Accuracy
axes[1, 1].plot(history.history['accuracy'], label='Train Accuracy', color='blue')
axes[1, 1].plot(history.history['val_accuracy'], label='Val Accuracy', color='orange', linestyle='--')
axes[1, 1].set_title('Train vs Validation Accuracy')
axes[1, 1].set_xlabel('Epochs')
axes[1, 1].set_ylabel('Accuracy')
axes[1, 1].legend()
axes[1, 1].grid(True, alpha=0.3)

# Comparaison Train vs Val Loss
axes[1, 2].plot(history.history['loss'], label='Train Loss', color='blue')
axes[1, 2].plot(history.history['val_loss'], label='Val Loss', color='orange', linestyle='--')
axes[1, 2].set_title('Train vs Validation Loss')
axes[1, 2].set_xlabel('Epochs')
axes[1, 2].set_ylabel('Loss')
axes[1, 2].legend()
axes[1, 2].grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# 12. √âVALUATION SUR LE TEST SET
print("\n=== √âVALUATION SUR LE TEST SET ===")

# Charger le meilleur mod√®le
best_model = keras.models.load_model('/content/drive/MyDrive/MobileNetV2/best_model.h5',
                                    custom_objects={'FocalLoss': FocalLoss})

# Pr√©dictions sur le test set
test_steps = len(test_df) // BATCH_SIZE
if len(test_df) % BATCH_SIZE != 0:
    test_steps += 1

test_generator.reset()
y_true = []
y_pred = []
y_scores = []

for step in range(test_steps):
    x_batch, y_batch = test_generator.next()

    # Pr√©diction
    batch_pred = best_model.predict(x_batch, verbose=0)

    # Vraies classes
    batch_true = np.argmax(y_batch, axis=1)
    y_true.extend(batch_true)

    # Classes pr√©dites
    batch_pred_classes = np.argmax(batch_pred, axis=1)
    y_pred.extend(batch_pred_classes)

    # Scores
    y_scores.extend(batch_pred)

# M√©triques
test_accuracy = np.mean(np.array(y_true) == np.array(y_pred))
test_f1 = f1_score(y_true, y_pred, average='macro')

print(f"Test Accuracy: {test_accuracy:.4f}")
print(f"Test F1-Score (macro): {test_f1:.4f}")

# Rapport de classification
class_indices = train_generator.class_indices
reverse_indices = {v: k for k, v in class_indices.items()}

print("\nRapport de classification d√©taill√©:")
print(classification_report(
    y_true, y_pred,
    target_names=[reverse_indices[i] for i in range(len(class_indices))],
    digits=4
))

# Matrice de confusion
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=[reverse_indices[i] for i in range(len(class_indices))],
            yticklabels=[reverse_indices[i] for i in range(len(class_indices))])
plt.title('Matrice de confusion - Test Set')
plt.ylabel('Vraie classe')
plt.xlabel('Classe pr√©dite')
plt.tight_layout()
plt.savefig('/content/drive/MyDrive/MobileNetV2/confusion_matrix_test.png')
plt.show()

# 13. SAUVEGARDE
print("\n=== SAUVEGARDE ===")
best_model.save('/content/drive/MyDrive/MobileNetV2/final_age_classifier.h5')
print("‚úì Mod√®le final sauvegard√©")

# Sauvegarde des m√©tadonn√©es
class_indices_df = pd.DataFrame(list(class_indices.items()),
                                columns=['class_name', 'class_index'])
class_indices_df.to_csv('/content/drive/MyDrive/MobileNetV2/class_indices.csv', index=False)
print("‚úì Indices sauvegard√©s")

# Poids des classes
class_weights_df = pd.DataFrame(list(class_weights.items()),
                               columns=['class_index', 'weight'])
class_weights_df['class_name'] = [reverse_indices[i] for i in class_weights_df['class_index']]
class_weights_df.to_csv('/content/drive/MyDrive/MobileNetV2/class_weights.csv', index=False)
print("‚úì Poids des classes sauvegard√©s")

# Sauvegarde des DataFrames
train_df.to_csv('/content/drive/MyDrive/MobileNetV2/train_data.csv', index=False)
val_df.to_csv('/content/drive/MyDrive/MobileNetV2/val_data.csv', index=False)
test_df.to_csv('/content/drive/MyDrive/MobileNetV2/test_data.csv', index=False)
print("‚úì Datasets sauvegard√©s")

# 14. R√âSUM√â FINAL
print("\n" + "="*70)
print("R√âSUM√â COMPLET - DIVISION 70/15/15")
print("="*70)

print("\n1. DISTRIBUTION INITIALE:")
for class_name in class_names:
    print(f"   {class_name}: {class_counts[class_name]:6d} images")

print("\n2. CIBLES D'√âQUILIBRAGE (avant division):")
for class_name in class_names:
    print(f"   {class_name}: {target_counts[class_name]:6d} images")

print("\n3. DATASETS FINAUX (apr√®s √©quilibrage + division 70/15/15):")
print(f"   Train:      {len(train_df):6d} images ({len(train_df)/(len(train_df)+len(val_df)+len(test_df))*100:.1f}%)")
print(f"   Validation: {len(val_df):6d} images ({len(val_df)/(len(train_df)+len(val_df)+len(test_df))*100:.1f}%)")
print(f"   Test:       {len(test_df):6d} images ({len(test_df)/(len(train_df)+len(val_df)+len(test_df))*100:.1f}%)")
print(f"   Total:      {len(train_df)+len(val_df)+len(test_df):6d} images")

print("\n4. DISTRIBUTION PAR CLASSE DANS CHAQUE SET (70/15/15):")
print("\n   TRAIN:")
train_counts = Counter(train_df['class'])
for class_name in class_names:
    count = train_counts.get(class_name, 0)
    total_class = target_counts[class_name]
    print(f"     {class_name}: {count:6d} images ({count/total_class*100:.1f}% du total classe)")

print("\n   VALIDATION:")
val_counts = Counter(val_df['class'])
for class_name in class_names:
    count = val_counts.get(class_name, 0)
    total_class = target_counts[class_name]
    print(f"     {class_name}: {count:6d} images ({count/total_class*100:.1f}% du total classe)")

print("\n   TEST:")
test_counts = Counter(test_df['class'])
for class_name in class_names:
    count = test_counts.get(class_name, 0)
    total_class = target_counts[class_name]
    print(f"     {class_name}: {count:6d} images ({count/total_class*100:.1f}% du total classe)")

print("\n5. POIDS DES CLASSES (calcul√©s sur train):")
for idx, weight in class_weights.items():
    print(f"   {reverse_indices[idx]}: {weight:.4f}")

print("\n6. R√âSULTATS D'ENTRA√éNEMENT:")
print(f"   Final Train Loss:      {history.history['loss'][-1]:.4f}")
print(f"   Final Validation Loss: {history.history['val_loss'][-1]:.4f}")
print(f"   Final Train Accuracy:  {history.history['accuracy'][-1]:.4f}")
print(f"   Final Val Accuracy:    {history.history['val_accuracy'][-1]:.4f}")
print(f"   Final Train Precision: {history.history['precision'][-1]:.4f}")
print(f"   Final Val Precision:   {history.history['val_precision'][-1]:.4f}")
print(f"   Final Train Recall:    {history.history['recall'][-1]:.4f}")
print(f"   Final Val Recall:      {history.history['val_recall'][-1]:.4f}")

print("\n7. R√âSULTATS SUR TEST SET:")
print(f"   Test Accuracy: {test_accuracy:.4f}")
print(f"   Test F1-Score (macro): {test_f1:.4f}")

print("\n8. FICHIERS CR√â√âS:")
print(f"   - final_age_classifier.h5 (mod√®le final)")
print(f"   - best_model.h5 (meilleur mod√®le)")
print(f"   - training_log.csv (historique)")
print(f"   - class_indices.csv")
print(f"   - class_weights.csv")
print(f"   - train_data.csv")
print(f"   - val_data.csv")
print(f"   - test_data.csv")
print(f"   - confusion_matrix_test.png")

print("\n" + "="*70)
print("ENTRA√éNEMENT ET √âVALUATION TERMIN√âS AVEC SUCC√àS!")
print("="*70)

# 15. FONCTION DE PR√âDICTION
def predict_age(image_path):
    """Pr√©dit la classe d'√¢ge d'une image"""
    # Charger mod√®le
    custom_objects = {'FocalLoss': FocalLoss}
    loaded_model = keras.models.load_model('/content/drive/MyDrive/MobileNetV2/final_age_classifier.h5',
                                          custom_objects=custom_objects)

    # Charger indices
    indices_df = pd.read_csv('/content/drive/MyDrive/MobileNetV2/class_indices.csv')
    class_mapping = dict(zip(indices_df['class_index'], indices_df['class_name']))

    # Pr√©traitement
    img = keras.preprocessing.image.load_img(image_path, target_size=IMAGE_SIZE)
    img_array = keras.preprocessing.image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0) / 255.0

    # Pr√©diction
    predictions = loaded_model.predict(img_array, verbose=0)[0]
    predicted_class = np.argmax(predictions)
    confidence = predictions[predicted_class]

    result = {
        'classe': class_mapping[predicted_class],
        'confiance': float(confidence),
        'probabilites': {class_mapping[i]: float(p) for i, p in enumerate(predictions)}
    }

    return result

# 16. TEST SUR QUELQUES IMAGES DU TEST SET
print("\n=== PR√âDICTIONS SUR QUELQUES IMAGES DE TEST ===")

# S√©lectionner 2 images par classe du test set
test_samples = []
for class_name in class_names:
    class_test_samples = test_df[test_df['class'] == class_name].sample(
        min(2, len(test_df[test_df['class'] == class_name])),
        random_state=SEED
    )
    test_samples.extend(class_test_samples.to_dict('records'))

for i, sample in enumerate(test_samples):
    print(f"\nTest {i+1}:")
    print(f"  Image: {os.path.basename(sample['filename'])[:40]}...")
    print(f"  Vraie classe: {sample['class']}")

    result = predict_age(sample['filename'])
    print(f"  Pr√©diction: {result['classe']}")
    print(f"  Confiance: {result['confiance']:.1%}")

    if result['classe'] == sample['class']:
        print("  ‚úì Correct")
    else:
        print("  ‚úó Incorrect")
        print(f"  Probabilit√©s: {result['probabilites']}")

In [None]:
# ===============================================================================
# TEST TIME AUGMENTATION (TTA) - AM√âLIORATION IMM√âDIATE +2-5% ACCURACY
# Pas de r√©entra√Ænement n√©cessaire !
# ===============================================================================

import numpy as np
import pandas as pd
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from sklearn.metrics import classification_report, f1_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
import json
from tqdm import tqdm
import os

print("="*80)
print("üöÄ TEST TIME AUGMENTATION (TTA)")
print("Am√©lioration imm√©diate: +2-5% accuracy sans r√©entra√Æner!")
print("="*80)

# ==================== CONFIGURATION ====================
BASE_PATH = "/content/drive/MyDrive/MobileNetV2"
IMAGE_SIZE = (224, 224)
N_AUGMENTATIONS = 10  # Nombre d'augmentations par image (ajuster selon temps disponible)

# ==================== FOCAL LOSS ====================
class FocalLoss(keras.losses.Loss):
    def __init__(self, gamma=2.0, alpha=0.25, name='focal_loss', **kwargs):
        super().__init__(name=name, **kwargs)
        self.gamma = gamma
        self.alpha = alpha

    def call(self, y_true, y_pred):
        y_pred = tf.clip_by_value(y_pred, 1e-7, 1.0 - 1e-7)
        cross_entropy = -y_true * tf.math.log(y_pred)
        weight = self.alpha * y_true * tf.pow((1 - y_pred), self.gamma)
        focal_loss = weight * cross_entropy
        focal_loss = tf.reduce_sum(focal_loss, axis=1)
        return tf.reduce_mean(focal_loss)

    def get_config(self):
        config = super().get_config()
        config.update({
            'gamma': self.gamma,
            'alpha': self.alpha
        })
        return config

# ==================== CHARGEMENT ====================
print("\nüìÇ CHARGEMENT DES DONN√âES ET MOD√àLE...")

# Charger le test set
test_df = pd.read_csv(f'{BASE_PATH}/test_data.csv')
print(f"‚úì Test set: {len(test_df)} images")

# Distribution des classes
class_counts = test_df['class'].value_counts()
print(f"\nDistribution du test set:")
for class_name, count in class_counts.items():
    print(f"  {class_name}: {count} images ({count/len(test_df)*100:.1f}%)")

# Charger le meilleur mod√®le
print(f"\nüîÑ Chargement du mod√®le...")
try:
    model = keras.models.load_model(
        f'{BASE_PATH}/best_model.h5',
        custom_objects={'FocalLoss': FocalLoss}
    )
    print("‚úì Mod√®le charg√©: best_model.h5")
except:
    try:
        model = keras.models.load_model(
            f'{BASE_PATH}/final_age_classifier.h5',
            custom_objects={'FocalLoss': FocalLoss}
        )
        print("‚úì Mod√®le charg√©: final_age_classifier.h5")
    except Exception as e:
        print(f"‚ùå Erreur: {e}")
        raise

# Classes
class_names = sorted(test_df['class'].unique())
class_to_idx = {c: i for i, c in enumerate(class_names)}
idx_to_class = {i: c for c, i in class_to_idx.items()}

print(f"‚úì Classes d√©tect√©es: {class_names}")

# ==================== √âVALUATION BASELINE ====================
print("\n" + "="*80)
print("üìä √âVALUATION BASELINE (SANS TTA)")
print("="*80)

print("\n√âvaluation du mod√®le original...")

# G√©n√©rateur simple sans augmentation
baseline_datagen = ImageDataGenerator(rescale=1./255)
baseline_generator = baseline_datagen.flow_from_dataframe(
    dataframe=test_df,
    x_col='filename',
    y_col='class',
    target_size=IMAGE_SIZE,
    batch_size=32,
    class_mode='categorical',
    shuffle=False
)

# Pr√©dictions baseline
baseline_generator.reset()
y_true_baseline = []
y_pred_baseline = []

for i in range(len(baseline_generator)):
    x_batch, y_batch = next(baseline_generator)
    batch_pred = model.predict(x_batch, verbose=0)

    batch_true = np.argmax(y_batch, axis=1)
    batch_pred_classes = np.argmax(batch_pred, axis=1)

    y_true_baseline.extend(batch_true)
    y_pred_baseline.extend(batch_pred_classes)

# M√©triques baseline
baseline_accuracy = np.mean(np.array(y_true_baseline) == np.array(y_pred_baseline))
baseline_f1 = f1_score(y_true_baseline, y_pred_baseline, average='macro')

print(f"\nüìà R√âSULTATS BASELINE:")
print(f"  Accuracy: {baseline_accuracy:.4f} ({baseline_accuracy*100:.2f}%)")
print(f"  F1-Score (macro): {baseline_f1:.4f}")

print(f"\nüìã Rapport par classe (baseline):")
print(classification_report(
    y_true_baseline,
    y_pred_baseline,
    target_names=class_names,
    digits=4
))

# ==================== TTA FUNCTIONS ====================
print("\n" + "="*80)
print("üîß CONFIGURATION TTA")
print("="*80)

def create_tta_generator():
    """Cr√©er le g√©n√©rateur pour TTA avec augmentations agressives"""
    return ImageDataGenerator(
        rescale=1./255,
        rotation_range=20,      # Rotation mod√©r√©e
        width_shift_range=0.15,  # D√©calage horizontal
        height_shift_range=0.15, # D√©calage vertical
        horizontal_flip=True,    # Flip horizontal
        zoom_range=0.15,         # Zoom
        brightness_range=[0.85, 1.15],  # Variation luminosit√©
        fill_mode='nearest'
    )

def predict_single_image_with_tta(model, image_path, n_augmentations=10):
    """
    Pr√©dire une seule image avec TTA

    Args:
        model: mod√®le Keras
        image_path: chemin vers l'image
        n_augmentations: nombre d'augmentations √† moyenner

    Returns:
        prediction: classe pr√©dite
        confidence: confiance de la pr√©diction
        all_probs: probabilit√©s moyennes pour toutes les classes
    """
    # Charger l'image
    img = keras.preprocessing.image.load_img(image_path, target_size=IMAGE_SIZE)
    img_array = keras.preprocessing.image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)

    predictions = []

    # 1. Pr√©diction normale (sans augmentation)
    pred_normal = model.predict(img_array / 255.0, verbose=0)
    predictions.append(pred_normal)

    # 2. Pr√©dictions avec augmentations
    if n_augmentations > 1:
        tta_gen = create_tta_generator()
        augmented_gen = tta_gen.flow(img_array, batch_size=1, shuffle=False)

        for _ in range(n_augmentations - 1):
            img_aug = next(augmented_gen)
            pred_aug = model.predict(img_aug, verbose=0)
            predictions.append(pred_aug)

    # 3. Moyenne des pr√©dictions
    avg_prediction = np.mean(predictions, axis=0)[0]

    # 4. R√©sultats
    predicted_class = np.argmax(avg_prediction)
    confidence = avg_prediction[predicted_class]

    return predicted_class, confidence, avg_prediction

print(f"\n‚úì Configuration TTA:")
print(f"  Nombre d'augmentations par image: {N_AUGMENTATIONS}")
print(f"  Types d'augmentation: rotation, shift, flip, zoom, brightness")
print(f"  Strat√©gie: moyenne des pr√©dictions")

# ==================== √âVALUATION AVEC TTA ====================
print("\n" + "="*80)
print("üéØ √âVALUATION AVEC TTA")
print("="*80)

print(f"\nD√©but de l'√©valuation avec TTA sur {len(test_df)} images...")
print("‚è≥ Cela peut prendre quelques minutes...")

y_true_tta = []
y_pred_tta = []
confidences_tta = []
all_predictions_tta = []

# Barre de progression
for idx, row in tqdm(test_df.iterrows(), total=len(test_df), desc="TTA Progress"):
    img_path = row['filename']
    true_class = class_to_idx[row['class']]

    # Pr√©diction avec TTA
    pred_class, confidence, all_probs = predict_single_image_with_tta(
        model,
        img_path,
        n_augmentations=N_AUGMENTATIONS
    )

    y_true_tta.append(true_class)
    y_pred_tta.append(pred_class)
    confidences_tta.append(confidence)
    all_predictions_tta.append(all_probs)

# ==================== ANALYSE DES R√âSULTATS ====================
print("\n" + "="*80)
print("üìä R√âSULTATS FINAUX")
print("="*80)

# M√©triques TTA
tta_accuracy = np.mean(np.array(y_true_tta) == np.array(y_pred_tta))
tta_f1 = f1_score(y_true_tta, y_pred_tta, average='macro')
tta_f1_weighted = f1_score(y_true_tta, y_pred_tta, average='weighted')
tta_f1_per_class = f1_score(y_true_tta, y_pred_tta, average=None)

# Comparaison
improvement = (tta_accuracy - baseline_accuracy) * 100
improvement_f1 = (tta_f1 - baseline_f1) * 100

print(f"\n{'='*80}")
print("üéâ COMPARAISON BASELINE vs TTA")
print(f"{'='*80}")

print(f"\nüìà ACCURACY:")
print(f"  Baseline:  {baseline_accuracy:.4f} ({baseline_accuracy*100:.2f}%)")
print(f"  Avec TTA:  {tta_accuracy:.4f} ({tta_accuracy*100:.2f}%)")
print(f"  {'üöÄ GAIN:':12} {improvement:+.2f}%")

print(f"\nüìà F1-SCORE (MACRO):")
print(f"  Baseline:  {baseline_f1:.4f}")
print(f"  Avec TTA:  {tta_f1:.4f}")
print(f"  {'üöÄ GAIN:':12} {improvement_f1:+.4f}")

print(f"\nüìà F1-SCORE (WEIGHTED):")
print(f"  Avec TTA:  {tta_f1_weighted:.4f}")

print(f"\nüéØ F1-SCORE PAR CLASSE (TTA):")
for i, class_name in enumerate(class_names):
    baseline_f1_class = f1_score(
        [1 if y == i else 0 for y in y_true_baseline],
        [1 if y == i else 0 for y in y_pred_baseline],
        average='binary'
    )
    tta_f1_class = tta_f1_per_class[i]
    improvement_class = (tta_f1_class - baseline_f1_class) * 100

    print(f"  {class_name:8} | Baseline: {baseline_f1_class:.4f} | TTA: {tta_f1_class:.4f} | Gain: {improvement_class:+.2f}%")

# Rapport d√©taill√©
print(f"\nüìã RAPPORT DE CLASSIFICATION D√âTAILL√â (TTA):")
print(classification_report(
    y_true_tta,
    y_pred_tta,
    target_names=class_names,
    digits=4
))

# ==================== MATRICE DE CONFUSION ====================
print("\n" + "="*80)
print("üìä MATRICES DE CONFUSION")
print("="*80)

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Baseline
cm_baseline = confusion_matrix(y_true_baseline, y_pred_baseline)
sns.heatmap(cm_baseline, annot=True, fmt='d', cmap='Blues',
            xticklabels=class_names, yticklabels=class_names,
            ax=axes[0])
axes[0].set_title(f'Matrice de Confusion - Baseline\nAccuracy: {baseline_accuracy*100:.2f}%',
                  fontsize=14, fontweight='bold')
axes[0].set_ylabel('Vraie classe')
axes[0].set_xlabel('Classe pr√©dite')

# TTA
cm_tta = confusion_matrix(y_true_tta, y_pred_tta)
sns.heatmap(cm_tta, annot=True, fmt='d', cmap='Greens',
            xticklabels=class_names, yticklabels=class_names,
            ax=axes[1])
axes[1].set_title(f'Matrice de Confusion - Avec TTA\nAccuracy: {tta_accuracy*100:.2f}% (+{improvement:.2f}%)',
                  fontsize=14, fontweight='bold')
axes[1].set_ylabel('Vraie classe')
axes[1].set_xlabel('Classe pr√©dite')

plt.tight_layout()
plt.savefig(f'{BASE_PATH}/confusion_matrix_tta_comparison.png', dpi=300, bbox_inches='tight')
print(f"‚úì Matrices sauvegard√©es: confusion_matrix_tta_comparison.png")
plt.show()

# ==================== ANALYSE DES CONFIANCES ====================
print("\n" + "="*80)
print("üîç ANALYSE DES CONFIANCES")
print("="*80)

avg_confidence = np.mean(confidences_tta)
print(f"\nConfiance moyenne des pr√©dictions: {avg_confidence:.4f} ({avg_confidence*100:.2f}%)")

# Confiance par classe
print(f"\nConfiance moyenne par classe:")
for i, class_name in enumerate(class_names):
    class_mask = np.array(y_pred_tta) == i
    if np.sum(class_mask) > 0:
        class_confidence = np.mean([confidences_tta[j] for j in range(len(confidences_tta)) if class_mask[j]])
        print(f"  {class_name}: {class_confidence:.4f} ({class_confidence*100:.2f}%)")

# Distribution des confiances
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.hist(confidences_tta, bins=50, color='skyblue', edgecolor='black', alpha=0.7)
plt.axvline(avg_confidence, color='red', linestyle='--', linewidth=2, label=f'Moyenne: {avg_confidence:.2f}')
plt.title('Distribution des Confiances (TTA)', fontsize=14, fontweight='bold')
plt.xlabel('Confiance')
plt.ylabel('Nombre de pr√©dictions')
plt.legend()
plt.grid(True, alpha=0.3)

# Confiance par classe
plt.subplot(1, 2, 2)
confidence_by_class = []
for i, class_name in enumerate(class_names):
    class_mask = np.array(y_pred_tta) == i
    class_confidences = [confidences_tta[j] for j in range(len(confidences_tta)) if class_mask[j]]
    confidence_by_class.append(class_confidences)

plt.boxplot(confidence_by_class, labels=class_names, patch_artist=True,
            boxprops=dict(facecolor='lightgreen', alpha=0.7))
plt.title('Confiances par Classe (TTA)', fontsize=14, fontweight='bold')
plt.ylabel('Confiance')
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f'{BASE_PATH}/confidence_analysis_tta.png', dpi=300, bbox_inches='tight')
print(f"‚úì Analyse des confiances sauvegard√©e: confidence_analysis_tta.png")
plt.show()

# ==================== SAUVEGARDE DES R√âSULTATS ====================
print("\n" + "="*80)
print("üíæ SAUVEGARDE DES R√âSULTATS")
print("="*80)

# CSV avec pr√©dictions d√©taill√©es
results_df = test_df.copy()
results_df['true_class_idx'] = y_true_tta
results_df['predicted_class'] = [idx_to_class[i] for i in y_pred_tta]
results_df['confidence'] = confidences_tta
results_df['correct'] = np.array(y_true_tta) == np.array(y_pred_tta)

for i, class_name in enumerate(class_names):
    results_df[f'prob_{class_name}'] = [all_predictions_tta[j][i] for j in range(len(all_predictions_tta))]

results_df.to_csv(f'{BASE_PATH}/test_predictions_tta.csv', index=False)
print(f"‚úì Pr√©dictions d√©taill√©es: test_predictions_tta.csv")

# JSON avec r√©sum√©
summary = {
    'baseline': {
        'accuracy': float(baseline_accuracy),
        'f1_macro': float(baseline_f1)
    },
    'tta': {
        'accuracy': float(tta_accuracy),
        'f1_macro': float(tta_f1),
        'f1_weighted': float(tta_f1_weighted),
        'n_augmentations': N_AUGMENTATIONS,
        'avg_confidence': float(avg_confidence)
    },
    'improvement': {
        'accuracy_gain': float(improvement),
        'f1_gain': float(improvement_f1)
    },
    'per_class': {
        class_names[i]: {
            'f1_score': float(tta_f1_per_class[i]),
            'precision': float(classification_report(y_true_tta, y_pred_tta, output_dict=True)[class_names[i]]['precision']),
            'recall': float(classification_report(y_true_tta, y_pred_tta, output_dict=True)[class_names[i]]['recall'])
        }
        for i in range(len(class_names))
    }
}

with open(f'{BASE_PATH}/tta_results_summary.json', 'w') as f:
    json.dump(summary, f, indent=4)
print(f"‚úì R√©sum√© JSON: tta_results_summary.json")

# ==================== ANALYSE DES ERREURS ====================
print("\n" + "="*80)
print("‚ùå ANALYSE DES ERREURS")
print("="*80)

errors_tta = np.where(np.array(y_true_tta) != np.array(y_pred_tta))[0]
print(f"\nNombre total d'erreurs avec TTA: {len(errors_tta)}/{len(y_true_tta)} ({len(errors_tta)/len(y_true_tta)*100:.2f}%)")

# Erreurs corrig√©es par TTA
errors_baseline_set = set(np.where(np.array(y_true_baseline) != np.array(y_pred_baseline))[0])
errors_tta_set = set(errors_tta)
corrected_errors = errors_baseline_set - errors_tta_set
new_errors = errors_tta_set - errors_baseline_set

print(f"\nüîß Erreurs CORRIG√âES par TTA: {len(corrected_errors)}")
print(f"‚ö†Ô∏è  Nouvelles erreurs avec TTA: {len(new_errors)}")
print(f"‚úÖ Bilan net: {len(corrected_errors) - len(new_errors)} corrections")

# Taux d'erreur par classe
print(f"\nüìä TAUX D'ERREUR PAR CLASSE:")
print(f"{'Classe':<10} {'Baseline':<12} {'TTA':<12} {'Am√©lioration':<15}")
print("-" * 50)
for i, class_name in enumerate(class_names):
    class_indices_baseline = [j for j in range(len(y_true_baseline)) if y_true_baseline[j] == i]
    class_indices_tta = [j for j in range(len(y_true_tta)) if y_true_tta[j] == i]

    errors_baseline_class = sum([1 for j in class_indices_baseline if y_true_baseline[j] != y_pred_baseline[j]])
    errors_tta_class = sum([1 for j in class_indices_tta if y_true_tta[j] != y_pred_tta[j]])

    error_rate_baseline = errors_baseline_class / len(class_indices_baseline) * 100 if len(class_indices_baseline) > 0 else 0
    error_rate_tta = errors_tta_class / len(class_indices_tta) * 100 if len(class_indices_tta) > 0 else 0
    improvement_error = error_rate_baseline - error_rate_tta

    print(f"{class_name:<10} {error_rate_baseline:>6.2f}%      {error_rate_tta:>6.2f}%      {improvement_error:>+6.2f}%")

# ==================== RECOMMANDATIONS ====================
print("\n" + "="*80)
print("üí° RECOMMANDATIONS")
print("="*80)

if improvement >= 2.0:
    print(f"\n‚úÖ EXCELLENT! Gain de {improvement:.2f}% avec TTA")
    print(f"   Votre mod√®le b√©n√©ficie bien de l'augmentation de donn√©es.")
elif improvement >= 1.0:
    print(f"\nüëç BON! Gain de {improvement:.2f}% avec TTA")
    print(f"   Am√©lioration notable mais potentiel pour plus.")
else:
    print(f"\n‚ö†Ô∏è  GAIN LIMIT√â: {improvement:.2f}% avec TTA")
    print(f"   Le mod√®le ne b√©n√©ficie pas beaucoup de TTA.")

print(f"\nüéØ PROCHAINES √âTAPES:")

if tta_accuracy < 0.80:
    print(f"   1. ‚ö° URGENT: Essayer l'ensemble de mod√®les (gain +3-5%)")
    print(f"   2. üîß Affiner avec plus d'augmentations (N={N_AUGMENTATIONS*2})")
    print(f"   3. üéì Fine-tuning cibl√© sur classe 21-50")
elif tta_accuracy < 0.85:
    print(f"   1. üéØ Combiner TTA + Ensemble pour viser 85%+")
    print(f"   2. üîç Analyser les erreurs restantes")
    print(f"   3. üìä Optionnel: Ajuster les seuils de d√©cision")
else:
    print(f"   1. üéâ Excellent r√©sultat! Pr√™t pour production")
    print(f"   2. üìà Optionnel: Ensemble pour optimisation finale")
    print(f"   3. üöÄ D√©ploiement recommand√©")

print(f"\nüíæ FICHIERS CR√â√âS:")
print(f"   - confusion_matrix_tta_comparison.png")
print(f"   - confidence_analysis_tta.png")
print(f"   - test_predictions_tta.csv")
print(f"   - tta_results_summary.json")

# ==================== R√âSUM√â FINAL ====================
print("\n" + "="*80)
print("üéâ R√âSUM√â FINAL")
print("="*80)

print(f"""
‚ïî‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïó
‚ïë                    R√âSULTATS TTA FINAUX                        ‚ïë
‚ï†‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ï£
‚ïë  Accuracy Baseline:     {baseline_accuracy*100:5.2f}%                               ‚ïë
‚ïë  Accuracy avec TTA:     {tta_accuracy*100:5.2f}%                               ‚ïë
‚ïë  üöÄ GAIN TOTAL:          {improvement:+5.2f}%                               ‚ïë
‚ï†‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ï£
‚ïë  F1-Score (macro):      {tta_f1:.4f}                              ‚ïë
‚ïë  F1-Score (weighted):   {tta_f1_weighted:.4f}                              ‚ïë
‚ïë  Confiance moyenne:     {avg_confidence*100:5.2f}%                               ‚ïë
‚ï†‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ï£
‚ïë  Erreurs corrig√©es:     {len(corrected_errors):>4} images                         ‚ïë
‚ïë  Nouvelles erreurs:     {len(new_errors):>4} images                         ‚ïë
‚ïë  Bilan net:             {len(corrected_errors) - len(new_errors):>+4} corrections                    ‚ïë
‚ïö‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïê‚ïù
""")

print("‚úÖ √âVALUATION TTA TERMIN√âE AVEC SUCC√àS!\n")

# ==================== FONCTION DE PR√âDICTION POUR PRODUCTION ====================
print("="*80)
print("üîß FONCTION DE PR√âDICTION POUR PRODUCTION")
print("="*80)

def predict_with_tta_production(image_path, model, n_augmentations=N_AUGMENTATIONS):
    """
    Fonction de pr√©diction pr√™te pour production

    Usage:
        result = predict_with_tta_production('path/to/image.jpg', model)
        print(f"Classe: {result['classe']}")
        print(f"Confiance: {result['confiance']:.2%}")
    """
    pred_class, confidence, all_probs = predict_single_image_with_tta(
        model, image_path, n_augmentations
    )

    result = {
        'classe': idx_to_class[pred_class],
        'confiance': float(confidence),
        'probabilites': {
            class_names[i]: float(all_probs[i])
            for i in range(len(class_names))
        }
    }

    return result

print("\n‚úì Fonction de pr√©diction pr√™te:")
print("""
# Exemple d'utilisation:
result = predict_with_tta_production('chemin/vers/image.jpg', model)
print(f"Pr√©diction: {result['classe']} (confiance: {result['confiance']:.2%})")
""")

print("\n" + "="*80)
print("üéä TERMIN√â! Votre mod√®le est maintenant plus performant avec TTA!")
print("="*80)