In [1]:
import os
import warnings
from tqdm.notebook import tqdm
import copy
from PIL import Image
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.utils.data as data
from torchvision import transforms

from networks.dan import DAN

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  warn(f"Failed to load image Python extension: {e}")


In [2]:
def warn(*args, **kwargs):
    pass
warnings.warn = warn

def load_data(image_path, emotion_path, subset):
    images = np.load(image_path)        # shape = (35393, 48, 48, 1)
    # images = images/255.0
    images = np.float32(images)
    
    if subset == 'train':
        training_emotions = np.load(training_emotion_path)    # shape = (35393, 8)
        training_emotions = np.float32(training_emotions)
        return images[:training_size], training_emotions[:training_size]
    if subset == 'test':
        test_emotions = np.load(test_emotion_path)
        test_emotions = np.float32(test_emotions)
        return images[training_size:], test_emotions[training_size:]

In [3]:
class FERPlusDataset(data.Dataset):
    def __init__(self, image_path, emotion_path, subset, transform = None):
        self.transform = transform
        assert(subset=='train' or subset=='test')
        self.images, self.emotions = load_data(image_path, emotion_path, subset)

    def __getitem__(self, index):
        image = self.images[index]
        emotion = self.emotions[index]
        if self.transform is not None:
            image = self.transform(image)
        return image, emotion
    
    def __len__(self):
        return len(self.images)

In [4]:
class AffinityLoss(nn.Module):
    def __init__(self, device, num_class=8, feat_dim=512):
        super(AffinityLoss, self).__init__()
        self.num_class = num_class
        self.feat_dim = feat_dim
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.device = device

        self.centers = nn.Parameter(torch.randn(self.num_class, self.feat_dim).to(device))

    def forward(self, x, labels):
        x = self.gap(x).view(x.size(0), -1)

        batch_size = x.size(0)
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_class) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_class, batch_size).t()
        distmat.addmm_(x, self.centers.t(), beta=1, alpha=-2)

        classes = torch.arange(self.num_class).long().to(self.device)
        labels = labels.expand(batch_size, self.num_class)
        mask = labels.eq(classes.expand(batch_size, self.num_class))

        dist = distmat * mask.float()
        dist = dist / self.centers.var(dim=0).sum()

        loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size

        return loss

In [5]:
class PartitionLoss(nn.Module):
    def __init__(self, ):
        super(PartitionLoss, self).__init__()
    
    def forward(self, x):
        num_head = x.size(1)

        if num_head > 1:
            var = x.var(dim=1).mean()
            loss = torch.log(1+num_head/var)
        else:
            loss = 0
            
        return loss

In [6]:
def accuracy(output, target):
    batch_size = target.size(0)
    acc = 0
    for i in range(batch_size):
        true = target[i]
        pred = output[i]
        index_max = torch.argmax(pred)
        if true[index_max] == torch.max(true):
            acc += 1
    acc = float(acc)/batch_size
    return acc

In [7]:
batch_size=32
num_head=4  # Number of attention head.
training_size = 28317 + 3541

image_path = '../dataset/aligned_images.npy'
training_emotion_path = "../dataset/emotions_multi.npy"
test_emotion_path = '../dataset/emotions_multi.npy'

In [8]:
if torch.cuda.is_available():
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.enabled = True
print("GPU",torch.cuda.is_available())

model = DAN(num_head=num_head)
model = model.to(device)

data_transforms = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomApply([
            transforms.RandomRotation(20),
            transforms.RandomCrop(224, padding=32)
        ], p=0.2),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225]),
                                
    transforms.RandomErasing(scale=(0.02,0.25)),
    ])    
train_dataset = FERPlusDataset(image_path=image_path, 
                               emotion_path=training_emotion_path, 
                               subset="train", 
                               transform=data_transforms)    
print('Whole train set size:', train_dataset.__len__())
train_loader = torch.utils.data.DataLoader(
    train_dataset,
    batch_size=batch_size,
    shuffle=True,
    #num_workers=num_workers,
    #pin_memory = True
)

data_transforms_val = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Grayscale(num_output_channels=3),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])
])
                                   
val_dataset = FERPlusDataset(image_path=image_path, 
                             emotion_path=test_emotion_path, 
                             subset="test", 
                             transform=data_transforms_val) 
print('Validation set size:', val_dataset.__len__())
val_loader = torch.utils.data.DataLoader(
    val_dataset,
    batch_size=batch_size,
    shuffle=False,
    #num_workers=num_workers,
    #pin_memory = True
)

loss1 = AffinityLoss(device)
loss2 = PartitionLoss()
params = list(model.parameters()) + list(loss1.parameters())

GPU True
Whole train set size: 31858
Validation set size: 3535


In [9]:
def fit(model, epochs, lr, factor1=1, factor2=1, adam=False, mse=False):
    
    best_model = None
    best_optim = None
    best_acc = 0 
    
    with torch.no_grad():
        
        iter_cnt = 0
        acc = 0
        model.eval()
        for (imgs, targets) in val_loader:
            imgs = imgs.to(device)
            targets = targets.to(device)
            
            out,feat,heads = model(imgs)
            iter_cnt+=1
            
            correctness = accuracy(out, targets)
            acc += correctness / len(val_loader)
        best_acc = acc
        best_model=copy.deepcopy(model.state_dict())
        print("copy best successfully!") 
        tqdm.write("Current best accuracy:%.4f." % (acc))
        
    if mse == True:
        loss0 = torch.nn.MSELoss().to(device) 
    else:
        loss0 = torch.nn.CrossEntropyLoss().to(device)

    if adam == True:
        optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    else:
        optimizer = torch.optim.SGD(params,lr=lr, weight_decay = 1e-4, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.5)
    
    
    for epoch in range(1, epochs + 1):
        running_loss = 0.0
        acc = 0
        iter_cnt = 0
        model.train()

        for (imgs, targets) in tqdm(train_loader):
            iter_cnt += 1
            optimizer.zero_grad()
            
            imgs = imgs.to(device)
            targets = targets.to(device)
            out,feat,heads = model(imgs)
            
            loss = loss0(out,targets) + factor1*loss1(feat,targets) + factor2*loss2(heads)  #89.3 89.4

            loss.backward()
            optimizer.step()

            running_loss += loss

            correctness = (out.argmax(dim=1) == targets.argmax(dim=1)).float().mean()
            acc += correctness / len(train_loader)

        #acc = correct_sum.float() / float(train_dataset.__len__())
        running_loss = running_loss/iter_cnt
        tqdm.write('[Epoch %d] Training accuracy: %.4f. Loss: %.4f. LR %.6f' % (epoch, acc, running_loss, optimizer.param_groups[0]['lr']))

        with torch.no_grad():
            running_loss = 0.0
            iter_cnt = 0
            acc = 0
            #bingo_cnt = 0
            #sample_cnt = 0
            #baccs = []

            model.eval()
            for (imgs, targets) in val_loader:
                imgs = imgs.to(device)
                targets = targets.to(device)

                out,feat,heads = model(imgs)
                loss = loss0(out,targets) + loss1(feat,targets) + loss2(heads)

                running_loss += loss
                iter_cnt+=1
                #_, predicts = torch.max(out, 1)
                #correct_num  = torch.eq(predicts,targets)
                #bingo_cnt += correct_num.sum().cpu()
                #sample_cnt += out.size(0)

                correctness = accuracy(out, targets)
                acc += correctness / len(val_loader)

                #baccs.append(balanced_accuracy_score(targets.cpu().numpy(),predicts.cpu().numpy()))

            running_loss = running_loss/iter_cnt   
            #scheduler.step()

            #acc = bingo_cnt.float()/float(sample_cnt)
            #acc = np.around(acc.numpy(),4)
            #bacc = np.around(np.mean(baccs),4)
            tqdm.write("[Epoch %d] Validation accuracy:%.4f. Loss:%.4f" % (epoch, acc,  running_loss))
                            
        if best_acc < acc:
            best_acc = acc
            best_model=copy.deepcopy(model.state_dict())
            #best_optim=copy.deepcopy(optimizer.state_dict())
            print("copy best successfully!")    
        tqdm.write("best_acc:" + str(best_acc))
            
    if best_model is not None:
        #optimizer.load_state_dict(best_optim)
        model.load_state_dict(best_model)
        print("load best successfully!")    

In [None]:
fit(model, epochs=10, lr=1e-1, factor1=1, factor2=1, adam=0, mse=0)
fit(model, epochs=10, lr=2e-2, factor1=1, factor2=1, adam=0, mse=0)
fit(model, epochs=10, lr=5e-3, factor1=1, factor2=1, adam=0, mse=0)
fit(model, epochs=10, lr=1e-3, factor1=1, factor2=1, adam=0, mse=0)
fit(model, epochs=10, lr=2e-4, factor1=1, factor2=1, adam=0, mse=0)

copy best successfully!
Current best accuracy:0.0121.


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 1] Training accuracy: 0.7228. Loss: 1.2406. LR 0.100000
[Epoch 1] Validation accuracy:0.7914. Loss:0.9896
copy best successfully!
best_acc:0.7914414414414411


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 2] Training accuracy: 0.7754. Loss: 0.9787. LR 0.100000
[Epoch 2] Validation accuracy:0.8083. Loss:0.8911
copy best successfully!
best_acc:0.8083333333333331


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 3] Training accuracy: 0.7889. Loss: 0.9283. LR 0.100000
[Epoch 3] Validation accuracy:0.8121. Loss:0.8773
copy best successfully!
best_acc:0.8120683183183177


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 4] Training accuracy: 0.7987. Loss: 0.9101. LR 0.100000
[Epoch 4] Validation accuracy:0.8013. Loss:0.8832
best_acc:0.8120683183183177


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 5] Training accuracy: 0.8019. Loss: 0.9058. LR 0.100000
[Epoch 5] Validation accuracy:0.8095. Loss:0.8811
best_acc:0.8120683183183177


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 6] Training accuracy: 0.8029. Loss: 0.8962. LR 0.100000
[Epoch 6] Validation accuracy:0.8203. Loss:0.8602
copy best successfully!
best_acc:0.8202702702702699


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 7] Training accuracy: 0.8056. Loss: 0.8910. LR 0.100000
[Epoch 7] Validation accuracy:0.7929. Loss:0.8875
best_acc:0.8202702702702699


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 8] Training accuracy: 0.8087. Loss: 0.8927. LR 0.100000
[Epoch 8] Validation accuracy:0.8286. Loss:0.8514
copy best successfully!
best_acc:0.8286411411411407


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 9] Training accuracy: 0.8120. Loss: 0.8858. LR 0.100000
[Epoch 9] Validation accuracy:0.8304. Loss:0.8600
copy best successfully!
best_acc:0.8304054054054049


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 10] Training accuracy: 0.8105. Loss: 0.8854. LR 0.100000
[Epoch 10] Validation accuracy:0.8213. Loss:0.8470
best_acc:0.8304054054054049
load best successfully!
copy best successfully!
Current best accuracy:0.8304.


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 1] Training accuracy: 0.8464. Loss: 0.8201. LR 0.020000
[Epoch 1] Validation accuracy:0.8585. Loss:0.7908
copy best successfully!
best_acc:0.8585210210210202


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 2] Training accuracy: 0.8580. Loss: 0.8014. LR 0.020000
[Epoch 2] Validation accuracy:0.8658. Loss:0.7835
copy best successfully!
best_acc:0.86584084084084


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 3] Training accuracy: 0.8638. Loss: 0.7899. LR 0.020000
[Epoch 3] Validation accuracy:0.8684. Loss:0.7820
copy best successfully!
best_acc:0.8684496996996988


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 4] Training accuracy: 0.8662. Loss: 0.7827. LR 0.020000
[Epoch 4] Validation accuracy:0.8673. Loss:0.7806
best_acc:0.8684496996996988


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 5] Training accuracy: 0.8690. Loss: 0.7769. LR 0.020000
[Epoch 5] Validation accuracy:0.8679. Loss:0.7816
best_acc:0.8684496996996988


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 6] Training accuracy: 0.8710. Loss: 0.7735. LR 0.020000
[Epoch 6] Validation accuracy:0.8628. Loss:0.7870
best_acc:0.8684496996996988


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 7] Training accuracy: 0.8746. Loss: 0.7674. LR 0.020000
[Epoch 7] Validation accuracy:0.8642. Loss:0.7814
best_acc:0.8684496996996988


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 8] Training accuracy: 0.8757. Loss: 0.7623. LR 0.020000
[Epoch 8] Validation accuracy:0.8639. Loss:0.7798
best_acc:0.8684496996996988


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 9] Training accuracy: 0.8758. Loss: 0.7605. LR 0.020000
[Epoch 9] Validation accuracy:0.8611. Loss:0.7867
best_acc:0.8684496996996988


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 10] Training accuracy: 0.8732. Loss: 0.7578. LR 0.020000
[Epoch 10] Validation accuracy:0.8660. Loss:0.7836
best_acc:0.8684496996996988
load best successfully!
copy best successfully!
Current best accuracy:0.8684.


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 1] Training accuracy: 0.8726. Loss: 0.7747. LR 0.005000
[Epoch 1] Validation accuracy:0.8713. Loss:0.7735
copy best successfully!
best_acc:0.871265015015014


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 2] Training accuracy: 0.8743. Loss: 0.7684. LR 0.005000
[Epoch 2] Validation accuracy:0.8704. Loss:0.7725
best_acc:0.871265015015014


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 3] Training accuracy: 0.8780. Loss: 0.7648. LR 0.005000
[Epoch 3] Validation accuracy:0.8724. Loss:0.7729
copy best successfully!
best_acc:0.872353603603603


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 4] Training accuracy: 0.8793. Loss: 0.7612. LR 0.005000
[Epoch 4] Validation accuracy:0.8746. Loss:0.7747
copy best successfully!
best_acc:0.8746058558558549


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 5] Training accuracy: 0.8822. Loss: 0.7580. LR 0.005000
[Epoch 5] Validation accuracy:0.8755. Loss:0.7708
copy best successfully!
best_acc:0.8754879879879869


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 6] Training accuracy: 0.8799. Loss: 0.7586. LR 0.005000
[Epoch 6] Validation accuracy:0.8718. Loss:0.7719
best_acc:0.8754879879879869


  0%|          | 0/996 [00:00<?, ?it/s]

In [None]:
fit(model, epochs=5, lr=1e-4, factor1=1, factor2=1, adam=0, mse=0)

In [None]:
fit(model, epochs=15, lr=1e-3, factor1=1, factor2=1, adam=0, mse=0)

In [None]:
torch.cuda.empty_cache()
model = DAN(num_head=num_head)
model = model.to(device)

fit(model, epochs=10, lr=1e-1, factor1=1, factor2=1, adam=1, mse=1)
fit(model, epochs=10, lr=2e-2, factor1=1, factor2=1, adam=1, mse=1)
fit(model, epochs=10, lr=5e-3, factor1=1, factor2=1, adam=1, mse=1)
fit(model, epochs=10, lr=1e-3, factor1=1, factor2=1, adam=1, mse=1)
fit(model, epochs=10, lr=2e-4, factor1=1, factor2=1, adam=1, mse=1)
torch.save({
            'model_state_dict': model.state_dict(),},
           os.path.join('checkpoints', "fer11.pth"))
tqdm.write('Model saved.')

copy best successfully!
Current best accuracy:0.0417.


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 1] Training accuracy: 0.3772. Loss: 0.0840. LR 0.100000
[Epoch 1] Validation accuracy:0.2676. Loss:0.2061
copy best successfully!
best_acc:0.26756756756756755


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 2] Training accuracy: 0.5225. Loss: 0.0595. LR 0.100000
[Epoch 2] Validation accuracy:0.5293. Loss:0.0574
copy best successfully!
best_acc:0.5292980480480479


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 3] Training accuracy: 0.5986. Loss: 0.0496. LR 0.100000
[Epoch 3] Validation accuracy:0.5580. Loss:0.0599
copy best successfully!
best_acc:0.5580142642642639


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 4] Training accuracy: 0.6365. Loss: 0.0441. LR 0.100000
[Epoch 4] Validation accuracy:0.6954. Loss:0.0411
copy best successfully!
best_acc:0.6954016516516518


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 5] Training accuracy: 0.6599. Loss: 0.0410. LR 0.100000
[Epoch 5] Validation accuracy:0.5776. Loss:0.0647
best_acc:0.6954016516516518


  0%|          | 0/996 [00:00<?, ?it/s]

[Epoch 6] Training accuracy: 0.6771. Loss: 0.0395. LR 0.100000
[Epoch 6] Validation accuracy:0.6676. Loss:0.0639
best_acc:0.6954016516516518


  0%|          | 0/996 [00:00<?, ?it/s]

In [None]:
torch.cuda.empty_cache()
model = DAN(num_head=num_head)
model = model.to(device)

fit(model, epochs=10, lr=1e-1, factor1=1, factor2=1, adam=0, mse=0)
fit(model, epochs=10, lr=2e-2, factor1=1, factor2=1, adam=0, mse=0)
fit(model, epochs=10, lr=5e-3, factor1=1, factor2=1, adam=0, mse=0)
fit(model, epochs=10, lr=1e-3, factor1=1, factor2=1, adam=0, mse=0)
fit(model, epochs=10, lr=2e-4, factor1=1, factor2=1, adam=0, mse=0)
torch.save({
            'model_state_dict': model.state_dict(),},
           os.path.join('checkpoints', "fer00.pth"))
tqdm.write('Model saved.')

In [None]:
torch.cuda.empty_cache()
model = DAN(num_head=num_head)
model = model.to(device)

fit(model, epochs=10, lr=1e-1, factor1=1, factor2=1, adam=1, mse=0)
fit(model, epochs=10, lr=2e-2, factor1=1, factor2=1, adam=1, mse=0)
fit(model, epochs=10, lr=5e-3, factor1=1, factor2=1, adam=1, mse=0)
fit(model, epochs=10, lr=1e-3, factor1=1, factor2=1, adam=1, mse=0)
fit(model, epochs=10, lr=2e-4, factor1=1, factor2=1, adam=1, mse=0)
torch.save({
            'model_state_dict': model.state_dict(),},
           os.path.join('checkpoints', "fer10.pth"))
tqdm.write('Model saved.')