# Modèle d'ethnicité UTKFace avec Data Augmentation

**Data augmentation appliquée :**
- Flip horizontal
- Rotation ±15°
- Luminosité 0.8-1.2
- Zoom ±10%
- Décalage ±10%

**NON appliqués (fausserait la couleur de peau) :**
- Flip vertical
- Rotation 90°
- Modification de teinte

**Dataset :** jangedoo/utkface-new

## 1. Installation et chargement des données

In [None]:
# Installation de kagglehub si nécessaire
!pip install -q kagglehub

In [None]:
import os
import numpy as np
from PIL import Image
import kagglehub

# Télécharger le dataset via kagglehub
path = kagglehub.dataset_download("jangedoo/utkface-new")
print(f"Path : {path}")

# Explorer la structure du dataset
print("\nContenu du dossier :")
for item in os.listdir(path):
    item_path = os.path.join(path, item)
    if os.path.isdir(item_path):
        print(f"  [DIR] {item} ({len(os.listdir(item_path))} fichiers)")
    else:
        print(f"  [FILE] {item}")

In [None]:
# Trouver automatiquement le dossier contenant les images
possible_folders = ["UTKFace", "utkface_aligned_cropped", "crop_part1", ""]
image_folder = None

for folder in possible_folders:
    test_path = os.path.join(path, folder) if folder else path
    if os.path.exists(test_path):
        files = os.listdir(test_path)
        jpg_files = [f for f in files if f.endswith(".jpg")]
        if jpg_files:
            image_folder = test_path
            print(f"Dossier d'images trouvé : {image_folder}")
            break

if image_folder is None:
    raise FileNotFoundError("Impossible de trouver le dossier contenant les images UTKFace")

image_files = [f for f in os.listdir(image_folder) if f.endswith(".jpg")]
print(f"Nombre de fichiers .jpg trouvés : {len(image_files)}")

In [None]:
images = []
labels = []

for file in image_files:
    try:
        parts = file.split("_")
        age = int(parts[0])
        gender = int(parts[1])
        try:
            race = int(parts[2])
        except:
            race = 4

        img = Image.open(os.path.join(image_folder, file)).convert("RGB").resize((128, 128))
        images.append(np.array(img))
        labels.append([age, gender, race])
    except:
        continue

images = np.array(images)
labels = np.array(labels)
print(f"Images chargées : {len(images)}")

## 2. Imports et préparation des données

In [None]:
import tensorflow as tf
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, confusion_matrix
from tensorflow.keras.utils import to_categorical
from tensorflow.keras import Sequential
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Flatten, Dense, Dropout, BatchNormalization
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt
import seaborn as sns

print(f"TensorFlow version : {tf.__version__}")
print(f"GPU disponible : {tf.config.list_physical_devices('GPU')}")

In [None]:
# Extraire X et y_ethnicity
X = images
y_ethnicity = labels[:, 2]  # La 3ème colonne = ethnie

# Split train/test (80/20)
X_train, X_test, y_train, y_test = train_test_split(
    X, y_ethnicity,
    test_size=0.2,
    random_state=42
)

# Normaliser [0-1]
X_train = X_train.astype('float32') / 255.0
X_test = X_test.astype('float32') / 255.0

# One-hot encoding (5 classes d'ethnicité)
y_train_cat = to_categorical(y_train, num_classes=5)
y_test_cat = to_categorical(y_test, num_classes=5)

print(f"X_train : {X_train.shape}")
print(f"X_test : {X_test.shape}")

## 3. Data Augmentation avec ImageDataGenerator

In [None]:
# Configuration de l'augmentation pour l'entraînement
# Transformations adaptées à la détection d'ethnicité
train_datagen = ImageDataGenerator(
    horizontal_flip=True,           # Flip horizontal
    rotation_range=15,              # Rotation ±15°
    brightness_range=[0.8, 1.2],    # Luminosité 0.8-1.2
    zoom_range=0.1,                 # Zoom ±10%
    width_shift_range=0.1,          # Décalage horizontal ±10%
    height_shift_range=0.1,         # Décalage vertical ±10%
    fill_mode='nearest',            # Remplissage des pixels manquants
    validation_split=0.2            # 20% pour la validation
)

# Pas d'augmentation pour la validation (juste rescaling déjà fait)
val_datagen = ImageDataGenerator(
    validation_split=0.2
)

In [None]:
# Créer les générateurs
BATCH_SIZE = 64

train_generator = train_datagen.flow(
    X_train, y_train_cat,
    batch_size=BATCH_SIZE,
    subset='training',
    shuffle=True
)

# Validation sans augmentation (images originales)
val_generator = val_datagen.flow(
    X_train, y_train_cat,
    batch_size=BATCH_SIZE,
    subset='validation',
    shuffle=False
)

print(f"Échantillons d'entraînement : {train_generator.n}")
print(f"Échantillons de validation : {val_generator.n}")

## 4. Création du modèle CNN

In [None]:
model = Sequential([
    Conv2D(32, (3, 3), activation='relu', input_shape=(128, 128, 3)),
    BatchNormalization(),
    MaxPooling2D((2, 2)),
    Dropout(0.25),

    Conv2D(64, (3, 3), activation='relu'),
    BatchNormalization(),
    MaxPooling2D((2, 2)),
    Dropout(0.25),

    Conv2D(128, (3, 3), activation='relu'),
    BatchNormalization(),
    MaxPooling2D((2, 2)),
    Dropout(0.25),

    Flatten(),
    Dense(256, activation='relu'),
    Dropout(0.5),
    Dense(5, activation='softmax')
])

model.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

model.summary()

## 5. Entraînement avec Data Augmentation à la volée

In [None]:
early_stop = EarlyStopping(
    monitor='val_loss',
    patience=5,
    restore_best_weights=True
)

# Calculer les steps par epoch
steps_per_epoch = train_generator.n // BATCH_SIZE
validation_steps = val_generator.n // BATCH_SIZE

history = model.fit(
    train_generator,
    epochs=30,
    steps_per_epoch=steps_per_epoch,
    validation_data=val_generator,
    validation_steps=validation_steps,
    callbacks=[early_stop]
)

## 6. Visualisation de l'entraînement

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

axes[0].plot(history.history['loss'], label='Train')
axes[0].plot(history.history['val_loss'], label='Validation')
axes[0].set_title('Loss')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].legend()

axes[1].plot(history.history['accuracy'], label='Train')
axes[1].plot(history.history['val_accuracy'], label='Validation')
axes[1].set_title('Accuracy')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].legend()

plt.tight_layout()
plt.show()

## 7. Évaluation du modèle

In [None]:
y_pred = model.predict(X_test).argmax(axis=1)

loss, accuracy = model.evaluate(X_test, y_test_cat)
print(f"\nAccuracy sur le test set : {accuracy*100:.2f}%")

eth_labels = ['Blanc', 'Noir', 'Asiatique', 'Indien', 'Autre']
print("\nRapport de classification :")
print(classification_report(y_test, y_pred, target_names=eth_labels))

## 8. Matrice de confusion

In [None]:
cm = confusion_matrix(y_test, y_pred)

plt.figure(figsize=(8, 6))
sns.heatmap(
    cm,
    annot=True,
    fmt='d',
    cmap='Blues',
    xticklabels=eth_labels,
    yticklabels=eth_labels
)
plt.title('Matrice de confusion - Ethnicité (avec Data Augmentation)')
plt.xlabel('Prédit')
plt.ylabel('Réel')
plt.tight_layout()
plt.show()

## 9. Sauvegarde du modèle

In [None]:
model.save('ethnicity_model_augmented.keras')
print("Modèle sauvegardé : ethnicity_model_augmented.keras")

# Télécharger le modèle (Colab)
from google.colab import files
files.download('ethnicity_model_augmented.keras')

## 10. Visualisation des augmentations (Optionnel)

In [None]:
def visualize_augmentations(image, datagen, n_examples=6):
    """Visualise les différentes augmentations appliquées à une image."""
    fig, axes = plt.subplots(2, 3, figsize=(10, 7))
    axes = axes.flatten()

    # Image originale
    axes[0].imshow(image)
    axes[0].set_title('Original')
    axes[0].axis('off')

    # Générer des versions augmentées
    img_array = image.reshape((1,) + image.shape)
    aug_iter = datagen.flow(img_array, batch_size=1)

    for i in range(1, n_examples):
        aug_img = next(aug_iter)[0]
        aug_img = np.clip(aug_img, 0, 1)  # Évite les valeurs hors [0,1] dues à la luminosité
        axes[i].imshow(aug_img)
        axes[i].set_title(f'Augmentation {i}')
        axes[i].axis('off')

    plt.suptitle('Exemples de Data Augmentation')
    plt.tight_layout()
    plt.show()

# Visualiser les augmentations sur une image exemple
print("Visualisation des augmentations :")
visualize_augmentations(X_train[0], train_datagen)