In [31]:
from datasets import load_dataset
from transformers import AutoFeatureExtractor  , AutoModelForImageClassification, TrainingArguments, Trainer
from torch.utils.data import Dataset
from torchvision import transforms
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from PIL import Image
import torch

In [36]:
#Load  dataset
dataset = load_dataset("Piro17/dataset-affecthqnet-fer2013")

sample_train = 8000
sample_test = 1000
print(dataset)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

dataset['test'] = dataset['train'].shuffle(seed=96).select(range(sample_test))
dataset['train'] = dataset['train'].shuffle(seed=23).select(range(sample_train))

test_valid_split = dataset['test'].train_test_split(test_size=0.65, seed=45)
dataset['test'] = test_valid_split['train']
dataset['validation'] = test_valid_split['test']

#Import Moel from HuggingFace

model = "google/mobilenet_v2_1.0_224"
#model = "microsoft/resnet-26"
feature_extractor = AutoFeatureExtractor.from_pretrained(model)
model = AutoModelForImageClassification.from_pretrained(model)


DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 56532
    })
})
cuda


In [37]:
print(dataset['train'][0])

{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1116x1116 at 0x2E8714BD9D0>, 'label': 5}


In [38]:
data_transforms = transforms.Compose([
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.RandomResizedCrop(size=(48, 48), scale=(0.8, 1.0)),
    transforms.ToTensor()
])

In [39]:
# Define the transform function
def transform(example_batch):
    # Take a list of PIL images and turn them to pixel values
    inputs = feature_extractor([img.convert("RGB") for img in example_batch['image']], return_tensors='pt')
    inputs['labels'] = example_batch['label']
    return inputs

# Apply the transform to the datasets
dataset = dataset.map(transform, batched=True)

# Remove the 'image' column as it's now transformed
dataset = dataset.remove_columns(['image'])

# Set the format for PyTorch
dataset.set_format(type='torch')

Map: 100%|██████████| 8000/8000 [07:09<00:00, 18.62 examples/s]
Map: 100%|██████████| 350/350 [00:18<00:00, 19.36 examples/s]
Map: 100%|██████████| 650/650 [00:30<00:00, 21.54 examples/s]


In [40]:
from transformers import Trainer

In [41]:
def compute_metrics(p):
    preds = p.predictions.argmax(-1)
    labels = p.label_ids
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'precision': precision,
        'recall': recall,
        'f1': f1,
    }

In [42]:
from transformers import EarlyStoppingCallback

In [None]:
#Training Args
training_args = TrainingArguments(
    output_dir='./huggingface_fer_model/results',          # output directory
    num_train_epochs=25,              # total number of training epochs
    per_device_train_batch_size=32,  # batch size for training
    per_device_eval_batch_size=32,   # batch size for evaluation
    evaluation_strategy="epoch",     # evaluation strategy to use at the end of each epoch
    save_strategy="epoch",           # save strategy to use at the end of each epoch
    logging_dir='./huggingface_fer_model/logs',            # directory for storing logs
    logging_steps=25,
    warmup_steps=55,                 # number of warmup steps for learning rate scheduler
    report_to=[],                    # disable reporting to any integration
    learning_rate=7e-5,
    weight_decay=0.055,
    fp16=True,                     # use mixed precision training
    load_best_model_at_end=True,     # load the best model when finished training (default metric is loss)
    metric_for_best_model="eval_loss",
    greater_is_better=False,          # lower loss is better
    save_total_limit=2,               # limit the total amount of checkpoints, delete the older checkpoints in the output_dir    
)

#Trainer
trainer = Trainer(
    model=model,                         
    args=training_args,                  
    train_dataset=dataset['train'],      
    eval_dataset=dataset['validation'],
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=5) ]  # Stop training if no improvement
)



In [46]:
trainer.train()


 28%|██▊       | 929/3340 [07:20<17:01,  2.36it/s]

{'loss': 0.6082, 'grad_norm': 10.3058443069458, 'learning_rate': 6.363636363636363e-05, 'epoch': 0.2}



 28%|██▊       | 929/3340 [07:33<17:01,  2.36it/s]

{'loss': 0.6135, 'grad_norm': inf, 'learning_rate': 6.94560327198364e-05, 'epoch': 0.3}



 28%|██▊       | 929/3340 [07:49<17:01,  2.36it/s]

{'loss': 0.585, 'grad_norm': 12.707986831665039, 'learning_rate': 6.874028629856851e-05, 'epoch': 0.4}



 28%|██▊       | 929/3340 [08:03<17:01,  2.36it/s]

{'loss': 0.5954, 'grad_norm': 15.39326000213623, 'learning_rate': 6.80245398773006e-05, 'epoch': 0.5}



 28%|██▊       | 929/3340 [08:18<17:01,  2.36it/s]

{'loss': 0.5518, 'grad_norm': 11.185004234313965, 'learning_rate': 6.730879345603272e-05, 'epoch': 0.6}



 28%|██▊       | 929/3340 [08:32<17:01,  2.36it/s]

{'loss': 0.5678, 'grad_norm': 10.299907684326172, 'learning_rate': 6.659304703476481e-05, 'epoch': 0.7}



 28%|██▊       | 929/3340 [08:46<17:01,  2.36it/s]

{'loss': 0.5578, 'grad_norm': 12.885693550109863, 'learning_rate': 6.587730061349692e-05, 'epoch': 0.8}



 28%|██▊       | 929/3340 [09:03<17:01,  2.36it/s]

{'loss': 0.5479, 'grad_norm': 15.263705253601074, 'learning_rate': 6.516155419222904e-05, 'epoch': 0.9}



 28%|██▊       | 929/3340 [09:23<17:01,  2.36it/s]

{'loss': 0.5112, 'grad_norm': 12.785643577575684, 'learning_rate': 6.444580777096115e-05, 'epoch': 1.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A

[A[A                                         
                                                  
 28%|██▊       | 929/3340 [09:36<17:01,  2.36it/s]
[A

{'eval_loss': 1.3253833055496216, 'eval_accuracy': 0.5753846153846154, 'eval_precision': 0.6254800521214012, 'eval_recall': 0.5753846153846154, 'eval_f1': 0.5875837925973619, 'eval_runtime': 12.3188, 'eval_samples_per_second': 52.765, 'eval_steps_per_second': 1.705, 'epoch': 1.0}



 28%|██▊       | 929/3340 [09:50<17:01,  2.36it/s]

{'loss': 0.3692, 'grad_norm': 8.539315223693848, 'learning_rate': 6.373006134969324e-05, 'epoch': 1.1}



 28%|██▊       | 929/3340 [10:04<17:01,  2.36it/s]

{'loss': 0.3607, 'grad_norm': 8.606311798095703, 'learning_rate': 6.301431492842536e-05, 'epoch': 1.2}



 28%|██▊       | 929/3340 [10:19<17:01,  2.36it/s]

{'loss': 0.3983, 'grad_norm': 9.215599060058594, 'learning_rate': 6.229856850715745e-05, 'epoch': 1.3}



 28%|██▊       | 929/3340 [10:34<17:01,  2.36it/s]

{'loss': 0.3297, 'grad_norm': 11.732529640197754, 'learning_rate': 6.158282208588956e-05, 'epoch': 1.4}



 28%|██▊       | 929/3340 [10:51<17:01,  2.36it/s]

{'loss': 0.3516, 'grad_norm': 11.609228134155273, 'learning_rate': 6.0867075664621675e-05, 'epoch': 1.5}



 28%|██▊       | 929/3340 [11:08<17:01,  2.36it/s]

{'loss': 0.3379, 'grad_norm': 12.683138847351074, 'learning_rate': 6.015132924335377e-05, 'epoch': 1.6}



 28%|██▊       | 929/3340 [11:23<17:01,  2.36it/s]

{'loss': 0.3706, 'grad_norm': 10.414093971252441, 'learning_rate': 5.9435582822085884e-05, 'epoch': 1.7}



 28%|██▊       | 929/3340 [11:37<17:01,  2.36it/s]

{'loss': 0.3749, 'grad_norm': 8.184925079345703, 'learning_rate': 5.871983640081799e-05, 'epoch': 1.8}



 28%|██▊       | 929/3340 [11:51<17:01,  2.36it/s]

{'loss': 0.3389, 'grad_norm': 7.687682628631592, 'learning_rate': 5.80040899795501e-05, 'epoch': 1.9}



 28%|██▊       | 929/3340 [12:05<17:01,  2.36it/s]

{'loss': 0.3556, 'grad_norm': 10.316762924194336, 'learning_rate': 5.728834355828221e-05, 'epoch': 2.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                  

 28%|██▊       | 929/3340 [12:15<17:01,  2.36it/s]
[A
[A

{'eval_loss': 1.380627989768982, 'eval_accuracy': 0.5969230769230769, 'eval_precision': 0.6148171263700191, 'eval_recall': 0.5969230769230769, 'eval_f1': 0.6033757539070361, 'eval_runtime': 10.3003, 'eval_samples_per_second': 63.105, 'eval_steps_per_second': 2.039, 'epoch': 2.0}



 28%|██▊       | 929/3340 [12:29<17:01,  2.36it/s]

{'loss': 0.2012, 'grad_norm': 9.808981895446777, 'learning_rate': 5.657259713701431e-05, 'epoch': 2.1}



 28%|██▊       | 929/3340 [12:43<17:01,  2.36it/s]

{'loss': 0.1736, 'grad_norm': 5.194911956787109, 'learning_rate': 5.585685071574641e-05, 'epoch': 2.2}



 28%|██▊       | 929/3340 [12:57<17:01,  2.36it/s]

{'loss': 0.1708, 'grad_norm': 9.612290382385254, 'learning_rate': 5.514110429447852e-05, 'epoch': 2.3}



 28%|██▊       | 929/3340 [13:11<17:01,  2.36it/s]

{'loss': 0.1758, 'grad_norm': 6.721974849700928, 'learning_rate': 5.4453987730061345e-05, 'epoch': 2.4}



 28%|██▊       | 929/3340 [13:25<17:01,  2.36it/s]

{'loss': 0.178, 'grad_norm': 4.472684383392334, 'learning_rate': 5.373824130879345e-05, 'epoch': 2.5}



 28%|██▊       | 929/3340 [13:39<17:01,  2.36it/s]

{'loss': 0.1602, 'grad_norm': 10.365991592407227, 'learning_rate': 5.302249488752556e-05, 'epoch': 2.6}



 28%|██▊       | 929/3340 [13:53<17:01,  2.36it/s]

{'loss': 0.1513, 'grad_norm': 7.915031909942627, 'learning_rate': 5.2306748466257664e-05, 'epoch': 2.7}



 28%|██▊       | 929/3340 [14:07<17:01,  2.36it/s]

{'loss': 0.2209, 'grad_norm': 14.104194641113281, 'learning_rate': 5.1591002044989776e-05, 'epoch': 2.8}



 28%|██▊       | 929/3340 [14:21<17:01,  2.36it/s]

{'loss': 0.3182, 'grad_norm': 17.06435203552246, 'learning_rate': 5.087525562372187e-05, 'epoch': 2.9}



 28%|██▊       | 929/3340 [14:35<17:01,  2.36it/s]

{'loss': 0.2926, 'grad_norm': 13.53278923034668, 'learning_rate': 5.0159509202453984e-05, 'epoch': 3.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                  

 28%|██▊       | 929/3340 [14:46<17:01,  2.36it/s]
[A
[A

{'eval_loss': 1.5806471109390259, 'eval_accuracy': 0.5861538461538461, 'eval_precision': 0.6094091781789753, 'eval_recall': 0.5861538461538461, 'eval_f1': 0.5912450257519086, 'eval_runtime': 10.3103, 'eval_samples_per_second': 63.044, 'eval_steps_per_second': 2.037, 'epoch': 3.0}



 28%|██▊       | 929/3340 [15:00<17:01,  2.36it/s]

{'loss': 0.1053, 'grad_norm': 11.78385066986084, 'learning_rate': 4.944376278118609e-05, 'epoch': 3.1}



 28%|██▊       | 929/3340 [15:13<17:01,  2.36it/s]

{'loss': 0.1376, 'grad_norm': 9.089988708496094, 'learning_rate': 4.87280163599182e-05, 'epoch': 3.2}



 28%|██▊       | 929/3340 [15:30<17:01,  2.36it/s]

{'loss': 0.1191, 'grad_norm': 9.92603874206543, 'learning_rate': 4.8012269938650304e-05, 'epoch': 3.3}



 28%|██▊       | 929/3340 [15:49<17:01,  2.36it/s]

{'loss': 0.1423, 'grad_norm': 11.097743034362793, 'learning_rate': 4.729652351738241e-05, 'epoch': 3.4}



 28%|██▊       | 929/3340 [16:12<17:01,  2.36it/s]

{'loss': 0.1475, 'grad_norm': 6.840501308441162, 'learning_rate': 4.658077709611451e-05, 'epoch': 3.5}



 28%|██▊       | 929/3340 [16:30<17:01,  2.36it/s]

{'loss': 0.131, 'grad_norm': 8.6572265625, 'learning_rate': 4.5865030674846624e-05, 'epoch': 3.6}



 28%|██▊       | 929/3340 [16:48<17:01,  2.36it/s]

{'loss': 0.1446, 'grad_norm': 11.022968292236328, 'learning_rate': 4.514928425357873e-05, 'epoch': 3.7}



 28%|██▊       | 929/3340 [17:05<17:01,  2.36it/s]

{'loss': 0.147, 'grad_norm': 11.222223281860352, 'learning_rate': 4.443353783231084e-05, 'epoch': 3.8}



 28%|██▊       | 929/3340 [17:20<17:01,  2.36it/s]

{'loss': 0.1592, 'grad_norm': 14.450681686401367, 'learning_rate': 4.3717791411042936e-05, 'epoch': 3.9}



 28%|██▊       | 929/3340 [17:35<17:01,  2.36it/s]

{'loss': 0.1264, 'grad_norm': 9.700873374938965, 'learning_rate': 4.300204498977505e-05, 'epoch': 4.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                   

 28%|██▊       | 929/3340 [17:46<17:01,  2.36it/s]
[A
[A

{'eval_loss': 1.7314094305038452, 'eval_accuracy': 0.5676923076923077, 'eval_precision': 0.5886708899165014, 'eval_recall': 0.5676923076923077, 'eval_f1': 0.5723840200807941, 'eval_runtime': 10.8394, 'eval_samples_per_second': 59.966, 'eval_steps_per_second': 1.937, 'epoch': 4.0}



 28%|██▊       | 929/3340 [18:00<17:01,  2.36it/s] 

{'loss': 0.0624, 'grad_norm': 7.4348883628845215, 'learning_rate': 4.228629856850715e-05, 'epoch': 4.1}



 28%|██▊       | 929/3340 [18:15<17:01,  2.36it/s] 

{'loss': 0.0677, 'grad_norm': 1.5687001943588257, 'learning_rate': 4.157055214723926e-05, 'epoch': 4.2}



 28%|██▊       | 929/3340 [18:29<17:01,  2.36it/s] 

{'loss': 0.0817, 'grad_norm': 4.914190769195557, 'learning_rate': 4.085480572597137e-05, 'epoch': 4.3}



 28%|██▊       | 929/3340 [18:44<17:01,  2.36it/s] 

{'loss': 0.0773, 'grad_norm': 13.003617286682129, 'learning_rate': 4.013905930470347e-05, 'epoch': 4.4}



 28%|██▊       | 929/3340 [18:59<17:01,  2.36it/s] 

{'loss': 0.0774, 'grad_norm': 10.087752342224121, 'learning_rate': 3.9423312883435576e-05, 'epoch': 4.5}



 28%|██▊       | 929/3340 [19:13<17:01,  2.36it/s] 

{'loss': 0.0775, 'grad_norm': 4.497786521911621, 'learning_rate': 3.870756646216769e-05, 'epoch': 4.6}



 28%|██▊       | 929/3340 [19:28<17:01,  2.36it/s] 

{'loss': 0.0805, 'grad_norm': 7.628304481506348, 'learning_rate': 3.799182004089979e-05, 'epoch': 4.7}



 28%|██▊       | 929/3340 [19:43<17:01,  2.36it/s] 

{'loss': 0.0846, 'grad_norm': 7.543399810791016, 'learning_rate': 3.72760736196319e-05, 'epoch': 4.8}



 28%|██▊       | 929/3340 [19:58<17:01,  2.36it/s] 

{'loss': 0.11, 'grad_norm': 12.618552207946777, 'learning_rate': 3.6560327198364e-05, 'epoch': 4.9}



 28%|██▊       | 929/3340 [20:15<17:01,  2.36it/s] 

{'loss': 0.0895, 'grad_norm': 5.141551971435547, 'learning_rate': 3.584458077709611e-05, 'epoch': 5.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                   

 28%|██▊       | 929/3340 [20:28<17:01,  2.36it/s]
[A
[A

{'eval_loss': 1.8279938697814941, 'eval_accuracy': 0.5815384615384616, 'eval_precision': 0.6164638717947241, 'eval_recall': 0.5815384615384616, 'eval_f1': 0.5923489674666289, 'eval_runtime': 12.8969, 'eval_samples_per_second': 50.4, 'eval_steps_per_second': 1.628, 'epoch': 5.0}



 28%|██▊       | 929/3340 [20:47<17:01,  2.36it/s] 

{'loss': 0.0347, 'grad_norm': 4.228569984436035, 'learning_rate': 3.5128834355828215e-05, 'epoch': 5.1}



 28%|██▊       | 929/3340 [21:06<17:01,  2.36it/s] 

{'loss': 0.0442, 'grad_norm': 8.975517272949219, 'learning_rate': 3.4413087934560326e-05, 'epoch': 5.2}



 28%|██▊       | 929/3340 [21:24<17:01,  2.36it/s] 

{'loss': 0.0537, 'grad_norm': 4.52344274520874, 'learning_rate': 3.369734151329243e-05, 'epoch': 5.3}



 28%|██▊       | 929/3340 [21:42<17:01,  2.36it/s] 

{'loss': 0.0363, 'grad_norm': 6.655033588409424, 'learning_rate': 3.2981595092024535e-05, 'epoch': 5.4}



 28%|██▊       | 929/3340 [21:59<17:01,  2.36it/s] 

{'loss': 0.0595, 'grad_norm': 1.148385763168335, 'learning_rate': 3.2265848670756646e-05, 'epoch': 5.5}



 28%|██▊       | 929/3340 [22:15<17:01,  2.36it/s] 

{'loss': 0.0535, 'grad_norm': 8.879231452941895, 'learning_rate': 3.155010224948875e-05, 'epoch': 5.6}



 28%|██▊       | 929/3340 [22:31<17:01,  2.36it/s] 

{'loss': 0.0576, 'grad_norm': 0.8673526644706726, 'learning_rate': 3.0834355828220855e-05, 'epoch': 5.7}



 28%|██▊       | 929/3340 [22:49<17:01,  2.36it/s] 

{'loss': 0.0519, 'grad_norm': 9.794440269470215, 'learning_rate': 3.0118609406952962e-05, 'epoch': 5.8}



 28%|██▊       | 929/3340 [23:06<17:01,  2.36it/s] 

{'loss': 0.0626, 'grad_norm': 5.573687553405762, 'learning_rate': 2.940286298568507e-05, 'epoch': 5.9}



 28%|██▊       | 929/3340 [23:25<17:01,  2.36it/s] 

{'loss': 0.061, 'grad_norm': 1.7495101690292358, 'learning_rate': 2.8687116564417178e-05, 'epoch': 6.0}



[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
[A
                                                   

 28%|██▊       | 929/3340 [23:39<17:01,  2.36it/s]
[A
[A

{'eval_loss': 1.8702986240386963, 'eval_accuracy': 0.6076923076923076, 'eval_precision': 0.619034545636202, 'eval_recall': 0.6076923076923076, 'eval_f1': 0.6100188655190375, 'eval_runtime': 13.9271, 'eval_samples_per_second': 46.671, 'eval_steps_per_second': 1.508, 'epoch': 6.0}



 60%|██████    | 1500/2500 [16:46<11:11,  1.49it/s]

{'train_runtime': 1006.8732, 'train_samples_per_second': 79.454, 'train_steps_per_second': 2.483, 'train_loss': 0.23589598234494527, 'epoch': 6.0}





TrainOutput(global_step=1500, training_loss=0.23589598234494527, metrics={'train_runtime': 1006.8732, 'train_samples_per_second': 79.454, 'train_steps_per_second': 2.483, 'total_flos': 1.51998969249792e+17, 'train_loss': 0.23589598234494527, 'epoch': 6.0})

In [47]:
eval_results = trainer.evaluate()
print(f"Validation Loss: {eval_results['eval_loss']:.4f}")
print(f"Validation Accuracy: {eval_results['eval_accuracy']:.4f}")
print(f"Validation Precision: {eval_results['eval_precision']:.4f}")
print(f"Validation Recall: {eval_results['eval_recall']:.4f}")
print(f"Validation F1 Score: {eval_results['eval_f1']:.4f}")

100%|██████████| 21/21 [00:08<00:00,  2.59it/s]

Validation Loss: 1.3254
Validation Accuracy: 0.5754
Validation Precision: 0.6255
Validation Recall: 0.5754
Validation F1 Score: 0.5876





In [55]:
model.save_pretrained('./mobilenet_v2_fer2013_model')
feature_extractor.save_pretrained('./mobilenet_v2_fer2013_model')

['./mobilenet_v2_fer2013_model\\preprocessor_config.json']

In [None]:
model.save_pretrained('./resnet26_fer2013_model')
feature_extractor.save_pretrained('./resnet26_fer2013_model')
