In [2]:
from datasets import load_dataset
from transformers import AutoFeatureExtractor  , AutoModelForImageClassification, TrainingArguments, Trainer, MobileNetV2ForImageClassification
from torch.utils.data import Dataset
from torchvision import transforms
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from PIL import Image
import torch
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [39]:
#Load  dataset
dataset = load_dataset("Piro17/dataset-affecthqnet-fer2013")
#dataset = load_dataset("AutumnQiu/fer2013")

sample_train = 40000
sample_test = int(sample_train / 8)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

dataset['test'] = dataset['train'].shuffle(seed=96).select(range(sample_test))
dataset['train'] = dataset['train'].shuffle(seed=23).select(range(sample_train))

test_valid_split = dataset['test'].train_test_split(test_size=0.65, seed=45)
dataset['test'] = test_valid_split['train']
dataset['validation'] = test_valid_split['test']


#Import Model from HuggingFace
model = "google/mobilenet_v2_1.0_224"
#model = "microsoft/resnet-26"
#model = "microsoft/resnet-50"
feature_extractor = AutoFeatureExtractor.from_pretrained(model)
model = MobileNetV2ForImageClassification.from_pretrained(model)

label2id = {
    "angry": 0,
    "disgust": 1,
    "fear": 2,
    "happy": 3,
    "sad": 4,
    "surprise": 5,
    "neutral": 6,
}

id2label = {v: k for k, v in label2id.items()}

#Fix up the id2label and label2id configurations (Ensure it is consistent with the dataset)
model.config.label2id = label2id
model.config.id2label = id2label

feature_extractor.label2id = label2id
feature_extractor.id2label = id2label

model.to(device)

MobileNetV2ForImageClassification(
  (mobilenet_v2): MobileNetV2Model(
    (conv_stem): MobileNetV2Stem(
      (first_conv): MobileNetV2ConvLayer(
        (convolution): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (conv_3x3): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), groups=32, bias=False)
        (normalization): BatchNorm2d(32, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
        (activation): ReLU6()
      )
      (reduce_1x1): MobileNetV2ConvLayer(
        (convolution): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1), bias=False)
        (normalization): BatchNorm2d(16, eps=0.001, momentum=0.997, affine=True, track_running_stats=True)
      )
    )
    (layer): ModuleList(
      (0): MobileNetV2InvertedResidual(
        (expand_1x1): MobileNe

In [40]:
print("Model label2id:", model.config.label2id)
print("Model id2label:", model.config.id2label)

Model label2id: {'angry': 0, 'disgust': 1, 'fear': 2, 'happy': 3, 'sad': 4, 'surprise': 5, 'neutral': 6}
Model id2label: {0: 'angry', 1: 'disgust', 2: 'fear', 3: 'happy', 4: 'sad', 5: 'surprise', 6: 'neutral'}


In [41]:
# Define the transform function
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([img.convert("RGB") for img in example_batch['image']], return_tensors='pt')
    inputs['labels'] = example_batch['label']
    return inputs

# Apply the transform to the datasets
dataset['train'] = dataset['train'].map(transform, batched=True)
dataset['validation'] = dataset['validation'].map(transform, batched=True)

# Remove the 'image' column as it's now transformed
dataset['train'] = dataset['train'].remove_columns(['image'])
dataset['validation'] = dataset['validation'].remove_columns(['image'])

# Set the format for PyTorch
dataset.set_format(type='torch')


In [42]:
print(dataset['validation'])

Dataset({
    features: ['label', 'pixel_values', 'labels'],
    num_rows: 3250
})


In [43]:
print(torch.cuda.is_available())

True


In [44]:
from transformers import Trainer

In [45]:
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)

    # Convert numeric labels to string labels
    predicted_labels = [id2label[p] for p in predictions]
    true_labels = [id2label[l] for l in labels]

    accuracy = accuracy_score(true_labels, predicted_labels)
    f1 = f1_score(true_labels, predicted_labels, average='weighted')  # Use 'weighted' for multi-class
    precision = precision_score(true_labels, predicted_labels, average='weighted')
    recall = recall_score(true_labels, predicted_labels, average='weighted')
    return {
        'accuracy': accuracy,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [46]:
from transformers import EarlyStoppingCallback

In [49]:
#Training Args
warmup = int(sample_train*0.1)
training_args = TrainingArguments(
    output_dir='./huggingface_fer_model/results',
    num_train_epochs=25,
    per_device_train_batch_size=27,
    per_device_eval_batch_size=27,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_dir='./huggingface_fer_model/logs',
    logging_steps=25,
    warmup_steps= warmup,
    report_to=[],
    learning_rate=6e-5,
    weight_decay=0.075,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    save_total_limit=2,
    fp16 = True,
    optim= 'adamw_torch_fused',
)

#Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'],
    eval_dataset=dataset['validation'],
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)]
)

In [50]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.1695,0.48592,0.861846,0.861842,0.863826,0.861846
2,0.1526,0.486559,0.872,0.872106,0.874935,0.872
3,0.1083,0.68552,0.850462,0.852145,0.862026,0.850462
4,0.0875,0.645629,0.884308,0.884355,0.886129,0.884308


TrainOutput(global_step=5928, training_loss=0.15162198367188012, metrics={'train_runtime': 1461.9971, 'train_samples_per_second': 683.996, 'train_steps_per_second': 25.342, 'total_flos': 5.0666323083264e+17, 'train_loss': 0.15162198367188012, 'epoch': 4.0})

In [51]:
eval_results = trainer.evaluate()
print(f"Validation Loss: {eval_results['eval_loss']:.4f}")
print(f"Validation Accuracy: {eval_results['eval_accuracy']:.4f}")
print(f"Validation Precision: {eval_results['eval_precision']:.4f}")
print(f"Validation Recall: {eval_results['eval_recall']:.4f}")
print(f"Validation F1 Score: {eval_results['eval_f1']:.4f}")

Validation Loss: 0.4859
Validation Accuracy: 0.8618
Validation Precision: 0.8638
Validation Recall: 0.8618
Validation F1 Score: 0.8618


In [55]:
output = './mobilenet_v2_affectnethq-fer2013_model_fixed_labels'
model.save_pretrained(output)
feature_extractor.save_pretrained(output)

['./mobilenet_v2_affectnethq-fer2013_model_fixed_labels/preprocessor_config.json']