In [2]:
from datasets import load_dataset
from transformers import AutoFeatureExtractor  , AutoModelForImageClassification, TrainingArguments, Trainer
from torch.utils.data import Dataset
from torchvision import transforms
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from PIL import Image
import torch

  from .autonotebook import tqdm as notebook_tqdm
2025-02-27 10:28:04.702975: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-02-27 10:28:04.714531: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-27 10:28:04.825665: I external/local_xla/xla/tsl/cuda/cudart_stub.cc:32] Could not find cuda drivers on your machine, GPU will not be used.
2025-02-27 10:28:04.923844: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1740670085.002973  868825 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has

In [11]:
#Loading FER-2013 Dataset:  https://huggingface.co/datasets/3una/Fer2013

dataset = load_dataset("AutumnQiu/fer2013")

sample_train = 22000
sample_test = 3000
seed = 27

dataset['train'] = dataset['train'].shuffle(seed=seed).select(range(sample_train))
dataset['test'] = dataset['test'].shuffle(seed=seed).select(range(sample_test))

test_valid_split = dataset['test'].train_test_split(test_size=0.65, seed=45)
dataset['test'] = test_valid_split['train']
dataset['validation'] = test_valid_split['test']

#Import ResNet-50 from HuggingFace
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-26")
model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-26")



In [12]:
print(dataset['train'][0])

{'label': 4, 'image': <PIL.PngImagePlugin.PngImageFile image mode=RGB size=48x48 at 0x7FBEE056D960>}


In [None]:
data_transforms = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.RandomResizedCrop(size=(48, 48), scale=(0.8, 1.0)),
    transforms.ToTensor()
])

In [14]:
# Define the transform function
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([img.convert("RGB") for img in example_batch['image']], return_tensors='pt')
    inputs['labels'] = example_batch['label']
    return inputs

# Apply the transform to the datasets
dataset = dataset.map(transform, batched=True)

# Remove the 'image' column as it's now transformed
dataset = dataset.remove_columns(['image'])

# Set the format for PyTorch
dataset.set_format(type='torch')

In [15]:
from transformers import Trainer

In [16]:
def compute_metrics(p):
    preds = p.predictions.argmax(-1)
    labels = p.label_ids
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
    }

In [21]:
from transformers import EarlyStoppingCallback

In [None]:
#Training Args
training_args = TrainingArguments(
    output_dir='./huggingface_fer_model/results',          # output directory
    num_train_epochs=25,              # total number of training epochs
    per_device_train_batch_size=24,  # batch size for training
    per_device_eval_batch_size=24,   # batch size for evaluation
    evaluation_strategy="epoch",     # evaluation strategy to use at the end of each epoch
    save_strategy="epoch",           # save strategy to use at the end of each epoch
    logging_dir='./huggingface_fer_model/logs',            # directory for storing logs
    logging_steps=25,
    warmup_steps=3000,                 # number of warmup steps for learning rate scheduler
    report_to=[],                    # disable reporting to any integration
    learning_rate=7e-5,
    weight_decay=0.055,
    fp16=True,                     # use mixed precision training
    load_best_model_at_end=True,     # load the best model when finished training (default metric is loss)
    metric_for_best_model="eval_loss",
    greater_is_better=False,          # lower loss is better
    save_total_limit=2,               # limit the total amount of checkpoints, delete the older checkpoints in the output_dir    
)

#Trainer
trainer = Trainer(
    model=model,                         
    args=training_args,                  
    train_dataset=dataset['train'],      
    eval_dataset=dataset['validation'],
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5) ]  # Stop training if no improvement
)



In [23]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,F1
1,0.0038,2.76528,0.659487,0.65865,0.659487,0.657852
2,0.0197,2.807828,0.641538,0.639493,0.641538,0.636706
3,0.0744,2.901085,0.626154,0.632058,0.626154,0.622925
4,0.1418,2.677907,0.632308,0.631166,0.632308,0.630835
5,0.0862,2.789686,0.639487,0.645813,0.639487,0.640869
6,0.0777,2.876967,0.637949,0.641636,0.637949,0.636517


KeyboardInterrupt: 

In [None]:
eval_results = trainer.evaluate()
print(f"Validation Loss: {eval_results['eval_loss']:.4f}")
print(f"Validation Accuracy: {eval_results['eval_accuracy']:.4f}")
print(f"Validation Precision: {eval_results['eval_precision']:.4f}")
print(f"Validation Recall: {eval_results['eval_recall']:.4f}")
print(f"Validation F1 Score: {eval_results['eval_f1']:.4f}")

Validation Loss: 2.0952
Validation Accuracy: 0.2914
Validation Precision: 0.3032
Validation Recall: 0.2914
Validation F1 Score: 0.2911


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [None]:
model.save_pretrained('./resnet26_fer2013_model')
feature_extractor.save_pretrained('./resnet26_fer2013_model')