In [1]:
from datasets import load_dataset
from transformers import AutoFeatureExtractor  , AutoModelForImageClassification, TrainingArguments, Trainer
from torch.utils.data import Dataset
from torchvision import transforms
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from PIL import Image
import torch

  from .autonotebook import tqdm as notebook_tqdm





In [4]:
#Loading FER-2013 Dataset:  https://huggingface.co/datasets/3una/Fer2013

dataset = load_dataset("Piro17/dataset-affecthqnet-fer2013")

sample_train = 22000
sample_test = 3000
seed = 27
print(dataset)

dataset['train'] = dataset['train'].shuffle(seed=seed).select(range(sample_train))
dataset['test'] = dataset['train'].shuffle(seed=seed).select(range(sample_test))

test_valid_split = dataset['test'].train_test_split(test_size=0.65, seed=45)
dataset['test'] = test_valid_split['train']
dataset['validation'] = test_valid_split['test']

#Import ResNet-50 from HuggingFace
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/resnet-26")
model = AutoModelForImageClassification.from_pretrained("microsoft/resnet-26")


DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 56532
    })
})




In [5]:
print(dataset['train'][0])

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=48x48 at 0x1EAD2D66610>, 'label': 4}


In [6]:
data_transforms = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.RandomResizedCrop(size=(48, 48), scale=(0.8, 1.0)),
    transforms.ToTensor()
])

In [7]:
# Define the transform function
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([img.convert("RGB") for img in example_batch['image']], return_tensors='pt')
    inputs['labels'] = example_batch['label']
    return inputs

# Apply the transform to the datasets
dataset = dataset.map(transform, batched=True)

# Remove the 'image' column as it's now transformed
dataset = dataset.remove_columns(['image'])

# Set the format for PyTorch
dataset.set_format(type='torch')

Map: 100%|██████████| 22000/22000 [19:59<00:00, 18.34 examples/s]
Map: 100%|██████████| 1050/1050 [00:55<00:00, 18.86 examples/s]
Map: 100%|██████████| 1950/1950 [01:46<00:00, 18.32 examples/s]


In [8]:
from transformers import Trainer

In [9]:
def compute_metrics(p):
    preds = p.predictions.argmax(-1)
    labels = p.label_ids
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
    }

In [10]:
from transformers import EarlyStoppingCallback

In [13]:
#Training Args
training_args = TrainingArguments(
    output_dir='./huggingface_fer_model/results',          # output directory
    num_train_epochs=10,              # total number of training epochs
    per_device_train_batch_size=8,  # batch size for training
    per_device_eval_batch_size=8,   # batch size for evaluation
    evaluation_strategy="epoch",     # evaluation strategy to use at the end of each epoch
    save_strategy="epoch",           # save strategy to use at the end of each epoch
    logging_dir='./huggingface_fer_model/logs',            # directory for storing logs
    logging_steps=25,
    warmup_steps=55,                 # number of warmup steps for learning rate scheduler
    report_to=[],                    # disable reporting to any integration
    learning_rate=7e-5,
    weight_decay=0.055,
    fp16=True,                     # use mixed precision training
    load_best_model_at_end=True,     # load the best model when finished training (default metric is loss)
    metric_for_best_model="eval_loss",
    greater_is_better=False,          # lower loss is better
    save_total_limit=2,               # limit the total amount of checkpoints, delete the older checkpoints in the output_dir    
)

#Trainer
trainer = Trainer(
    model=model,                         
    args=training_args,                  
    train_dataset=dataset['train'],      
    eval_dataset=dataset['validation'],
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5) ]  # Stop training if no improvement
)

In [14]:
trainer.train()

                                                     
  1%|          | 233/22925 [02:14<2:54:06,  2.17it/s]

{'loss': 6.0068, 'grad_norm': 25.996702194213867, 'learning_rate': 2.8e-05, 'epoch': 0.01}


                                                     
  1%|          | 233/22925 [02:17<2:54:06,  2.17it/s]

{'loss': 5.1254, 'grad_norm': 29.352420806884766, 'learning_rate': 5.981818181818181e-05, 'epoch': 0.02}


                                                     
  1%|          | 233/22925 [02:20<2:54:06,  2.17it/s]

{'loss': 3.4438, 'grad_norm': 29.727903366088867, 'learning_rate': 6.995664055383494e-05, 'epoch': 0.03}


                                                     
  1%|          | 233/22925 [02:24<2:54:06,  2.17it/s]

{'loss': 2.2065, 'grad_norm': 23.744081497192383, 'learning_rate': 6.989287666241573e-05, 'epoch': 0.04}


                                                     
  1%|          | 233/22925 [02:27<2:54:06,  2.17it/s]

{'loss': 1.703, 'grad_norm': 27.744749069213867, 'learning_rate': 6.982911277099653e-05, 'epoch': 0.05}


                                                     
  1%|          | 233/22925 [02:30<2:54:06,  2.17it/s]

{'loss': 1.5923, 'grad_norm': 26.917104721069336, 'learning_rate': 6.976534887957733e-05, 'epoch': 0.05}


                                                     
  1%|          | 233/22925 [02:34<2:54:06,  2.17it/s]

{'loss': 1.8393, 'grad_norm': 20.134244918823242, 'learning_rate': 6.970158498815812e-05, 'epoch': 0.06}


                                                     
  1%|          | 233/22925 [02:38<2:54:06,  2.17it/s]

{'loss': 1.5371, 'grad_norm': 23.514179229736328, 'learning_rate': 6.963782109673892e-05, 'epoch': 0.07}


                                                     
  1%|          | 233/22925 [02:41<2:54:06,  2.17it/s]

{'loss': 1.5466, 'grad_norm': 20.333528518676758, 'learning_rate': 6.957405720531973e-05, 'epoch': 0.08}


                                                     
  1%|          | 233/22925 [02:45<2:54:06,  2.17it/s]

{'loss': 1.3505, 'grad_norm': 16.894142150878906, 'learning_rate': 6.951029331390053e-05, 'epoch': 0.09}


                                                     
  1%|          | 233/22925 [02:48<2:54:06,  2.17it/s]

{'loss': 1.5042, 'grad_norm': 16.10357093811035, 'learning_rate': 6.944652942248132e-05, 'epoch': 0.1}


                                                     
  1%|          | 233/22925 [02:52<2:54:06,  2.17it/s]

{'loss': 1.399, 'grad_norm': 18.991901397705078, 'learning_rate': 6.938276553106212e-05, 'epoch': 0.11}


                                                     
  1%|          | 233/22925 [02:55<2:54:06,  2.17it/s]

{'loss': 1.5363, 'grad_norm': 24.633615493774414, 'learning_rate': 6.932155219529968e-05, 'epoch': 0.12}


                                                     
  1%|          | 233/22925 [02:59<2:54:06,  2.17it/s]

{'loss': 1.4731, 'grad_norm': 16.47776222229004, 'learning_rate': 6.925778830388048e-05, 'epoch': 0.13}


                                                     
  1%|          | 233/22925 [03:03<2:54:06,  2.17it/s]

{'loss': 1.3722, 'grad_norm': 22.60738754272461, 'learning_rate': 6.919402441246127e-05, 'epoch': 0.14}


                                                     
  1%|          | 233/22925 [03:06<2:54:06,  2.17it/s]

{'loss': 1.3346, 'grad_norm': 16.074195861816406, 'learning_rate': 6.913026052104207e-05, 'epoch': 0.15}


                                                     
  1%|          | 233/22925 [03:10<2:54:06,  2.17it/s]

{'loss': 1.278, 'grad_norm': 16.84820556640625, 'learning_rate': 6.906649662962288e-05, 'epoch': 0.15}


                                                     
  1%|          | 233/22925 [03:13<2:54:06,  2.17it/s]

{'loss': 1.399, 'grad_norm': 15.278069496154785, 'learning_rate': 6.900273273820368e-05, 'epoch': 0.16}


                                                     
  1%|          | 233/22925 [03:17<2:54:06,  2.17it/s]

{'loss': 1.2683, 'grad_norm': 10.560317993164062, 'learning_rate': 6.893896884678448e-05, 'epoch': 0.17}


                                                     
  1%|          | 233/22925 [03:21<2:54:06,  2.17it/s]

{'loss': 1.2326, 'grad_norm': 28.500417709350586, 'learning_rate': 6.887520495536527e-05, 'epoch': 0.18}


                                                     
  1%|          | 233/22925 [03:24<2:54:06,  2.17it/s]

{'loss': 1.3051, 'grad_norm': 22.326135635375977, 'learning_rate': 6.881144106394607e-05, 'epoch': 0.19}


                                                     
  1%|          | 233/22925 [03:28<2:54:06,  2.17it/s]

{'loss': 1.177, 'grad_norm': 20.022321701049805, 'learning_rate': 6.874767717252686e-05, 'epoch': 0.2}


                                                     
  1%|          | 233/22925 [03:31<2:54:06,  2.17it/s]

{'loss': 1.2667, 'grad_norm': 26.1958065032959, 'learning_rate': 6.868391328110766e-05, 'epoch': 0.21}


                                                     
  1%|          | 233/22925 [03:35<2:54:06,  2.17it/s]

{'loss': 1.2504, 'grad_norm': 14.167830467224121, 'learning_rate': 6.862014938968846e-05, 'epoch': 0.22}


                                                     
  1%|          | 233/22925 [03:39<2:54:06,  2.17it/s]

{'loss': 1.3252, 'grad_norm': 20.554737091064453, 'learning_rate': 6.855638549826926e-05, 'epoch': 0.23}


                                                     
  1%|          | 233/22925 [03:42<2:54:06,  2.17it/s]

{'loss': 1.2766, 'grad_norm': 14.12016487121582, 'learning_rate': 6.849262160685007e-05, 'epoch': 0.24}


                                                     
  1%|          | 233/22925 [03:46<2:54:06,  2.17it/s]

{'loss': 1.3328, 'grad_norm': 20.089366912841797, 'learning_rate': 6.842885771543085e-05, 'epoch': 0.25}


                                                     
  1%|          | 233/22925 [03:49<2:54:06,  2.17it/s]

{'loss': 1.1271, 'grad_norm': 19.0623779296875, 'learning_rate': 6.836509382401166e-05, 'epoch': 0.25}


                                                     
  1%|          | 233/22925 [03:54<2:54:06,  2.17it/s]

{'loss': 1.2469, 'grad_norm': 18.270267486572266, 'learning_rate': 6.830132993259245e-05, 'epoch': 0.26}


                                                     
  1%|          | 233/22925 [03:58<2:54:06,  2.17it/s]

{'loss': 1.1736, 'grad_norm': 19.200239181518555, 'learning_rate': 6.823756604117325e-05, 'epoch': 0.27}


                                                     
  1%|          | 233/22925 [04:02<2:54:06,  2.17it/s]

{'loss': 1.2834, 'grad_norm': 14.31234073638916, 'learning_rate': 6.817380214975405e-05, 'epoch': 0.28}


                                                     
  1%|          | 233/22925 [04:06<2:54:06,  2.17it/s]

{'loss': 1.2043, 'grad_norm': 11.422408103942871, 'learning_rate': 6.811003825833485e-05, 'epoch': 0.29}


                                                     
  1%|          | 233/22925 [04:10<2:54:06,  2.17it/s]

{'loss': 1.1163, 'grad_norm': 20.068878173828125, 'learning_rate': 6.804627436691564e-05, 'epoch': 0.3}


                                                     
  1%|          | 233/22925 [04:15<2:54:06,  2.17it/s]

{'loss': 1.29, 'grad_norm': 15.025643348693848, 'learning_rate': 6.798251047549644e-05, 'epoch': 0.31}


                                                     
  1%|          | 233/22925 [04:19<2:54:06,  2.17it/s]

{'loss': 1.3057, 'grad_norm': 15.464679718017578, 'learning_rate': 6.791874658407724e-05, 'epoch': 0.32}


                                                     
  1%|          | 233/22925 [04:23<2:54:06,  2.17it/s]

{'loss': 1.2881, 'grad_norm': 18.83913803100586, 'learning_rate': 6.785498269265803e-05, 'epoch': 0.33}


                                                     
  1%|          | 233/22925 [04:28<2:54:06,  2.17it/s]

{'loss': 1.1036, 'grad_norm': 18.2570743560791, 'learning_rate': 6.779121880123883e-05, 'epoch': 0.34}


                                                     
  1%|          | 233/22925 [04:32<2:54:06,  2.17it/s]


{'loss': 1.2705, 'grad_norm': 15.381704330444336, 'learning_rate': 6.772745490981964e-05, 'epoch': 0.35}


                                                     [A
  1%|          | 233/22925 [04:36<2:54:06,  2.17it/s]


{'loss': 1.1339, 'grad_norm': 17.630159378051758, 'learning_rate': 6.766369101840042e-05, 'epoch': 0.35}


                                                     [A
  1%|          | 233/22925 [04:41<2:54:06,  2.17it/s] 

{'loss': 1.074, 'grad_norm': 14.94446849822998, 'learning_rate': 6.759992712698123e-05, 'epoch': 0.36}


                                                     
  1%|          | 233/22925 [04:45<2:54:06,  2.17it/s] 


{'loss': 1.1332, 'grad_norm': 9.674606323242188, 'learning_rate': 6.753616323556203e-05, 'epoch': 0.37}


                                                     ][A
  1%|          | 233/22925 [04:49<2:54:06,  2.17it/s] 

{'loss': 1.2504, 'grad_norm': 11.144970893859863, 'learning_rate': 6.747239934414283e-05, 'epoch': 0.38}


                                                     
  1%|          | 233/22925 [04:54<2:54:06,  2.17it/s] 


{'loss': 1.1406, 'grad_norm': 10.059240341186523, 'learning_rate': 6.740863545272363e-05, 'epoch': 0.39}


                                                     ][A
  1%|          | 233/22925 [04:59<2:54:06,  2.17it/s] 

{'loss': 1.1968, 'grad_norm': 23.000457763671875, 'learning_rate': 6.734487156130442e-05, 'epoch': 0.4}


                                                     
  1%|          | 233/22925 [05:03<2:54:06,  2.17it/s] 


{'loss': 1.1658, 'grad_norm': 16.80907440185547, 'learning_rate': 6.728110766988522e-05, 'epoch': 0.41}


                                                     ][A
  1%|          | 233/22925 [05:08<2:54:06,  2.17it/s] 

{'loss': 1.0658, 'grad_norm': 9.34872817993164, 'learning_rate': 6.721734377846601e-05, 'epoch': 0.42}


                                                     
  1%|          | 233/22925 [05:12<2:54:06,  2.17it/s] 

{'loss': 1.2316, 'grad_norm': 16.439456939697266, 'learning_rate': 6.715357988704681e-05, 'epoch': 0.43}


                                                     
  1%|          | 233/22925 [05:17<2:54:06,  2.17it/s] 

{'loss': 1.1408, 'grad_norm': 15.894543647766113, 'learning_rate': 6.708981599562762e-05, 'epoch': 0.44}


                                                     
  1%|          | 233/22925 [05:22<2:54:06,  2.17it/s] 

{'loss': 0.9659, 'grad_norm': 13.796121597290039, 'learning_rate': 6.702605210420842e-05, 'epoch': 0.45}


                                                     
  1%|          | 233/22925 [05:27<2:54:06,  2.17it/s] 

{'loss': 1.3516, 'grad_norm': 16.30180549621582, 'learning_rate': 6.69622882127892e-05, 'epoch': 0.45}


                                                     
  1%|          | 233/22925 [05:32<2:54:06,  2.17it/s] 

{'loss': 1.0445, 'grad_norm': 17.970521926879883, 'learning_rate': 6.689852432137001e-05, 'epoch': 0.46}


                                                     
  1%|          | 233/22925 [05:37<2:54:06,  2.17it/s] 


{'loss': 1.1668, 'grad_norm': 16.254528045654297, 'learning_rate': 6.683476042995081e-05, 'epoch': 0.47}


                                                     ][A
  1%|          | 233/22925 [05:41<2:54:06,  2.17it/s] 

{'loss': 1.1402, 'grad_norm': 19.232746124267578, 'learning_rate': 6.67709965385316e-05, 'epoch': 0.48}


                                                     
  1%|          | 233/22925 [05:46<2:54:06,  2.17it/s] 

{'loss': 1.2197, 'grad_norm': 18.24677848815918, 'learning_rate': 6.67072326471124e-05, 'epoch': 0.49}


                                                     
  1%|          | 233/22925 [05:51<2:54:06,  2.17it/s] 


{'loss': 1.1871, 'grad_norm': 12.948417663574219, 'learning_rate': 6.66434687556932e-05, 'epoch': 0.5}


                                                     ][A
  1%|          | 233/22925 [05:55<2:54:06,  2.17it/s] 


{'loss': 1.182, 'grad_norm': 12.621625900268555, 'learning_rate': 6.657970486427399e-05, 'epoch': 0.51}


                                                     ][A
  1%|          | 233/22925 [06:00<2:54:06,  2.17it/s] 
  5%|▌         | 1426/27500 [03:50<1:16:52,  5.65it/s]

{'loss': 1.1731, 'grad_norm': 14.157234191894531, 'learning_rate': 6.651594097285479e-05, 'epoch': 0.52}


                                                     
  1%|          | 233/22925 [06:04<2:54:06,  2.17it/s] 

{'loss': 1.092, 'grad_norm': 12.404963493347168, 'learning_rate': 6.64521770814356e-05, 'epoch': 0.53}


                                                     
  1%|          | 233/22925 [06:09<2:54:06,  2.17it/s] 

{'loss': 1.0882, 'grad_norm': 11.737333297729492, 'learning_rate': 6.63884131900164e-05, 'epoch': 0.54}


                                                     
  1%|          | 233/22925 [06:13<2:54:06,  2.17it/s] 

{'loss': 1.245, 'grad_norm': 14.258978843688965, 'learning_rate': 6.632464929859719e-05, 'epoch': 0.55}


                                                     
  1%|          | 233/22925 [06:17<2:54:06,  2.17it/s] 

{'loss': 1.1928, 'grad_norm': 12.401409149169922, 'learning_rate': 6.626088540717799e-05, 'epoch': 0.55}


                                                     
  1%|          | 233/22925 [06:22<2:54:06,  2.17it/s] 


{'loss': 1.1187, 'grad_norm': 11.027584075927734, 'learning_rate': 6.619712151575878e-05, 'epoch': 0.56}


                                                     ][A
  1%|          | 233/22925 [06:26<2:54:06,  2.17it/s] 


{'loss': 1.2227, 'grad_norm': 12.806900024414062, 'learning_rate': 6.613335762433958e-05, 'epoch': 0.57}


                                                     ][A
  1%|          | 233/22925 [06:30<2:54:06,  2.17it/s] 

{'loss': 1.0359, 'grad_norm': 11.70371150970459, 'learning_rate': 6.606959373292038e-05, 'epoch': 0.58}


                                                     
  1%|          | 233/22925 [06:35<2:54:06,  2.17it/s] 

{'loss': 1.0359, 'grad_norm': 12.671439170837402, 'learning_rate': 6.600582984150118e-05, 'epoch': 0.59}


                                                     
  1%|          | 233/22925 [06:39<2:54:06,  2.17it/s] 

{'loss': 1.2137, 'grad_norm': 9.769308090209961, 'learning_rate': 6.594206595008198e-05, 'epoch': 0.6}


                                                     
  1%|          | 233/22925 [06:44<2:54:06,  2.17it/s] 


{'loss': 1.0982, 'grad_norm': 13.260696411132812, 'learning_rate': 6.587830205866277e-05, 'epoch': 0.61}


                                                     ][A
  1%|          | 233/22925 [06:48<2:54:06,  2.17it/s] 


{'loss': 1.0289, 'grad_norm': 9.819355964660645, 'learning_rate': 6.581453816724357e-05, 'epoch': 0.62}


                                                     ][A
  1%|          | 233/22925 [06:52<2:54:06,  2.17it/s] 


{'loss': 1.0493, 'grad_norm': 12.072371482849121, 'learning_rate': 6.575077427582436e-05, 'epoch': 0.63}


                                                     ][A
  1%|          | 233/22925 [06:57<2:54:06,  2.17it/s] 


{'loss': 0.9163, 'grad_norm': 8.879270553588867, 'learning_rate': 6.568701038440516e-05, 'epoch': 0.64}


                                                     ][A
  1%|          | 233/22925 [07:02<2:54:06,  2.17it/s] 


{'loss': 1.0949, 'grad_norm': 10.821778297424316, 'learning_rate': 6.562324649298597e-05, 'epoch': 0.65}


                                                     ][A
  1%|          | 233/22925 [07:07<2:54:06,  2.17it/s] 

{'loss': 1.0102, 'grad_norm': 16.203832626342773, 'learning_rate': 6.555948260156677e-05, 'epoch': 0.65}


                                                     
  1%|          | 233/22925 [07:12<2:54:06,  2.17it/s] 

{'loss': 1.0877, 'grad_norm': 15.98011302947998, 'learning_rate': 6.549571871014756e-05, 'epoch': 0.66}




KeyboardInterrupt: 

In [None]:
eval_results = trainer.evaluate()
print(f"Validation Loss: {eval_results['eval_loss']:.4f}")
print(f"Validation Accuracy: {eval_results['eval_accuracy']:.4f}")
print(f"Validation Precision: {eval_results['eval_precision']:.4f}")
print(f"Validation Recall: {eval_results['eval_recall']:.4f}")
print(f"Validation F1 Score: {eval_results['eval_f1']:.4f}")

Validation Loss: 2.7327
Validation Accuracy: 0.6508
Validation Precision: 0.6494
Validation Recall: 0.6508
Validation F1 Score: 0.6492


In [None]:
model.save_pretrained('./resnet26_fer2013_model')
feature_extractor.save_pretrained('./resnet26_fer2013_model')

['./resnet26_fer2013_model/preprocessor_config.json']