In [None]:
cascades = [
    'haarcascade_frontalface_default.xml',
    'haarcascade_frontalface_alt.xml',
    'haarcascade_frontalface_alt_2.xml',
    'haarcascade_frontalface_alt_tree.xml',
    'haarcascade_profileface.xml',
]

In [109]:
import torch
import fastai.vision.all as vision

files_path = '../data/03/fer_2013'

In [110]:
def compute_results(learner):
    learner.show_results()
    learner.lr_find()
    interp = vision.Interpretation.from_learner(learner)
    interp.plot_top_losses(9,largest=True, figsize=(15,10))
    classif_interp = vision.ClassificationInterpretation.from_learner(learner)
    classif_interp.print_classification_report()
    classif_interp.plot_confusion_matrix(title=f"Confusion maxtrix of {learner.arch.__name__} model")
    classif_interp.most_confused()

In [121]:
tfms = vision.aug_transforms (
    mult=3.0,
    flip_vert=True,
    max_rotate=15.0,
    size=48
)

batch_sizes = [2**n for n in range(4, 8)]

metrics=[
    vision.F1Score(average='macro')
]

archs = [
    # vision.models.alexnet,
    vision.models.resnet18,
    vision.models.densenet121,
    vision.models.vgg11_bn,
    vision.models.vgg19_bn,
]


In [118]:
import warnings

warnings.filterwarnings('ignore')

In [123]:
import tensorboard
from fastai.callback.tensorboard import TensorBoardCallback
import os

for batch_size in batch_sizes:
    dls = vision.ImageDataLoaders.from_folder(
        path=files_path, 
        train='train',
        valid='val',
        bs=batch_size, 
        shuffle=True, 
        device=torch.device('cuda'),
        batch_tfms=tfms
    )
    
    for arch in archs:
        name = arch.__name__
        callbacks = [
            vision.EarlyStoppingCallback(patience=5),
            vision.ReduceLROnPlateau(patience=3, factor=10),
            TensorBoardCallback(
                os.path.join('../models/logs/fastai', str(batch_size), name)
            )
        ]
        
        learn = vision.vision_learner(
            dls=dls, 
            arch=arch, 
            metrics=metrics,
            cbs=callbacks
        )

        learn.fine_tune(100)
        compute_results(learn)
        

epoch,train_loss,valid_loss,f1_score,time
0,1.956758,1.860542,0.200743,01:58


epoch,train_loss,valid_loss,f1_score,time
0,1.76627,1.595381,0.277877,02:58
1,1.716,1.550157,0.290457,02:57
2,1.70322,1.465479,0.342384,02:58
3,1.67121,1.419104,0.374371,03:00
4,1.625391,1.399836,0.367479,02:59
5,1.638608,1.354146,0.377823,02:58
6,1.59269,1.336679,0.393967,02:58
7,1.60065,1.313328,0.417209,02:57
8,1.598849,1.326231,0.401785,02:58
9,1.574238,1.319836,0.413658,02:58


Epoch 25: reducing lr to 9.654086509011709e-05
Epoch 44: reducing lr to 8.909246520845526e-05
No improvement since epoch 41: early stopping


              precision    recall  f1-score   support

       angry       0.51      0.51      0.51       495
     disgust       0.44      0.50      0.47        54
        fear       0.37      0.49      0.42       512
       happy       0.86      0.80      0.83       898
     neutral       0.50      0.65      0.57       619
         sad       0.56      0.37      0.44       607
    surprise       0.72      0.59      0.65       400

    accuracy                           0.59      3585
   macro avg       0.57      0.56      0.55      3585
weighted avg       0.61      0.59      0.59      3585



Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /home/gio/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth


  0%|          | 0.00/30.8M [00:00<?, ?B/s]

epoch,train_loss,valid_loss,f1_score,time
0,1.902436,1.682388,0.219917,02:57


epoch,train_loss,valid_loss,f1_score,time
0,1.738731,1.527391,0.296464,04:01
1,1.706684,1.436099,0.319817,04:01
2,1.653369,1.386673,0.365042,04:02
3,1.664597,1.349671,0.368436,04:02
4,1.613784,1.316862,0.415474,04:02
5,1.635805,1.29625,0.402302,04:02
6,1.571789,1.285732,0.415359,04:03
7,1.60088,1.262906,0.430121,04:03
8,1.5378,1.240769,0.451752,04:03
9,1.567674,1.244863,0.425625,04:13


Epoch 31: reducing lr to 9.979882918113223e-05
No improvement since epoch 28: early stopping


              precision    recall  f1-score   support

       angry       0.57      0.49      0.53       495
     disgust       0.61      0.35      0.45        54
        fear       0.44      0.48      0.46       512
       happy       0.86      0.82      0.84       898
     neutral       0.54      0.69      0.60       619
         sad       0.55      0.40      0.46       607
    surprise       0.64      0.77      0.70       400

    accuracy                           0.62      3585
   macro avg       0.60      0.57      0.58      3585
weighted avg       0.62      0.62      0.62      3585



Downloading: "https://download.pytorch.org/models/vgg11_bn-6002323d.pth" to /home/gio/.cache/torch/hub/checkpoints/vgg11_bn-6002323d.pth


  0%|          | 0.00/507M [00:00<?, ?B/s]

epoch,train_loss,valid_loss,f1_score,time
0,1.87136,1.884452,0.203212,02:14


epoch,train_loss,valid_loss,f1_score,time
0,1.734226,1.576118,0.269511,03:10
1,1.721468,1.4763,0.330455,03:10
2,1.632073,1.406846,0.358662,03:12
3,1.629164,1.357571,0.388117,03:12
4,1.623753,1.347763,0.393772,03:10
5,1.572054,1.320599,0.392817,03:09
6,1.583494,1.264513,0.420101,03:12
7,1.577013,1.262307,0.431441,03:07
8,1.567351,1.255428,0.443758,03:07
9,1.541765,1.238322,0.441206,03:06


Epoch 15: reducing lr to 6.417881042874624e-05
Epoch 24: reducing lr to 9.463984638184743e-05
Epoch 49: reducing lr to 8.117565764452141e-05
Epoch 56: reducing lr to 6.75702347137672e-05
No improvement since epoch 53: early stopping


              precision    recall  f1-score   support

       angry       0.52      0.62      0.57       495
     disgust       0.81      0.56      0.66        54
        fear       0.57      0.42      0.48       512
       happy       0.89      0.84      0.86       898
     neutral       0.56      0.70      0.62       619
         sad       0.56      0.48      0.52       607
    surprise       0.73      0.80      0.76       400

    accuracy                           0.66      3585
   macro avg       0.66      0.63      0.64      3585
weighted avg       0.66      0.66      0.65      3585



epoch,train_loss,valid_loss,f1_score,time
0,1.911581,1.957153,0.174738,03:34


epoch,train_loss,valid_loss,f1_score,time
0,1.685115,1.444768,0.348379,05:23
1,1.624899,1.323371,0.390207,05:23
2,1.533502,1.250582,0.423206,05:23
3,1.487569,1.209563,0.463919,05:23
4,1.467098,1.174321,0.455726,05:23
5,1.510434,1.140699,0.505904,05:23
6,1.429639,1.141588,0.481779,05:23
7,1.431215,1.296907,0.45507,05:23
8,1.426191,1.146351,0.48571,05:23
9,1.395074,1.142406,0.543662,05:23


Epoch 8: reducing lr to 3.64866975194948e-05
Epoch 18: reducing lr to 7.626732989672503e-05
Epoch 22: reducing lr to 8.972422792449991e-05
Epoch 26: reducing lr to 9.804153982669419e-05
No improvement since epoch 23: early stopping


              precision    recall  f1-score   support

       angry       0.71      0.34      0.46       495
     disgust       0.52      0.56      0.54        54
        fear       0.44      0.47      0.45       512
       happy       0.84      0.85      0.85       898
     neutral       0.52      0.76      0.62       619
         sad       0.52      0.47      0.50       607
    surprise       0.76      0.69      0.72       400

    accuracy                           0.62      3585
   macro avg       0.61      0.59      0.59      3585
weighted avg       0.64      0.62      0.62      3585



epoch,train_loss,valid_loss,f1_score,time
0,1.924232,1.894932,0.187245,01:14


epoch,train_loss,valid_loss,f1_score,time
0,1.780242,1.626969,0.260056,01:45
1,1.732693,1.550464,0.294589,01:45
2,1.68542,1.497894,0.325797,01:45
3,1.668207,1.469334,0.323146,01:45
4,1.6496,1.437677,0.344543,01:45
5,1.62663,1.375756,0.368612,01:45
6,1.611394,1.345007,0.39901,01:45
7,1.615154,1.341002,0.396916,01:45
8,1.596358,1.326316,0.382125,01:45
9,1.5784,1.315925,0.405099,01:45


Epoch 44: reducing lr to 8.90932459338421e-05
No improvement since epoch 41: early stopping


              precision    recall  f1-score   support

       angry       0.51      0.50      0.51       495
     disgust       0.54      0.52      0.53        54
        fear       0.43      0.40      0.42       512
       happy       0.83      0.81      0.82       898
     neutral       0.50      0.69      0.58       619
         sad       0.50      0.35      0.41       607
    surprise       0.67      0.73      0.70       400

    accuracy                           0.60      3585
   macro avg       0.57      0.57      0.57      3585
weighted avg       0.60      0.60      0.59      3585



epoch,train_loss,valid_loss,f1_score,time
0,1.904042,1.663809,0.227082,02:07


epoch,train_loss,valid_loss,f1_score,time
0,1.74938,1.51967,0.287028,02:43
1,1.702221,1.453331,0.329698,02:43
2,1.683025,1.400605,0.342218,02:43
3,1.669888,1.355028,0.37394,02:43
4,1.627991,1.350265,0.370869,02:43
5,1.588914,1.321826,0.388774,02:43
6,1.581489,1.265912,0.426979,02:43
7,1.576908,1.274066,0.414908,02:43
8,1.545372,1.263009,0.429667,02:43
9,1.517117,1.214526,0.477662,02:43


Epoch 22: reducing lr to 8.9722669077482e-05
Epoch 39: reducing lr to 9.504957806130462e-05
Epoch 48: reducing lr to 8.289899173532382e-05
Epoch 57: reducing lr to 6.54535806647523e-05
No improvement since epoch 54: early stopping


              precision    recall  f1-score   support

       angry       0.52      0.57      0.55       495
     disgust       0.74      0.54      0.62        54
        fear       0.51      0.46      0.48       512
       happy       0.88      0.80      0.84       898
     neutral       0.52      0.75      0.62       619
         sad       0.58      0.46      0.51       607
    surprise       0.79      0.69      0.74       400

    accuracy                           0.64      3585
   macro avg       0.65      0.61      0.62      3585
weighted avg       0.65      0.64      0.64      3585



epoch,train_loss,valid_loss,f1_score,time
0,1.931851,1.821692,0.228462,01:14


epoch,train_loss,valid_loss,f1_score,time
0,1.753491,1.554158,0.289371,01:47
1,1.716914,1.498202,0.314558,01:47
2,1.681077,1.431915,0.347443,01:47
3,1.63195,1.382521,0.367193,01:47
4,1.631184,1.358621,0.367974,01:47
5,1.572748,1.305392,0.38933,01:47
6,1.57224,1.311403,0.392935,01:47
7,1.557988,1.283972,0.412971,01:47
8,1.561852,1.233973,0.447757,01:47
9,1.546387,1.238843,0.422228,01:47


Epoch 21: reducing lr to 8.676175221039784e-05
Epoch 34: reducing lr to 9.874696553437278e-05
Epoch 45: reducing lr to 8.765534284672913e-05
No improvement since epoch 42: early stopping


              precision    recall  f1-score   support

       angry       0.63      0.47      0.54       495
     disgust       0.70      0.56      0.62        54
        fear       0.46      0.53      0.49       512
       happy       0.86      0.85      0.85       898
     neutral       0.60      0.63      0.61       619
         sad       0.51      0.58      0.54       607
    surprise       0.74      0.64      0.69       400

    accuracy                           0.64      3585
   macro avg       0.64      0.61      0.62      3585
weighted avg       0.65      0.64      0.64      3585



epoch,train_loss,valid_loss,f1_score,time
0,1.897004,1.769396,0.221211,02:33


epoch,train_loss,valid_loss,f1_score,time
0,1.696486,1.527222,0.301657,03:54
1,1.620386,1.372917,0.365655,03:51
2,1.54985,1.274825,0.403068,03:55
