# Training on FER+

## Setup

### Installation of packages

In [None]:
%pip install -q tensorflow tensorflow-addons tensorflow-hub tensorflow-datasets matplotlib seaborn scikit-learn pandas numpy

### Imports

In [None]:
import pandas as pd
import numpy as np
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.applications import MobileNetV2
from tensorflow.keras.models import Model
from tensorflow.keras.layers import Dense, Dropout, GlobalAveragePooling2D, Input
import tensorflow as tf
from tensorflow.keras.callbacks import ReduceLROnPlateau, ModelCheckpoint
from tensorflow.keras import mixed_precision
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import seaborn as sns

### Mixed precision and device selection

In [None]:
# Set mixed precision policy
mixed_precision.set_global_policy('mixed_float16')

# use GPU if available
physical_devices = tf.config.list_physical_devices('GPU')
print("Num GPUs Available: ", len(physical_devices))
if len(physical_devices) > 0:
    print("Using GPU")
else:
    print("Using CPU")

### Plotting functions

In [None]:
def plot_confusion_matrix(model, dataset, num_classes, class_names):
    # Get predictions
    y_pred = []
    y_true = []
    
    for x, y in dataset:
        pred = model.predict(x)
        y_pred.extend(np.argmax(pred, axis=1))
        y_true.extend(np.argmax(y.numpy(), axis=1))
    
    # Calculate confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    
    # Normalize confusion matrix
    cm_norm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
    
    # Plot
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm_norm, annot=True, fmt='.2f', cmap='Blues', 
                xticklabels=class_names if class_names else range(num_classes),
                yticklabels=class_names if class_names else range(num_classes))
    plt.title('Normalized Confusion Matrix')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()
    
    # Also show raw counts
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=class_names if class_names else range(num_classes),
                yticklabels=class_names if class_names else range(num_classes))
    plt.title('Confusion Matrix (Raw Counts)')
    plt.ylabel('True Label')
    plt.xlabel('Predicted Label')
    plt.show()

# plot the training history
def plot_training_history(history):
    # Plot training & validation accuracy values
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title('Model accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')

    # Plot training & validation loss values
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')

    plt.show()

def visualize_predictions(model, dataset, class_names, num_classes, num_samples=5):
    plt.figure(figsize=(15, num_samples*3))
    
    count = 0
    for images, labels in dataset:
        for i in range(min(len(images), num_samples - count)):
            image = images[i].numpy()
            true_label = np.argmax(labels[i].numpy())
            
            # Make prediction
            pred = model.predict(tf.expand_dims(images[i], 0))
            pred_label = np.argmax(pred)
            
            # Display image and predictions
            plt.subplot(num_samples, 3, count*3 + 1)
            # Convert back to grayscale for visualization
            plt.imshow(image[:,:,0], cmap='gray')
            plt.title(f"True: {class_names[true_label]}")
            plt.axis('off')
            
            plt.subplot(num_samples, 3, count*3 + 2)
            plt.bar(range(num_classes), pred[0])
            plt.xticks(range(num_classes), class_names, rotation=45)
            plt.title(f"Prediction: {class_names[pred_label]}")
            
            plt.subplot(num_samples, 3, count*3 + 3)
            plt.text(0.5, 0.5, f"Confidence: {pred[0][pred_label]:.4f}\nCorrect: {'✓' if true_label == pred_label else '✗'}", 
                    ha='center', fontsize=12)
            plt.axis('off')
            
            count += 1
            if count >= num_samples:
                break
        if count >= num_samples:
            break
    
    plt.tight_layout()
    plt.show()


## Data preparations

### Load data

In [None]:
data = pd.read_csv('../dataset/fer2013++.csv')

# Use official dataset splits
train_df = data[data.Usage=='Training']
val_df = data[data.Usage=='PublicTest']    # validation
test_df = data[data.Usage=='PrivateTest']  # final report

print(f"Train samples: {len(train_df)}")
print(f"Validation samples: {len(val_df)}")
print(f"Test samples: {len(test_df)}")

### Process data

In [None]:
class_counts = data['emotion'].value_counts().to_dict()
total_samples = len(data)
num_classes = 8
class_names = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral', 'Contempt']

# Inverse frequency weighting
class_weights = {class_id: total_samples / (num_classes * count) 
                for class_id, count in class_counts.items()}
print("Class weights (inverse frequency):", class_weights)

# Process training data
train_pixels = train_df['pixels'].tolist()
trainX = np.array([np.fromstring(pixel_sequence, sep=' ') for pixel_sequence in train_pixels])
trainX = trainX.reshape((-1, 48, 48, 1))
# trainX = trainX / 255.0
trainY = to_categorical(train_df['emotion'], num_classes=num_classes)

# Process validation data
val_pixels = val_df['pixels'].tolist()
valX = np.array([np.fromstring(pixel_sequence, sep=' ') for pixel_sequence in val_pixels])
valX = valX.reshape((-1, 48, 48, 1))
# valX = valX / 255.0
valY = to_categorical(val_df['emotion'], num_classes=num_classes)

# Process test data
test_pixels = test_df['pixels'].tolist()
testX = np.array([np.fromstring(pixel_sequence, sep=' ') for pixel_sequence in test_pixels])
testX = testX.reshape((-1, 48, 48, 1))
# testX = testX / 255.0
testY = to_categorical(test_df['emotion'], num_classes=num_classes)

print(f"Train shape: {trainX.shape}, {trainY.shape}")
print(f"Validation shape: {valX.shape}, {valY.shape}")
print(f"Test shape: {testX.shape}, {testY.shape}")

In [None]:
def augment(x, y):
    x = tf.image.random_flip_left_right(x)
    x = tf.image.random_brightness(x, 25)
    x = tf.image.random_contrast(x, 0.8, 1.2)
    return x, y

def preprocess_image(x, y):
    x = tf.image.resize(x, (224,224))
    x = tf.image.grayscale_to_rgb(x)
    x = tf.keras.applications.mobilenet_v2.preprocess_input(x)  # 0-255 –> [-1,1]
    return x, y

batch_size = 128

train_dataset = (tf.data.Dataset.from_tensor_slices((trainX, trainY))
            .shuffle(len(trainX))
            .map(augment,  num_parallel_calls=tf.data.AUTOTUNE)
            .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
            .batch(batch_size=batch_size).prefetch(tf.data.AUTOTUNE))

val_dataset   = (tf.data.Dataset.from_tensor_slices((valX, valY))
            .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
            .batch(batch_size=batch_size).prefetch(tf.data.AUTOTUNE))

test_dataset  = (tf.data.Dataset.from_tensor_slices((testX, testY))
            .map(preprocess_image, num_parallel_calls=tf.data.AUTOTUNE)
            .batch(batch_size=batch_size).prefetch(tf.data.AUTOTUNE))

In [None]:
# Input shape must match MobileNetV2 expectations
input_shape = (224, 224, 3)
base_model = MobileNetV2(include_top=False, weights='imagenet', input_shape=input_shape)

## Model definition

In [None]:
for layer in base_model.layers:
    layer.trainable = False

x = base_model.output
x = GlobalAveragePooling2D()(x)
x = Dropout(0.3)(x)
x = Dense(192, activation='relu')(x)
x = Dropout(0.3)(x)
outputs = Dense(num_classes, activation='softmax')(x)

## Training

### Base model

In [None]:
model = Model(inputs=base_model.input, outputs=outputs)
model.compile(optimizer=tf.keras.optimizers.Adam(1e-3), loss='categorical_crossentropy', metrics=['accuracy'])

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=5, min_lr=0.00001)

history = model.fit(train_dataset, 
                    validation_data=val_dataset, 
                    epochs=15,
                    class_weight=class_weights,
                    callbacks=[reduce_lr])

print("Training history for base model:")
plot_training_history(history)

### Finetune

In [None]:
for layer in base_model.layers[-20:]:
    layer.trainable = True

model.compile(optimizer=tf.keras.optimizers.Adam(1e-4), 
              loss='categorical_crossentropy', metrics=['accuracy'])

reduce_lr = ReduceLROnPlateau(monitor='val_loss', factor=0.2, patience=3, min_lr=0.000001)

checkpoint_finetune = ModelCheckpoint(
    'best_finetune_06.keras',
    monitor='val_accuracy',
    save_best_only=True,
    mode='max',
    verbose=1
)

history_finetune = model.fit(train_dataset, 
                            validation_data=val_dataset,
                            epochs=30,
                            class_weight=class_weights,
                            callbacks=[reduce_lr, ])

model.save('finetune_06.keras')

print("Training history for finetune model:")
plot_training_history(history_finetune)

## Evaluation

In [None]:
# Final evaluation on test set
print("Evaluating the model on test dataset:")

# Load fine-tuned model
finetune = tf.keras.models.load_model('finetune_06.keras')
test_loss_finetune, test_accuracy_finetune = finetune.evaluate(test_dataset)
print(f"Fine-tuned model test loss: {test_loss_finetune}, Fine-tuned model test accuracy: {test_accuracy_finetune}")

In [None]:
# Create confusion matrix for all three models on test dataset
print("Confusion Matrix for Fine-tuned Model (Test Set):")
cm, cm_norm = plot_confusion_matrix(finetune, test_dataset, num_classes, class_names)

### Visualize predictions

In [None]:
print(f"\nVisualizing predictions for the model:")
visualize_predictions(model, val_dataset, class_names, num_classes=num_classes, num_samples=10)

#Calculate per-class accuracy
per_class_acc = np.diag(cm_norm)
print("\nPer-class accuracy:")
for i, (class_name, acc) in enumerate(zip(class_names, per_class_acc)):
    print(f"{class_name}: {acc:.4f}")