In [1]:
# Imports
import pandas as pd
import numpy as np
import tensorflow as tf
from tensorflow.keras import layers, models
from sklearn.preprocessing import LabelEncoder
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import os
import glob
from sklearn.metrics import confusion_matrix, classification_report
import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
# Define constants
IMG_SIZE = (224, 224)
BATCH_SIZE = 32
EPOCHS = 30
CHECKPOINT_DIR = './checkpoints'
os.makedirs(CHECKPOINT_DIR, exist_ok=True)

# Load dataset labels
train_df = pd.read_csv('./images/FairFace/train_labels.csv')
val_df = pd.read_csv('./images/FairFace/val_labels.csv')

# Filter and map desired classes
desired_classes = ['Black', 'East Asian', 'Indian', 'Middle Eastern', 'White']
race_mapping = {
    'Black': 'African',
    'East Asian': 'Asian',
    'Indian': 'Asian',
    'Middle Eastern': 'Middle Eastern',
    'White': 'Western'
}
train_df = train_df[train_df['race'].isin(desired_classes)].copy()
val_df = val_df[val_df['race'].isin(desired_classes)].copy()
train_df['race'] = train_df['race'].map(race_mapping)
val_df['race'] = val_df['race'].map(race_mapping)

# Encode labels
le = LabelEncoder()
train_df['label'] = le.fit_transform(train_df['race']).astype(str)
val_df['label'] = le.transform(val_df['race']).astype(str)

# Ensure file paths are strings
train_df['file'] = train_df['file'].astype(str)
val_df['file'] = val_df['file'].astype(str)

# Image data generators
train_datagen = ImageDataGenerator(
    rescale=1./255,
    rotation_range=20,
    zoom_range=0.15,
    horizontal_flip=True
)
val_datagen = ImageDataGenerator(rescale=1./255)

In [3]:
# Create generators
train_generator = train_datagen.flow_from_dataframe(
    dataframe=train_df,
    directory='./images/FairFace/',
    x_col='file',
    y_col='label',
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='sparse',
    shuffle=True
)

val_generator = val_datagen.flow_from_dataframe(
    dataframe=val_df,
    directory='./images/FairFace/',
    x_col='file',
    y_col='label',
    target_size=IMG_SIZE,
    batch_size=BATCH_SIZE,
    class_mode='sparse',
    shuffle=False
)


Found 62582 validated image filenames belonging to 4 classes.
Found 7916 validated image filenames belonging to 4 classes.


In [4]:
# Build the model
model = models.Sequential([
    layers.Input(shape=(224, 224, 3)),
    layers.Conv2D(32, (3,3), activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(64, (3,3), activation='relu'),
    layers.MaxPooling2D(),
    layers.Conv2D(128, (3,3), activation='relu'),
    layers.MaxPooling2D(),
    layers.Flatten(),
    layers.Dense(128, activation='relu'),
    layers.Dense(len(le.classes_), activation='softmax')
])

# Compile the model
model.compile(
    optimizer='adam',
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

In [5]:
# Custom callback to save weights every 10 epochs
class CustomWeightsCheckpoint(tf.keras.callbacks.Callback):
    def __init__(self, save_freq, save_path):
        super().__init__()
        self.save_freq = save_freq
        self.save_path = save_path

    def on_epoch_end(self, epoch, logs=None):
        if (epoch + 1) % self.save_freq == 0:
            filename = f'{self.save_path}/weights_epoch_{epoch + 1:02d}.h5'
            self.model.save_weights(filename)
            print(f'\n✅ Saved weights at: {filename}')

checkpoint_cb = CustomWeightsCheckpoint(
    save_freq=10,
    save_path=CHECKPOINT_DIR
)

# Load latest weights if available
weights_files = sorted(
    glob.glob(os.path.join(CHECKPOINT_DIR, 'weights_epoch_*.h5')),
    key=os.path.getmtime
)
initial_epoch = 0
if weights_files:
    latest_weights = weights_files[-1]
    print(f'🔁 Resuming from: {latest_weights}')
    model.load_weights(latest_weights)
    initial_epoch = int(os.path.basename(latest_weights).split('_')[-1].split('.')[0])


In [6]:
# Train the model
model.fit(
    train_generator,
    validation_data=val_generator,
    epochs=EPOCHS,
    initial_epoch=initial_epoch,
    callbacks=[checkpoint_cb]
)

  self._warn_if_super_not_called()


Epoch 1/30
[1m1956/1956[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2588s[0m 1s/step - accuracy: 0.4550 - loss: 1.2293 - val_accuracy: 0.5606 - val_loss: 1.0361
Epoch 2/30
[1m1956/1956[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m2776s[0m 1s/step - accuracy: 0.5395 - loss: 1.0610 - val_accuracy: 0.5855 - val_loss: 0.9814
Epoch 3/30
[1m1867/1956[0m [32m━━━━━━━━━━━━━━━━━━━[0m[37m━[0m [1m1:35[0m 1s/step - accuracy: 0.5747 - loss: 0.9917

KeyboardInterrupt: 

In [None]:
# Predict the classes
val_preds = model.predict(val_generator)
predicted_labels = np.argmax(val_preds, axis=1)
true_labels = val_generator.classes

# Confusion Matrix
cm = confusion_matrix(true_labels, predicted_labels)

# Plot Confusion Matrix
plt.figure(figsize=(8,6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=val_generator.class_indices.keys(),
            yticklabels=val_generator.class_indices.keys())
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.show()

In [None]:
# Classification Report
print(classification_report(true_labels, predicted_labels, target_names=val_generator.class_indices.keys()))

In [None]:
# Save final full model
model.save('ethnicity_classifier_model_final.h5')