 # Experiments: Training of EfficientNetB0 with Attention

In [1]:
import torch
from data_utils import DFirt
from classifiers.efficient_net_b0_att import EfficientNetB0Att

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

%load_ext autoreload
%autoreload 2

###### Data Preprocessing

In [2]:
# Augment expression swap and identity swap 
train = DFirt('../DFirt/train/')
val = DFirt('../DFirt/val/')

In [3]:
train.class_distribution()

attribute_manipulation :  160000
expression_swap :  160000
face_synthesis :  160000
identity_swap :  160000
real :  160000


In [4]:
val.class_distribution()

attribute_manipulation :  1999
expression_swap :  4057
face_synthesis :  3998
identity_swap :  7789
real :  3213


In [5]:
bs = 64
train_loader = train.data_loader(batch_size=bs)
val_loader = val.data_loader(batch_size=bs)

###### Define Model
Plain EfficientNet-B0 and Attention Module is implemented in the forward function.

In [6]:
model = EfficientNetB0Att(version='b0', reg=False)
model.to(device)
print('Number of parameters: ', sum(p.numel() for p in model.parameters() if p.requires_grad))

Loaded pretrained weights for efficientnet-b0
Number of parameters:  5653633


###### Training

In [7]:
from solver import Solver
solver = Solver(optim_args={'lr': 1e-4}, reg=False)
train_history = solver.train(model, train_loader, val_loader, log_nth=1, num_epochs=12)

START TRAIN.
[Epoch 1/12] TRAIN acc/loss: 0.860061/0.345815
[Epoch 1/12] VAL acc/loss: 0.888345/0.285121
Saving model... models/EfficientNetb0Att.model
[Epoch 2/12] TRAIN acc/loss: 0.913312/0.215248
[Epoch 2/12] VAL acc/loss: 0.892240/0.289401
Saving model... models/EfficientNetb0Att.model
[Epoch 3/12] TRAIN acc/loss: 0.934215/0.164062
[Epoch 3/12] VAL acc/loss: 0.898699/0.328680
Saving model... models/EfficientNetb0Att.model
[Epoch 4/12] TRAIN acc/loss: 0.949013/0.130210
[Epoch 4/12] VAL acc/loss: 0.904873/0.341705
Saving model... models/EfficientNetb0Att.model
[Epoch 5/12] TRAIN acc/loss: 0.958580/0.104649
[Epoch 5/12] VAL acc/loss: 0.906297/0.346621
Saving model... models/EfficientNetb0Att.model
[Epoch 6/12] TRAIN acc/loss: 0.966741/0.085711
[Epoch 6/12] VAL acc/loss: 0.906155/0.375337
[Epoch 7/12] TRAIN acc/loss: 0.972274/0.072146
[Epoch 7/12] VAL acc/loss: 0.906487/0.424374
Saving model... models/EfficientNetb0Att.model
[Epoch 8/12] TRAIN acc/loss: 0.976221/0.062448
[Epoch 8/12] V