In [42]:
from datasets import load_dataset
from transformers import AutoFeatureExtractor  , AutoModelForImageClassification, TrainingArguments, Trainer
from torch.utils.data import Dataset
from torchvision import transforms
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from PIL import Image
import torch

In [66]:
#Loading FER-2013 Dataset:  https://huggingface.co/datasets/3una/Fer2013

dataset = load_dataset("AutumnQiu/fer2013")

sample_train = 22000
sample_test = 3000
seed = 27

dataset['train'] = dataset['train'].shuffle(seed=seed).select(range(sample_train))
dataset['test'] = dataset['test'].shuffle(seed=seed).select(range(sample_test))

test_valid_split = dataset['test'].train_test_split(test_size=0.65, seed=45)
dataset['test'] = test_valid_split['train']
dataset['validation'] = test_valid_split['test']

#Import ResNet-50 from HuggingFace
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-26")
model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-26")



In [67]:
print(dataset['train'][0])

{'label': 4, 'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=48x48 at 0x7FB09A83F670>}


In [68]:
# Define the transform function
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([img.convert("RGB") for img in example_batch['image']], return_tensors='pt')
    inputs['labels'] = example_batch['label']
    return inputs

# Apply the transform to the datasets
dataset = dataset.map(transform, batched=True)

# Remove the 'image' column as it's now transformed
dataset = dataset.remove_columns(['image'])

# Set the format for PyTorch
dataset.set_format(type='torch')

Map: 100%|██████████| 22000/22000 [01:34<00:00, 233.05 examples/s]


In [69]:
from transformers import Trainer

In [70]:
def compute_metrics(p):
    preds = p.predictions.argmax(-1)
    labels = p.label_ids
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
    }

In [73]:
#Training Args
training_args = TrainingArguments(
    output_dir='./huggingface_fer_model/results',          # output directory
    num_train_epochs=25,              # total number of training epochs
    per_device_train_batch_size=24,  # batch size for training
    per_device_eval_batch_size=24,   # batch size for evaluation
    evaluation_strategy="epoch",     # evaluation strategy to use at the end of each epoch
    save_strategy="epoch",           # save strategy to use at the end of each epoch
    logging_dir='./huggingface_fer_model/logs',            # directory for storing logs
    logging_steps=25,
    warmup_steps=200,                 # number of warmup steps for learning rate scheduler
    report_to=[],                    # disable reporting to any integration
    learning_rate=7e-5,
    weight_decay=0.05,
)

#Trainer
trainer = Trainer(
    model=model,                         
    args=training_args,                  
    train_dataset=dataset['train'],      
    eval_dataset=dataset['validation'],
    compute_metrics=compute_metrics
)



In [74]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.1498,1.601404,0.629231,0.636532,0.629231,0.628493
2,0.1051,1.830445,0.645128,0.646625,0.645128,0.642679
3,0.0828,2.128633,0.64,0.639272,0.64,0.636502
4,0.0596,2.315516,0.622051,0.631463,0.622051,0.620509
5,0.0866,2.062851,0.638462,0.637662,0.638462,0.63718
6,0.0629,2.21664,0.637949,0.633263,0.637949,0.634026
7,0.0811,2.366654,0.655385,0.659991,0.655385,0.654837
8,0.0175,2.333942,0.647179,0.652457,0.647179,0.645926
9,0.0522,2.383913,0.644103,0.642959,0.644103,0.639968
10,0.0309,2.588795,0.651282,0.648812,0.651282,0.646589


TrainOutput(global_step=22925, training_loss=0.031089846424602538, metrics={'train_runtime': 4238.1271, 'train_samples_per_second': 129.774, 'train_steps_per_second': 5.409, 'total_flos': 7.9454821146624e+18, 'train_loss': 0.031089846424602538, 'epoch': 25.0})

In [None]:
eval_results = trainer.evaluate()
print(f"Validation Loss: {eval_results['eval_loss']:.4f}")
print(f"Validation Accuracy: {eval_results['eval_accuracy']:.4f}")
print(f"Validation Precision: {eval_results['eval_precision']:.4f}")
print(f"Validation Recall: {eval_results['eval_recall']:.4f}")
print(f"Validation F1 Score: {eval_results['eval_f1']:.4f}")

Validation Loss: 2.0952
Validation Accuracy: 0.2914
Validation Precision: 0.3032
Validation Recall: 0.2914
Validation F1 Score: 0.2911


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [None]:
model.save_pretrained('./resnet26_fer2013_model')
feature_extractor.save_pretrained('./resnet26_fer2013_model')