[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/gist/MangaBoba/0b65f8e48d9ba3acb44b2573f8de2d6b/backbone_experements.ipynb)

In [None]:
wandb_api_key = '' #

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
#@title Loading git repos, libs, datasets and stuff
# resnet18 msceleb pretrain weights
!gdown --id 1H421M8mosIVt8KsEWQ1UuYMkQS8X1prf

# rafdb aligned
!unzip /content/drive/MyDrive/rafdb/labels.zip -d /content
!unzip /content/drive/MyDrive/rafdb/images.zip -d /content
!unzip /content/Image/aligned.zip -d /content/Image

# fer2013
# !unzip /content/drive/MyDrive/fer2013.zip 

# original paper implementaions and some some libs for collab
!unzip /content/drive/MyDrive/RUL.zip
!git clone https://github.com/JDAI-CV/FaceX-Zoo
!git clone https://github.com/amirhfarzaneh/dacl
!git clone https://github.com/zyh-uaiaaaa/Relative-Uncertainty-Learning
!pip install timm
!pip install wandb

# jit models
!unzip /content/drive/MyDrive/jit_models.zip

In [None]:
#@title Global imports
import sys
import os
import wandb
import time
import argparse
import pprint
from tqdm import tqdm
from datetime import timedelta
import datetime
import cv2
import numpy as np
import random

import torch
import torch.nn as nn
from torchvision import transforms
import torchvision.models as models
import torch.utils.data as data
import torch.nn.functional as F
import torch.nn.init as init
from torchvision import models
from torchvision import models as torch_models
# from torchsampler import ImbalancedDatasetSampler
from torch.utils.data import WeightedRandomSampler
from sklearn.metrics import balanced_accuracy_score

from datetime import timedelta
from torch.utils.tensorboard import SummaryWriter
import torch.backends.cudnn as cudnn
import socket

import warnings
from torch.utils.data import sampler
from tqdm import tqdm
import argparse

from PIL import Image
import pandas as pd
import sklearn
import seaborn as sn
from pandas.core.arrays.numeric import T
from PIL import Image 
import matplotlib.pyplot as plt
import csv
import math



In [None]:
#@title Relative uncertainity learning (RUL):  Original paper implementation  

#add gaussian noise
def add_g(image_array, mean=0.0, var=30):
    std = var ** 0.5
    image_add = image_array + np.random.normal(mean, std, image_array.shape)
    image_add = np.clip(image_add, 0, 255).astype(np.uint8)
    return image_add

#flip image
def filp_image(image_array):
    return cv2.flip(image_array, 1)

#set random seed to ensure the results can be reproduced,
#we simply set the random seed to 0, change the random seed value might get the performance of RUL better,
#but we believe that the random seed parameter should not be finetuned
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True

#use uncertainty value as weights to mixup feature
#we find that simply follow the traditional mixup setup
# to get mixup pairs can ensure good performance
def mixup_data(x, y, att, use_cuda=True):
    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)
    att1 = att / (att + att[index])
    att2 = att[index] / (att + att[index])
    mixed_x = att1 * x + att2 * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, att1, att2

#add-up loss
def mixup_criterion(y_a, y_b):
    return lambda criterion, pred:  0.5 *  criterion(pred, y_a) + 0.5 * criterion(pred, y_b)


class BasicBlock(nn.Module):
    
    expansion = 1
    
    def __init__(self, in_channels, out_channels, stride = 1, downsample = False):
        super().__init__()
                
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size = 3, 
                               stride = stride, padding = 1, bias = False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size = 3, 
                               stride = 1, padding = 1, bias = False)
        self.bn2 = nn.BatchNorm2d(out_channels)
        
        self.relu = nn.ReLU(inplace = True)
        
        if downsample:
            conv = nn.Conv2d(in_channels, out_channels, kernel_size = 1, 
                             stride = stride, bias = False)
            bn = nn.BatchNorm2d(out_channels)
            downsample = nn.Sequential(conv, bn)
        else:
            downsample = None
        
        self.downsample = downsample
        
    def forward(self, x):
        
        i = x
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        
        x = self.conv2(x)
        x = self.bn2(x)
        
        if self.downsample is not None:
            i = self.downsample(i)
                        
        x += i
        x = self.relu(x)
        
        return x
    
    
class ResNet(nn.Module):
    def __init__(self, block, n_blocks, channels, output_dim):
        super().__init__()
                
        
        self.in_channels = channels[0]
            
        assert len(n_blocks) == len(channels) == 4
        
        self.conv1 = nn.Conv2d(3, self.in_channels, kernel_size = 7, stride = 2, padding = 3, bias = False)
        self.bn1 = nn.BatchNorm2d(self.in_channels)
        self.relu = nn.ReLU(inplace = True)
        self.maxpool = nn.MaxPool2d(kernel_size = 3, stride = 2, padding = 1)
        
        self.layer1 = self.get_resnet_layer(block, n_blocks[0], channels[0])
        self.layer2 = self.get_resnet_layer(block, n_blocks[1], channels[1], stride = 2)
        self.layer3 = self.get_resnet_layer(block, n_blocks[2], channels[2], stride = 2)
        self.layer4 = self.get_resnet_layer(block, n_blocks[3], channels[3], stride = 2)
        
        self.avgpool = nn.AdaptiveAvgPool2d((1,1))
        self.fc = nn.Linear(self.in_channels, output_dim)
        
    def get_resnet_layer(self, block=BasicBlock, n_blocks=[2,2,2,2], channels=[64, 128, 256, 512], stride = 1):
    
        layers = []
        
        if self.in_channels != block.expansion * channels:
            downsample = True
        else:
            downsample = False
        
        layers.append(block(self.in_channels, channels, stride, downsample))
        
        for i in range(1, n_blocks):
            layers.append(block(block.expansion * channels, channels))

        self.in_channels = block.expansion * channels
            
        return nn.Sequential(*layers)
        
    def forward(self, x):
        
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.maxpool(x)

        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        
        x = self.avgpool(x)
        h = x.view(x.shape[0], -1)
        x = self.fc(h)
        
        return x, h
    

"""
@author: Jun Wang 
@date: 20201019
@contact: jun21wangustc@gmail.com
"""

# based on:
# https://github.com/TreB1eN/InsightFace_Pytorch/blob/master/model.py

from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
import torch

class Flatten(Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class Conv_block(Module):
    def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
        super(Conv_block, self).__init__()
        self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
        self.bn = BatchNorm2d(out_c)
        self.prelu = PReLU(out_c)
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.prelu(x)
        return x

class Linear_block(Module):
    def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
        super(Linear_block, self).__init__()
        self.conv = Conv2d(in_c, out_channels=out_c, kernel_size=kernel, groups=groups, stride=stride, padding=padding, bias=False)
        self.bn = BatchNorm2d(out_c)
    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        return x

class Depth_Wise(Module):
     def __init__(self, in_c, out_c, residual = False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
        super(Depth_Wise, self).__init__()
        self.conv = Conv_block(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
        self.conv_dw = Conv_block(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride)
        self.project = Linear_block(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
        self.residual = residual
     def forward(self, x):
        if self.residual:
            short_cut = x
        x = self.conv(x)
        x = self.conv_dw(x)
        x = self.project(x)
        if self.residual:
            output = short_cut + x
        else:
            output = x
        return output

class Residual(Module):
    def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
        super(Residual, self).__init__()
        modules = []
        for _ in range(num_block):
            modules.append(Depth_Wise(c, c, residual=True, kernel=kernel, padding=padding, stride=stride, groups=groups))
        self.model = Sequential(*modules)
    def forward(self, x):
        return self.model(x)

class MobileFaceNet(Module):
    def __init__(self, embedding_size, out_h, out_w):
        super(MobileFaceNet, self).__init__()
        self.conv1 = Conv_block(3, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1))
        self.conv2_dw = Conv_block(64, 64, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64)
        self.conv_23 = Depth_Wise(64, 64, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128)
        self.conv_3 = Residual(64, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
        self.conv_34 = Depth_Wise(64, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256)
        self.conv_4 = Residual(128, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
        self.conv_45 = Depth_Wise(128, 128, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512)
        self.conv_5 = Residual(128, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1))
        self.conv_6_sep = Conv_block(128, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
        #self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(7,7), stride=(1, 1), padding=(0, 0))
        #self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(4,7), stride=(1, 1), padding=(0, 0))
        self.conv_6_dw = Linear_block(512, 512, groups=512, kernel=(out_h, out_w), stride=(1, 1), padding=(0, 0))
        self.conv_6_flatten = Flatten()
        self.linear = Linear(512, embedding_size, bias=False)
        self.bn = BatchNorm1d(embedding_size)
    
    def forward(self, x):
        out = self.conv1(x)
        out = self.conv2_dw(out)
        out = self.conv_23(out)
        out = self.conv_3(out)
        out = self.conv_34(out)
        out = self.conv_4(out)
        out = self.conv_45(out)
        out = self.conv_5(out)
        out = self.conv_6_sep(out)
        out = self.conv_6_dw(out)
        out = self.conv_6_flatten(out)
        out = self.linear(out)
        out = self.bn(out)
        return out

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

class res18feature(nn.Module):
    def __init__(self, args, pretrained=True, num_classes=7, drop_rate=0.4, out_dim=64):
        super(res18feature, self).__init__()

        #'affectnet_baseline/resnet18_msceleb.pth'
        res18 = ResNet(block=BasicBlock, n_blocks=[2, 2, 2, 2], channels=[64, 128, 256, 512], output_dim=1000)
        msceleb_model = torch.load(args.pretrained_backbone_path)
        state_dict = msceleb_model['state_dict']
        res18.load_state_dict(state_dict, strict=False)

        self.drop_rate = drop_rate
        self.out_dim = out_dim
        self.features = nn.Sequential(*list(res18.children())[:-2])

        self.mu = nn.Sequential(
            nn.BatchNorm2d(512, eps=2e-5, affine=False),
            nn.Dropout(p=self.drop_rate),
            Flatten(),
            nn.Linear(512 * 7 * 7, self.out_dim),
            nn.BatchNorm1d(self.out_dim, eps=2e-5))

        self.log_var = nn.Sequential(
            nn.BatchNorm2d(512, eps=2e-5, affine=False),
            nn.Dropout(p=self.drop_rate),
            Flatten(),
            nn.Linear(512 * 7 * 7, self.out_dim),
            nn.BatchNorm1d(self.out_dim, eps=2e-5))

    def forward(self, x, target, phase='train'):

        if phase == 'train':
            x = self.features(x)
            mu = self.mu(x)
            logvar = self.log_var(x)

            mixed_x, y_a, y_b, att1, att2 = mixup_data(mu, target, logvar.exp().mean(dim=1, keepdim=True), use_cuda=True)
            return mixed_x, y_a, y_b, att1, att2
        else:
            x = self.features(x)
            output = self.mu(x)
            return output


In [None]:
#@title RafDB dataset


class RafDataset(data.Dataset):
    def __init__(self, phase, transform = None):
        self.raf_path = ''
        self.phase = phase
        self.transform = transform

        df = pd.read_csv(os.path.join(self.raf_path, 
                                      '/content/EmoLabel/list_patition_label.txt'), 
                                  sep=' ', header=None,names=['name','label'])

        if phase == 'train':
            self.data = df[df['name'].str.startswith('train')]
        else:
            self.data = df[df['name'].str.startswith('test')]

        file_names = self.data.loc[:, 'name'].values
        self.label = self.data.loc[:, 'label'].values - 1 # 0:Surprise, 1:Fear, 2:Disgust, 3:Happiness, 4:Sadness, 5:Anger, 6:Neutral

        _, self.sample_counts = np.unique(self.label, return_counts=True)
        # print(f' distribution of {phase} samples: {self.sample_counts}')

        self.file_paths = []
        for f in file_names:
            f = f.split(".")[0]
            f = f +"_aligned.jpg"
            path = os.path.join(self.raf_path, 'Image/aligned', f)
            self.file_paths.append(path)

    def get_labels(self):
      return self.label

    def __len__(self):
        return len(self.file_paths)

    def __getitem__(self, idx):
        path = self.file_paths[idx]
        image = np.array(Image.open(path).convert('RGB'))
        label = self.label[idx]

        if self.transform is not None:
            image = self.transform(image)
        
        return image, label

In [None]:
#@title customized RUL model
class RUL(nn.Module):
    def __init__(self, bbone, num_classes=7, drop_rate=0.4, inp_dim = 512, out_dim=64, feature_size=7):
        super(RUL, self).__init__()

        self.drop_rate = drop_rate
        self.out_dim = out_dim
        self.features = bbone
        self.feature_size = feature_size

        self.mu = nn.Sequential(
            nn.BatchNorm2d(inp_dim, eps=2e-5, affine=False),
            nn.Dropout(p=self.drop_rate),
            nn.Flatten(),
            nn.Linear(inp_dim * feature_size * feature_size, self.out_dim),
            nn.BatchNorm1d(self.out_dim, eps=2e-5))

        self.log_var = nn.Sequential(
            nn.BatchNorm2d(inp_dim, eps=2e-5, affine=False),
            nn.Dropout(p=self.drop_rate),
            nn.Flatten(),
            nn.Linear(inp_dim * feature_size * feature_size, self.out_dim),
            nn.BatchNorm1d(self.out_dim, eps=2e-5))

    def forward(self, x, target, phase='train'):

        if phase == 'train':
            x = self.features(x)
            if self.feature_size == 1:
              x = torch.unsqueeze(torch.unsqueeze(x,2), 3)
            mu = self.mu(x)
            logvar = self.log_var(x)

            mixed_x, y_a, y_b, att1, att2 = mixup_data(mu, target, logvar.exp().mean(dim=1, keepdim=True), use_cuda=True)
            return mixed_x, y_a, y_b, att1, att2
        else:
            x = self.features(x)
            output = self.mu(x)
            return output


In [None]:
#@title RUL model for jit trace
class evalRUL(nn.Module):
    def __init__(self, bbone, num_classes=7, drop_rate=0.4, inp_dim = 512, out_dim=64):
        super(evalRUL, self).__init__()

        self.drop_rate = drop_rate
        self.out_dim = out_dim
        self.features = bbone

        self.mu = nn.Sequential(
            nn.BatchNorm2d(inp_dim, eps=2e-5, affine=False),
            nn.Dropout(p=self.drop_rate),
            Flatten(),
            nn.Linear(inp_dim * 7 * 7, self.out_dim),
            nn.BatchNorm1d(self.out_dim, eps=2e-5))

        self.log_var = nn.Sequential(
            nn.BatchNorm2d(inp_dim, eps=2e-5, affine=False),
            nn.Dropout(p=self.drop_rate),
            Flatten(),
            nn.Linear(inp_dim * 7 * 7, self.out_dim),
            nn.BatchNorm1d(self.out_dim, eps=2e-5))

    def forward(self, x):
            x = self.features(x)
            output = self.mu(x)
            return output

In [None]:
#@title RUL train loop

def train_RUL(RUL,
          sampling=False,
          out_features_dim = 64, 
          epochs = 60, 
          batch_size = 64, 
          workers = 2,
          train_tf = None,
          val_tf = None):

    setup_seed(0)
    model = RUL
    fc = nn.Linear(out_features_dim, 7)

    data_transforms = train_tf
    data_transforms_val = test_tf

    train_dataset = RafDataset(phase='train', transform=data_transforms)
    test_dataset = RafDataset(phase='test', transform=data_transforms_val)

    if sampling == False:
      train_shuffle = True
      train_sampler = None
    else:
      train_shuffle = False
      from collections import Counter
      train_dataset.label
      sample_weights = [0]*len(train_dataset)
      class_weights = []
      classes = Counter(train_dataset.label).keys()
      classes_num_samples = Counter(train_dataset.label).values()
      for s in classes_num_samples:
            class_weights.append(1/s)

      for idx, (data, label) in enumerate(train_dataset):
          class_weight = class_weights[label]
          sample_weights[idx] = class_weight
      train_sampler = WeightedRandomSampler(sample_weights, 
                                            num_samples=len(sample_weights), 
                                            replacement=True)


    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               batch_size=batch_size,
                                               sampler=train_sampler,
                                               shuffle=train_shuffle,
                                               num_workers=workers,
                                               pin_memory=True)

    test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                              shuffle=False,
                                              num_workers=workers,
                                              pin_memory=True)

    model.cuda()
    fc.cuda()

    params = model.parameters()
    params2 = fc.parameters()


    optimizer = torch.optim.AdamW([
        {'params': params},
        {'params': params2, 'lr': 0.01}], lr=0.0001, weight_decay=0.0001,
        amsgrad=True)
    scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.9)



    best_acc = 0
    best_epoch = 0
    for i in range(1, epochs + 1):
        running_loss = 0.0
        iter_cnt = 0
        correct_sum = 0
        model.train()

        for batch_i, (imgs, labels) in enumerate(train_loader):
            imgs = imgs.cuda()
            labels = labels.cuda()

            optimizer.zero_grad()
            mixed_x, y_a, y_b, att1, att2 = model(imgs, labels, phase='train')
            outputs = fc(mixed_x)

            criterion = nn.CrossEntropyLoss()
            loss_func = mixup_criterion(y_a, y_b)
            loss = loss_func(criterion, outputs)

            loss.backward()
            optimizer.step()

            iter_cnt += 1
            _, predicts = torch.max(outputs, 1)

            correct_num = torch.eq(predicts, labels).sum()
            correct_sum += correct_num
            running_loss += loss

        scheduler.step()

        running_loss = running_loss / iter_cnt

        acc = correct_sum.float() / float(train_dataset.__len__())
        #wandb.log(acc , loss)
        print('Epoch : %d, train_acc : %.4f, train_loss: %.4f' % (i, acc, running_loss))

        with torch.no_grad():
            model.eval()

            running_loss = 0.0
            iter_cnt = 0
            correct_sum = 0
            data_num = 0


            for batch_i, (imgs, labels) in enumerate(test_loader):
                imgs = imgs.cuda()
                labels = labels.cuda()

                outputs = model(imgs, labels, phase='test')
                outputs = fc(outputs)

                loss = nn.CrossEntropyLoss()(outputs, labels)

                iter_cnt += 1
                _, predicts = torch.max(outputs, 1)

                correct_num = torch.eq(predicts, labels).sum()
                correct_sum += correct_num

                running_loss += loss
                data_num += outputs.size(0)

            running_loss = running_loss / iter_cnt
            test_acc = correct_sum.float() / float(data_num)

            if test_acc > best_acc:
                best_acc = test_acc
                best_epoch = i
                if best_acc >= 0.86:
                    torch.save({'model_state_dict': model.state_dict(),
                                'fc_state_dict': fc.state_dict()},
                               "model_86.pth")
                    print('Model saved.')


            print('Epoch : %d, test_acc : %.4f, test_loss: %.4f' % (i, test_acc, running_loss))

    print('best acc: ', best_acc, 'best epoch: ', best_epoch)


Prepare lightweight backbones

In [None]:
# Pretrained on MS-Celeb 
backbone_pathes = {'MobileFaceNet':'/content/drive/MyDrive/celebp/mobilefacenet/Epoch_17.pt',
                   'GhostNet' : '/content/drive/MyDrive/celebp/GhostNet/Epoch_17.pt',
                   'ReXNetv1' : '',
                   'SwinT' : '/content/drive/MyDrive/celebp/SwinT/Epoch_17.pt',
                   'LiteCNN29' : '', 
                   'ResNet18' : '/content/resnet18_msceleb.pth'
                   }

In [None]:
sys.path.append('FaceX-Zoo')
from backbone.GhostNet import GhostNet
from backbone.MobileFaceNets import MobileFaceNet
from backbone.LightCNN import LightCNN
from backbone.Swin_Transformer import SwinTransformer

In [None]:
def get_backbone_dict(model_path: str, resnet18 = False) -> dict:

  if resnet18:
    return torch.load('/content/resnet18_msceleb.pth')['state_dict']

  d = torch.load(model_path)['state_dict']
  d = {key: val for key, val in d.items() if key.startswith('backbone')}
  valid_dict = {key.split("backbone.",1)[1]: val for key, val in d.items()}
  return valid_dict

In [None]:
import yaml
with open("/content/FaceX-Zoo/training_mode/backbone_conf.yaml", "r") as stream:
    try:
        bbones_cfgs = yaml.safe_load(stream)
    except yaml.YAMLError as exc:
        print(exc)

In [None]:
# !pip install torchsampler

Experements with RUL backbones

In [None]:
#@title transforms config
train_tf = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        # transforms.RandomApply([
        # transforms.RandomAdjustSharpness(sharpness_factor=1),
        #     ], p=0.3),
        # transforms.RandomApply([
        #         transforms.RandomRotation(20),
        #         transforms.RandomCrop(224, padding=32)
        #     ], p=0.3),
        # transforms.RandomApply([
        #         transforms.ColorJitter(brightness=0.05, contrast=0.05, 
        #                               saturation=0.05, hue=0.05)
        #     ], p=0.3),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(scale=(0.02, 0.25))
    ])

test_tf = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225]),
    ])

In [None]:
bbone = ResNet(block=BasicBlock, n_blocks=[2, 2, 2, 2],
               channels=[64, 128, 256, 512], output_dim=1000)
bbone.load_state_dict(get_backbone_dict(backbone_pathes['ResNet18']),strict=False)
feature_extractor = nn.Sequential(*list(bbone.children())[:-2])

In [None]:
bbone = GhostNet(**bbones_cfgs['GhostNet'])
bbone.load_state_dict(get_backbone_dict(backbone_pathes['GhostNet']))
feature_extractor = nn.Sequential(*list(bbone.children())[:-1])

In [None]:
bbone = MobileFaceNet(*bbones_cfgs['MobileFaceNet'].values())
bbone.load_state_dict(get_backbone_dict(backbone_pathes['MobileFaceNet']))
feature_extractor = nn.Sequential(*list(bbone.children())[:-4])

Use RUL() to train model, use evalRUL() for jit trace 

In [None]:
OriginalRUL = evalRUL(feature_extractor)

In [None]:
MobileRUL = evalRUL(feature_extractor,drop_rate=0.3, #0.2 
                out_dim=64)

In [None]:
inner_feature_size = 64
GhostRUL = evalRUL(bbone=feature_extractor, inp_dim=960, out_dim=inner_feature_size)


In [None]:
# wandb.login(key=wandb_api_key)
# run = wandb.init(settings=wandb.Settings(start_method="thread"),
#                     entity='mangaboba', project='FER')
# run.name = 'teslLog'

train_RUL(OriginalRUL,train_tf=train_tf, val_tf=test_tf, 
          out_features_dim=64)

#wandb.finish()

In [None]:
# load mobileRUL
def load_rul_model(model, pth):
    sd = torch.load(pth)
    fc = torch.nn.Linear(64, 7)
    fc.load_state_dict(sd['fc_state_dict'])
    # d = {key: val for key, val in d.items() if key.startswith('features')}
    # valid_dict = {key.split("features.",1)[1]: val for key, val in d.items()}
    model.load_state_dict(sd['model_state_dict'])
    return torch.nn.Sequential(model, fc)


In [None]:
# MobileRUL = load_rul_model(MobileRUL, '/content/drive/MyDrive/resulet_models/mobileRul_8716_4112.pth')
ResnetRUL = load_rul_model(OriginalRUL, '/content/drive/MyDrive/resulet_models/rul_res18_0.888.pth')

In [None]:
# !cp /content/drive/MyDrive/DAN.zip -d .
# !unzip DAN.zip

In [None]:
#@title DAN model: https://github.com/yaoing/DAN

class DAN(nn.Module):
    def __init__(self, bbone, num_class=7,num_head=4, pretrained=True):
        super(DAN, self).__init__()
        
        self.features = bbone
        self.num_head = num_head
        for i in range(num_head):
            setattr(self,"cat_head%d" %i, CrossAttentionHead())
        self.sig = nn.Sigmoid()
        self.fc = nn.Linear(512, num_class)
        self.bn = nn.BatchNorm1d(num_class)


    def forward(self, x):
        x = self.features(x)
        heads = []
        for i in range(self.num_head):
            heads.append(getattr(self,"cat_head%d" %i)(x))
        
        heads = torch.stack(heads).permute([1,0,2])
        if heads.size(1)>1:
            heads = F.log_softmax(heads,dim=1)
            
        out = self.fc(heads.sum(dim=1))
        out = self.bn(out)
   
        return out, x, heads

class CrossAttentionHead(nn.Module):
    def __init__(self):
        super().__init__()
        self.sa = SpatialAttention()
        self.ca = ChannelAttention()
        self.init_weights()


    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, mode='fan_out')
                if m.bias is not None:
                    init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                init.normal_(m.weight, std=0.001)
                if m.bias is not None:
                    init.constant_(m.bias, 0)
    def forward(self, x):
        sa = self.sa(x)
        ca = self.ca(sa)

        return ca


class SpatialAttention(nn.Module):

    def __init__(self):
        super().__init__()
        self.conv1x1 = nn.Sequential(
            nn.Conv2d(512, 256, kernel_size=1),
            nn.BatchNorm2d(256),
        )
        self.conv_3x3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=3,padding=1),
            nn.BatchNorm2d(512),
        )
        self.conv_1x3 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=(1,3),padding=(0,1)),
            nn.BatchNorm2d(512),
        )
        self.conv_3x1 = nn.Sequential(
            nn.Conv2d(256, 512, kernel_size=(3,1),padding=(1,0)),
            nn.BatchNorm2d(512),
        )
        self.relu = nn.ReLU()


    def forward(self, x):
        y = self.conv1x1(x)
        y = self.relu(self.conv_3x3(y) + self.conv_1x3(y) + self.conv_3x1(y))
        y = y.sum(dim=1,keepdim=True) 
        out = x*y
        
        return out 

class ChannelAttention(nn.Module):

    def __init__(self):
        super().__init__()
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.attention = nn.Sequential(
            nn.Linear(512, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(inplace=True),
            nn.Linear(32, 512),
            nn.Sigmoid()    
        )


    def forward(self, sa):
        sa = self.gap(sa)
        sa = sa.view(sa.size(0),-1)
        y = self.attention(sa)
        out = sa * y
        
        return out

In [None]:
#@title DAN train loop
class AffinityLoss(nn.Module):
    def __init__(self, device, num_class=8, feat_dim=512):
        super(AffinityLoss, self).__init__()
        self.num_class = num_class
        self.feat_dim = feat_dim
        self.gap = nn.AdaptiveAvgPool2d(1)
        self.device = device

        self.centers = nn.Parameter(torch.randn(self.num_class, self.feat_dim).to(device))

    def forward(self, x, labels):
        x = self.gap(x).view(x.size(0), -1)

        batch_size = x.size(0)
        distmat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_class) + \
                  torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_class, batch_size).t()
        distmat.addmm_(x, self.centers.t(), beta=1, alpha=-2)

        classes = torch.arange(self.num_class).long().to(self.device)
        labels = labels.unsqueeze(1).expand(batch_size, self.num_class)
        mask = labels.eq(classes.expand(batch_size, self.num_class))

        dist = distmat * mask.float()
        dist = dist / self.centers.var(dim=0).sum()

        loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size

        return loss

class PartitionLoss(nn.Module):
    def __init__(self, ):
        super(PartitionLoss, self).__init__()
    
    def forward(self, x):
        num_head = x.size(1)

        if num_head > 1:
            var = x.var(dim=1).mean()
            ## add eps to avoid empty var case
            loss = torch.log(1+num_head/(var+1e-6))
        else:
            loss = 0
            
        return loss

def trainin_DAN(DANmodel, batch_size = 256, 
                 workers = 2, lr = 0.1, epochs = 40):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.enabled = True

    model = DANmodel
    model.to(device)

    data_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomApply([
                transforms.RandomRotation(20),
                transforms.RandomCrop(224, padding=32)
            ], p=0.2),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(scale=(0.02,0.25)),
        ])

    # data_transforms = transforms.Compose([
    #     transforms.Resize((224, 224)),
    #     transforms.RandomHorizontalFlip(),
    #     transforms.RandomApply([
    #             transforms.RandomRotation(20),
    #             transforms.RandomCrop(224, padding=32)
    #         ], p=0.3),
    #     transforms.RandomApply([
    #             transforms.ColorJitter(brightness=0.05, contrast=0.05, 
    #                                   saturation=0.05, hue=0.025)
    #         ], p=0.3),
        
    #     transforms.ToTensor(),
    #     transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                              std=[0.229, 0.224, 0.225]),
    #     transforms.RandomErasing(scale=(0.02,0.25)),
    #     ])
    
    train_dataset = RafDataset(phase = 'train', transform = data_transforms)    
    
    print('Whole train set size:', train_dataset.__len__())

    train_loader = torch.utils.data.DataLoader(train_dataset,
                              # sampler=ImbalancedDatasetSampler(train_dataset),
                                               shuffle = True,  
                                               batch_size = batch_size,
                                               num_workers = workers,
                                               pin_memory = True)

    data_transforms_val = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])])   

    val_dataset = RafDataset(phase = 'test', transform = data_transforms_val)   

    print('Validation set size:', val_dataset.__len__())
    
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                               batch_size = batch_size,
                                               num_workers = workers,
                                               shuffle = False,  
                                               pin_memory = True)

    criterion_cls = torch.nn.CrossEntropyLoss()
    criterion_af = AffinityLoss(device)
    criterion_pt = PartitionLoss()

    params = list(model.parameters()) + list(criterion_af.parameters())
    optimizer = torch.optim.SGD(params,lr=lr, weight_decay = 1e-4, momentum=0.9)
    # optimizer = torch.optim.AdamW(params,lr=lr, 
                                  # weight_decay = 1e-4, amsgrad=True)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)


    best_acc = 0
    for epoch in tqdm(range(1, epochs + 1)):
        running_loss = 0.0
        correct_sum = 0
        iter_cnt = 0
        model.train()

        for (imgs, targets) in train_loader:
            iter_cnt += 1
            optimizer.zero_grad()

            imgs = imgs.to(device)
            targets = targets.to(device)
            
            out,feat,heads = model(imgs)

            loss = criterion_cls(out,targets) + 1* criterion_af(feat,targets) + 1*criterion_pt(heads)  #89.3 89.4

            loss.backward()
            optimizer.step()
            
            running_loss += loss
            _, predicts = torch.max(out, 1)
            correct_num = torch.eq(predicts, targets).sum()
            correct_sum += correct_num

        acc = correct_sum.float() / float(train_dataset.__len__())
        running_loss = running_loss/iter_cnt
        tqdm.write('[Epoch %d] Training accuracy: %.4f. Loss: %.3f. LR %.6f' % (epoch, acc, running_loss,optimizer.param_groups[0]['lr']))
        
        with torch.no_grad():
            running_loss = 0.0
            iter_cnt = 0
            bingo_cnt = 0
            sample_cnt = 0
            
            ## for calculating balanced accuracy
            y_true = []
            y_pred = []

            model.eval()
            for (imgs, targets) in val_loader:
                imgs = imgs.to(device)
                targets = targets.to(device)
                
                out,feat,heads = model(imgs)
                loss = criterion_cls(out,targets) + criterion_af(feat,targets) + criterion_pt(heads)

                running_loss += loss
                iter_cnt+=1
                _, predicts = torch.max(out, 1)
                correct_num  = torch.eq(predicts,targets)
                bingo_cnt += correct_num.sum().cpu()
                sample_cnt += out.size(0)
                
                y_true.append(targets.cpu().numpy())
                y_pred.append(predicts.cpu().numpy())
        
            running_loss = running_loss/iter_cnt   
            scheduler.step()

            acc = bingo_cnt.float()/float(sample_cnt)
            acc = np.around(acc.numpy(),4)
            best_acc = max(acc,best_acc)

            y_true = np.concatenate(y_true)
            y_pred = np.concatenate(y_pred)
            balanced_acc = np.around(balanced_accuracy_score(y_true, y_pred),4)

            tqdm.write("[Epoch %d] Validation accuracy:%.4f. bacc:%.4f. Loss:%.3f" % (epoch, acc, balanced_acc, running_loss))
            tqdm.write("best_acc:" + str(best_acc))

            if acc > 0.89 and acc == best_acc:
                torch.save({'iter': epoch,
                            'model_state_dict': model.state_dict(),
                             'optimizer_state_dict': optimizer.state_dict(),},
                            os.path.join('checkpoints', "rafdb_epoch"+str(epoch)+"_acc"+str(acc)+"_bacc"+str(balanced_acc)+".pth"))
                tqdm.write('Model saved.')


In [None]:
bbone = torch_models.resnet18()
bbone.load_state_dict(get_backbone_dict(backbone_pathes['ResNet18'],
                                        resnet18 = True), strict=True)
feature_extractor = nn.Sequential(*list(bbone.children())[:-2])
resnet18DAN = DAN(bbone = feature_extractor)

In [None]:
bbone = MobileFaceNet(*bbones_cfgs['MobileFaceNet'].values())
bbone.load_state_dict(get_backbone_dict(backbone_pathes['MobileFaceNet']))
feature_extractor = nn.Sequential(*list(bbone.children())[:-4])
resnet18DAN = DAN(bbone = feature_extractor)

In [None]:
trainin_DAN(resnet18DAN, batch_size=128)

DACL experements
https://github.com/amirhfarzaneh/dacl

In [None]:
# DACL imports
from dacl.loss import SparseCenterLoss
from dacl.utils import Logger, AverageMeter, accuracy, calc_metrics, RandomFiveCrop
sys.path.append('dacl')
#torchvision.models.utils -> torch.hub
from dacl.models.resnet import resnet18

In [None]:
#@title DACL modified train loop for for any model https://github.com/amirhfarzaneh/dacl/blob/f1d5aad84650831f70ad82e41919f8d092ce0411/main.py#L311 

def train_DACL(cfg, feat_size=512, train_set = None, val_set = None):

    global device
    if torch.cuda.is_available():
        device = torch.device('cuda')
        cudnn.benchmark = True
    else:
        device = torch.device('cpu')

    if cfg['deterministic']:
        random.seed(cfg['seed'])
        torch.manual_seed(cfg['seed'])
        cudnn.deterministic = True
        cudnn.benchmark = False

    # Loading RAF-DB
    # -----------------
    print('[>] Loading dataset '.ljust(64, '-'))

    # validation set
    val_loader = torch.utils.data.DataLoader(
        train_set,
        batch_size=cfg['batch_size'], shuffle=True,
        num_workers=cfg['workers'], pin_memory=True)

    train_loader = torch.utils.data.DataLoader(train_set,
                                               batch_size=cfg['batch_size'],
                                              #  sampler=train_sampler,
                                               shuffle=True,
                                               num_workers=cfg['workers'],
                                               pin_memory=True)

    test_loader = torch.utils.data.DataLoader(val_set, batch_size=cfg['batch_size'],
                                              shuffle=False,
                                              num_workers=cfg['workers'],
                                              pin_memory=True)

    print('[*] Loaded dataset!')

    # Create Model
    # ------------
    print('[>] Model '.ljust(64, '-'))
    if cfg['model'] == 'resnet18':
        feat_size = 512
        if not cfg['pretrained'] == '':
            model = resnet18(pretrained=cfg['pretrained'])
            model.fc = nn.Linear(feat_size, 7)
        else:
            print('[!] model is trained from scratch!')
            model = resnet18(num_classes=7, pretrained=cfg['pretrained'])
    else:
        model = cfg['model']
    model = torch.nn.DataParallel(model).to(device)
    print('[*] Model initialized!')

    # define loss function (criterion) and optimizer
    # ----------------------------------------------
    criterion = {
        'softmax': nn.CrossEntropyLoss().to(device),
        'center': SparseCenterLoss(7, feat_size).to(device)
    }
    optimizer = {
        'softmax': torch.optim.SGD(model.parameters(), cfg['lr'],
                                   momentum=cfg['momentum'],
                                   weight_decay=cfg['weight_decay']),
        'center': torch.optim.SGD(criterion['center'].parameters(), cfg['alpha'])
    }
    # lr scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer['softmax'], step_size=20, gamma=0.1)

    # training and evaluation
    # -----------------------
    global best_valid
    best_valid = dict.fromkeys(['acc', 'rec', 'f1', 'aucpr', 'aucroc'], 0.0)

    print('[>] Begin Training '.ljust(64, '-'))
    for epoch in range(1, cfg['epochs'] + 1):

        start = time.time()

        # train for one epoch
        train(train_loader, model, criterion, optimizer, epoch, cfg)
        # validate for one epoch
        validate(val_loader, model, criterion, epoch, cfg)

        # progress
        end = time.time()
        progress = (
            f'[*] epoch time = {end - start:.2f}s | '
            f'lr = {optimizer["softmax"].param_groups[0]["lr"]}\n'
        )
        print(progress)

        # lr step
        scheduler.step()

    # best valid info
    # ---------------
    print('[>] Best Valid '.ljust(64, '-'))
    stat = (
        f'[+] acc={best_valid["acc"]:.4f}\n'
        f'[+] rec={best_valid["rec"]:.4f}\n'
        f'[+] f1={best_valid["f1"]:.4f}\n'
        f'[+] aucpr={best_valid["aucpr"]:.4f}\n'
        f'[+] aucroc={best_valid["aucroc"]:.4f}'
    )
    print(stat)


def train(train_loader, model, criterion, optimizer, epoch, cfg):
    losses = {
        'softmax': AverageMeter(),
        'center': AverageMeter(),
        'total': AverageMeter()
    }
    accs = AverageMeter()
    y_pred, y_true, y_scores = [], [], []

    # switch to train mode
    model.train()

    with tqdm(total=int(len(train_loader.dataset) / cfg['batch_size'])) as pbar:
        for i, (images, target) in enumerate(train_loader):

            images = images.to(device)
            target = target.to(device)

            # compute output
            feat, output, A = model(images)
            l_softmax = criterion['softmax'](output, target)
            l_center = criterion['center'](feat, A, target)
            l_total = l_softmax + cfg['lamb'] * l_center

            # measure accuracy and record loss
            acc, pred = accuracy(output, target)
            losses['softmax'].update(l_softmax.item(), images.size(0))
            losses['center'].update(l_center.item(), images.size(0))
            losses['total'].update(l_total.item(), images.size(0))
            accs.update(acc.item(), images.size(0))

            # collect for metrics
            y_pred.append(pred)
            y_true.append(target)
            y_scores.append(output.data)

            # compute grads + opt step
            optimizer['softmax'].zero_grad()
            optimizer['center'].zero_grad()
            l_total.backward()
            optimizer['softmax'].step()
            optimizer['center'].step()

            # progressbar
            pbar.set_description(f'TRAINING [{epoch:03d}/{cfg["epochs"]}]')
            pbar.set_postfix({'L': losses["total"].avg,
                              'Ls': losses["softmax"].avg,
                              'Lsc': losses["center"].avg,
                              'acc': accs.avg})
            pbar.update(1)

    metrics = calc_metrics(y_pred, y_true, y_scores)
    progress = (
        f'[-] TRAIN [{epoch:03d}/{cfg["epochs"]}] | '
        f'L={losses["total"].avg:.4f} | '
        f'Ls={losses["softmax"].avg:.4f} | '
        f'Lsc={losses["center"].avg:.4f} | '
        f'acc={accs.avg:.4f} | '
        f'rec={metrics["rec"]:.4f} | '
        f'f1={metrics["f1"]:.4f} | '
        f'aucpr={metrics["aucpr"]:.4f} | '
        f'aucroc={metrics["aucroc"]:.4f}'
    )
    print(progress)
    write_log(losses, accs.avg, metrics, epoch, tag='train')
    wandb.log({"tL": losses["total"].avg})
    wandb.log({"tLs": losses["softmax"].avg})
    wandb.log({"tLsc": losses["center"].avg})
    wandb.log({"tacc": accs.avg})
    wandb.log({"trec": metrics["rec"]})
    wandb.log({"tf1": metrics["f1"]})



def validate(valid_loader, model, criterion, epoch, cfg):
    losses = {
        'softmax': AverageMeter(),
        'center': AverageMeter(),
        'total': AverageMeter()
    }
    accs = AverageMeter()
    y_pred, y_true, y_scores = [], [], []

    # switch to evaluate mode
    model.eval()

    with tqdm(total=int(len(valid_loader.dataset) / cfg['batch_size'])) as pbar:
        with torch.no_grad():
            for i, (images, target) in enumerate(valid_loader):

                images = images.to(device)
                target = target.to(device)
                # compute output
                feat, output, A = model(images)
                l_softmax = criterion['softmax'](output, target)
                l_center = criterion['center'](feat, A, target)
                l_total = l_softmax + cfg['lamb'] * l_center

                # measure accuracy and record loss
                acc, pred = accuracy(output, target)
                losses['softmax'].update(l_softmax.item(), images.size(0))
                losses['center'].update(l_center.item(), images.size(0))
                losses['total'].update(l_total.item(), images.size(0))
                accs.update(acc.item(), images.size(0))

                # collect for metrics
                y_pred.append(pred)
                y_true.append(target)
                y_scores.append(output.data)

                # progressbar
                pbar.set_description(f'VALIDATING [{epoch:03d}/{cfg["epochs"]}]')
                pbar.update(1)

    metrics = calc_metrics(y_pred, y_true, y_scores)
    progress = (
        f'[-] VALID [{epoch:03d}/{cfg["epochs"]}] | '
        f'L={losses["total"].avg:.4f} | '
        f'Ls={losses["softmax"].avg:.4f} | '
        f'Lsc={losses["center"].avg:.4f} | '
        f'acc={accs.avg:.4f} | '
        f'rec={metrics["rec"]:.4f} | '
        f'f1={metrics["f1"]:.4f} | '
        f'aucpr={metrics["aucpr"]:.4f} | '
        f'aucroc={metrics["aucroc"]:.4f}'
    )
    print(progress)

    # save model checkpoints for best valid
    if accs.avg > best_valid['acc']:
        save_checkpoint(epoch, model, cfg, tag='best_valid_acc.pth')
        # wandb.sklearn.plot_confusion_matrix(y_true, y_pred, 
          # ['Surprise', 'Fear', 'Disgust', 'Happiness', 'Sadness', 'Anger', 'Neutral'])
    if metrics['rec'] > best_valid['rec']:
        save_checkpoint(epoch, model, cfg, tag='best_valid_rec.pth')

    best_valid['acc'] = max(best_valid['acc'], accs.avg)
    best_valid['rec'] = max(best_valid['rec'], metrics['rec'])
    best_valid['f1'] = max(best_valid['f1'], metrics['f1'])
    best_valid['aucpr'] = max(best_valid['aucpr'], metrics['aucpr'])
    best_valid['aucroc'] = max(best_valid['aucroc'], metrics['aucroc'])
    write_log(losses, accs.avg, metrics, epoch, tag='valid')
    wandb.log({"vL": losses["total"].avg})
    wandb.log({"vLs": losses["softmax"].avg})
    wandb.log({"vLsc": losses["center"].avg})
    wandb.log({"vacc": accs.avg})
    wandb.log({"vrec": metrics["rec"]})
    wandb.log({"vf1": metrics["f1"]})



def write_log(losses, acc, metrics, e, tag='set'):
    # tensorboard
    writer.add_scalar(f'L_softmax/{tag}', losses['softmax'].avg, e)
    writer.add_scalar(f'L_center/{tag}', losses['center'].avg, e)
    writer.add_scalar(f'L_total/{tag}', losses['total'].avg, e)
    writer.add_scalar(f'acc/{tag}', acc, e)
    writer.add_scalar(f'rec/{tag}', metrics['rec'], e)
    writer.add_scalar(f'f1/{tag}', metrics['f1'], e)
    writer.add_scalar(f'aucpr/{tag}', metrics['aucpr'], e)
    writer.add_scalar(f'aucroc/{tag}', metrics['aucroc'], e)


def save_checkpoint(epoch, model, cfg, tag):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
    }, os.path.join(cfg['save_path'], tag))

In [None]:
#@title DACL experement configuration. Use cfg['model'] = model to set you model (out_features=7). Don't forget to fix input resizing!

cfg = {
            # model
            'model' : 'resnet18', #_your_model_here_ or 'resnet18' fro original
            # global config
            'log_dir': "logs",
            'save_path' : '',
            'deterministic': False,
            'seed': 0,
            'workers': 4,
            # training config
            'arch': 'resnet18',
            'lr': 0.01,
            'momentum': 0.9,
            'weight_decay': 0.0005,
            'batch_size': 64,
            'epochs': 60,
            'pretrained': "msceleb",
            # centerloss config
            'alpha': 0.5,
            'lamb': 0.01
        }

current_time = datetime.datetime.now().strftime('%b%d_%H-%M-%S')
logname = current_time + '_' + socket.gethostname() + \
      (
            f'_a-centerloss'
            f'_ar={cfg["arch"]}'
            f'_pt={cfg["pretrained"]}'
            f'_bs={cfg["batch_size"]}'
            f'_lr={cfg["lr"]}'
            f'_wd={cfg["weight_decay"]}'
            f'_alpha={cfg["alpha"]}'
            f'_lamb={cfg["lamb"]}'
        )
cfg['save_path'] = os.path.join(cfg['log_dir'], logname)
 
  # Create a new directory because it does not exist 
os.makedirs(cfg['save_path'])


normalize = transforms.Normalize(mean=[0.5752, 0.4495, 0.4012],
                                     std=[0.2086, 0.1911, 0.1827]) ## rafdb
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
train_set = RafDataset(phase='train', 
              transform = transforms.Compose([
                          transforms.ToPILImage(),
                          transforms.Resize(256),
                          RandomFiveCrop(224),
                          transforms.RandomHorizontalFlip(),
                          transforms.ToTensor(),
                          normalize,
                      ]))

val_set = RafDataset(phase='val', 
              transform=transforms.Compose([
                        transforms.ToPILImage(),
                        transforms.Resize(256),
                        transforms.CenterCrop(224),
                        transforms.ToTensor(),
                        normalize,
                    ]))

global writer
writer = SummaryWriter(cfg['log_dir'])
sys.stdout = Logger(os.path.join(cfg['log_dir'], 'log.log'))

Run ResNet18 train

In [None]:
# run original implementation
# !python dacl/main.py 

# or
cfg['model'] = 'resnet18' # set as default

In [None]:
# from dacl.models import resnet18
# a = resnet18(pretrained='msceleb').cpu().eval()
# a.forward(torch.rand((1,3,224,224)))

In [None]:
# @title DACL MobileFaceNet implementation 

class MobileFaceNet_DACL(nn.Module):
    def __init__(self, num_classes=7):
        super(MobileFaceNet_DACL, self).__init__()
        self.bbone = MobileFaceNet(*bbones_cfgs['MobileFaceNet'].values())
        self.bbone.load_state_dict(get_backbone_dict(backbone_pathes['MobileFaceNet']))
        self.feature_extractor = nn.Sequential(*list(self.bbone.children())[:-4])
        self.nb_head = 512
        # cfg['model'] = 
        # nn.Sequential(bbone, torch.nn.Linear(feat_size, 7))#####

        self.attention = nn.Sequential(
            nn.Linear(512 * 7 * 7, 3584),
            nn.BatchNorm1d(3584),
            nn.ReLU(inplace=True),
            nn.Linear(3584, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Linear(512, 64),
            nn.BatchNorm1d(64),
            nn.Tanh(),
        )
        self.attention_heads = nn.Linear(64, 2 * self.nb_head)
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512, num_classes)

    def forward(self, x):

        x = self.feature_extractor(x)

        x_flat = torch.flatten(x, 1)
        E = self.attention(x_flat)
        A = self.attention_heads(E).reshape(-1, 512, 2).softmax(dim=-1)[:, :, 1]

        x = self.avgpool(x)
        f = torch.flatten(x, 1)
        out = self.fc(f)

        return f, out, A


MobileFaceNet model for DACL train

In [None]:
cfg['model'] = MobileFaceNet_DACL()

In [None]:
import argparse
import random
import pprint
import time
import sys
import os
from tqdm import tqdm
from datetime import timedelta

from torchvision.transforms.transforms import ToPILImage

import torch
import torch.optim
import torch.nn as nn
import torch.utils.data
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torchvision.datasets as datasets
import torchvision.transforms as transforms

wandb.login(key=wandb_api_key)
run = wandb.init(settings=wandb.Settings(start_method="thread"),
                    entity='mangaboba', project='FER')
run.name = 'OriginalModel'
# -----------------
start = time.time()
train_DACL(cfg, 512, train_set, val_set)
end = time.time()
# -----------------
run.finish()

print('\n[*] Fini! '.ljust(64, '-'))
print(f'[!] total time = {timedelta(seconds=end - start)}s')
sys.stdout.flush()

In [None]:
# load DACL original model weights example
testmodel = resnet18(pretrained=cfg['pretrained'])
testmodel.fc = nn.Linear(512, 7)
d = torch.load('/content/logs/Jun21_07-41-22_9f2bbcb9e53b_a-centerloss_ar=resnet18_pt=msceleb_bs=64_lr=0.01_wd=0.0005_alpha=0.5_lamb=0.01/best_valid_acc.pth')['model_state_dict']
d = {key: val for key, val in d.items() if key.startswith('module')}
valid_dict = {key.split("module.",1)[1]: val for key, val in d.items()}
testmodel.load_state_dict(valid_dict)
originalDACL = testmodel

<All keys matched successfully>

Make jit models

In [None]:
import torch
import torchvision
# model = torchvision.models.resnet34(pretrained = True)
# model.eval()
def make_jit_models(model, file_prefix, inp_size):
    sample = torch.rand(1, 3, inp_size, inp_size)
    model.to('cuda:0')
    model.eval()
    scripted_model = torch.jit.trace(model, sample.to('cuda:0'))
    scripted_model.save(file_prefix+'_cuda.pth')
    model.to('cpu')
    scripted_model = torch.jit.trace(model, sample)
    scripted_model.save(file_prefix+'_cpu.pth')

In [None]:
make_jit_models(model = cfg['model'], file_prefix = 'MobibeDACL', inp_size = 112)

In [None]:
make_jit_models(model = originalDACL, file_prefix='originalDACL', inp_size=224)

Models Validation

In [None]:
#@title Validation util for jit_models 
class GrayscaleExpander:
    """
    Transformation eval with grayscale images
    Duplicates channel 3 times
    Gets: (1, N, N) grayscale
    Returns: (3, N, N) grayscale
    """

    def __init__(self):
        pass

    def __call__(self, image: torch.tensor) -> torch.tensor:
        return torch.cat([image, image, image], dim=0)

class validator:

    def __init__(self, batch_size = 64):
        self.batch_size = batch_size
        self.device = 'cuda:0'
        if not torch.cuda.is_available():
            raise RuntimeError('Cuda not available.')
        self.class_names = ['Surprise', 
                            'Fear', 
                            'Disgust', 
                            'Happiness', 
                            'Sadness', 
                            'Anger', 
                            'Neutral']
    
    def computeTime(self, model, inp_size, device='cuda'):
        inputs = torch.randn(1, 3, inp_size, inp_size)
        if device == 'cuda':
            inputs = inputs.cuda()

        time.sleep(1)

        i = 0
        time_spent = []
        while i < 300:
            start_time = time.time()
            with torch.no_grad():
                _ = model(inputs)

            if device == 'cuda':
                torch.cuda.synchronize()  # wait for cuda to finish (cuda is asynchronous!)
            if i != 0:
                time_spent.append(time.time() - start_time)
            i += 1
        # print('Avg execution time (ms): {:.3f}'.format(np.mean(time_spent)))
        return np.mean(time_spent)
    
    def plot_cf_matrix(self, y_true, y_pred, filename = "conf_mtx"):
        cfm = sklearn.metrics.confusion_matrix(torch.cat(y_true).cpu().numpy(),
                                               torch.cat(y_pred).cpu().numpy(),
                                               normalize = 'all')
        df_cfm = pd.DataFrame(cfm, index = self.class_names, 
                              columns = self.class_names)
        
        df_cfmn = df_cfm.astype('float') / df_cfm.sum(axis=1)[:, np.newaxis]
        fig, ax = plt.subplots(figsize=(10,10))
        
        fig = plt.figure(figsize = (15,15))
        
        cfm_plot = sn.heatmap(df_cfmn, annot=True, fmt='.2f', 
                              xticklabels=self.class_names,
                              yticklabels=self.class_names)
        plt.ylabel('Actual')
        plt.xlabel('Predicted')
        cfm_plot.figure.savefig(filename+".png")

    def make_dataloaders(self, inp_size, model_class):
        if model_class == 'DACL':
            transform_raf=transforms.Compose([
                          transforms.ToPILImage(),
                          transforms.Resize(int(inp_size*(256/224))),
                          transforms.CenterCrop(inp_size),
                          transforms.ToTensor(),
                          transforms.Normalize(mean=[0.485, 0.456, 0.406],
                              std=[0.229, 0.224, 0.225]),
                      ])
            # transform_fer = transforms.Compose([
            #                 transforms.Resize(int(inp_size*(256/224))),
            #                 transforms.Resize(inp_size),
            #                 transforms.ToTensor(),
            #                 GrayscaleExpander()]
            #                 )
            
        elif model_class == 'RUL':
            transform_raf = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((inp_size,inp_size)),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                        std=[0.229, 0.224, 0.225])])
        
        dataset_1 = RafDataset(phase = 'test', transform=transform_raf)
        data_loader_1 = torch.utils.data.DataLoader(  
                dataset_1,
                batch_size=self.batch_size, shuffle=True)
        
        
        return [data_loader_1]

    def validate(self, model_pth, model_class, inp_size):
        dataloaders = self.make_dataloaders(inp_size, model_class)
        for dl in dataloaders:
          if model_class == 'DACL':
              self.validate_DACL(dl, model_class, model_pth, inp_size)
          elif model_class == 'RUL':
              self.validate_RUL(dl, model_class, model_pth, inp_size)


    def validate_DACL(self, dl, model_class, model_pth, inp_size):
        device = 'cuda:0'
        model = torch.jit.load(model_pth)
        accs = AverageMeter()
        y_pred, y_true, y_scores = [], [], []

        # switch to evaluate mode
        model.eval()
        losses = {
              'softmax': AverageMeter(),
              'center': AverageMeter(),
              'total': AverageMeter()
              }
        criterion = {
            'softmax': nn.CrossEntropyLoss().to(device),
            'center': SparseCenterLoss(7, 512).to(device)
            }
        with tqdm(total=int(len(dl.dataset)/self.batch_size)) as pbar:
            with torch.no_grad():
                for i, (images, target) in enumerate(dl):

                    images = images.to(device)
                    target = target.to(device)

                    # compute output
                    feat, output, A = model(images.to('cuda:0'))
                    l_softmax = criterion['softmax'](output, target)
                    l_center = criterion['center'](feat, A, target)
                    l_total = l_softmax + (0.01 * l_center)
                    # measure accuracy and record loss
                    acc, pred = accuracy(output, target)
                    losses['softmax'].update(l_softmax.item(), images.size(0))
                    losses['center'].update(l_center.item(), images.size(0))
                    losses['total'].update(l_total.item(), images.size(0))
                    accs.update(acc.item(), images.size(0))
                    # collect for metrics
                    y_pred.append(pred)
                    y_true.append(target)
                    y_scores.append(output.data)

        metrics = calc_metrics(y_pred, y_true, y_scores)
        avg_exec_time = self.computeTime(model, inp_size)
        progress = (
            f'{model_class}-{model_pth}: | '
            f'L={losses["total"].avg:.4f} | '
            f'Ls={losses["softmax"].avg:.4f} | '
            f'Lsc={losses["center"].avg:.4f} | '
            f'acc={accs.avg:.4f} | '
            f'rec={metrics["rec"]:.4f} | '
            f'f1={metrics["f1"]:.4f} | '
            f'Evaluation time={avg_exec_time:4f} |'
        )
        self.plot_cf_matrix(y_true, y_pred, model_pth)
        print(progress)

    def validate_RUL(self, dl, model_class, model_pth, inp_size):
        device = 'cuda:0'
        model = torch.jit.load(model_pth)
        accs = AverageMeter()
        y_pred, y_true, y_scores = [], [], []

        # switch to evaluate mode
        model.eval()
        losses = {
              'softmax': AverageMeter(),
              }
        criterion = {
            'softmax': nn.CrossEntropyLoss().to(device)
            }
        with tqdm(total=int(len(dl.dataset)/self.batch_size)) as pbar:
            with torch.no_grad():
                for i, (images, target) in enumerate(dl):

                    images = images.to(device)
                    target = target.to(device)

                    # compute output
                    output = model(images.to('cuda:0'))
                    l_softmax = criterion['softmax'](output, target)
                    # measure accuracy and record loss
                    acc, pred = accuracy(output, target)
                    losses['softmax'].update(l_softmax.item(), images.size(0))
                    accs.update(acc.item(), images.size(0))
                    # collect for metrics
                    y_pred.append(pred)
                    y_true.append(target)
                    y_scores.append(output.data)
            pbar.set_description("Processing ")
        metrics = calc_metrics(y_pred, y_true, y_scores)
        avg_exec_time = self.computeTime(model, inp_size)
        progress = (
            f'{model_class}-{model_pth}: | '
            f'Ls={losses["softmax"].avg:.4f} | '
            f'acc={accs.avg:.4f} | '
            f'rec={metrics["rec"]:.4f} | '
            f'f1={metrics["f1"]:.4f} | '
            f'Evaluation time={avg_exec_time:4f} |'
        )
        self.plot_cf_matrix(y_true, y_pred, model_pth)
        print(progress)



In [None]:
jit_location = '/content/jit_models'
toval = [( jit_location + '/ResnetRUL_cuda.pth','RUL', 224),
         ( jit_location + '/GhostRUL_cuda.pth','RUL', 112),
         ( jit_location + '/MobileRUL_cuda.pth','RUL', 112),
         ( jit_location + '/originalDACL_cuda.pth', 'DACL', 224),
         ( jit_location + '/MobibeDACL_cuda.pth', 'DACL', 112),

]                                                   
val = validator()
for m in toval:
  val.validate(*m)

Further experements

In [None]:
#@title Some code for knowledge distillation experements
# !pip install KD_lib

# import torch
# import torch.optim as optim
# from torchvision import datasets, transforms
# from KD_Lib.KD import VanillaKD

# # Define datasets, dataloaders, models and optimizers
# train_dataset = RafDataset(phase='train', transform=test_tf)
# test_dataset = RafDataset(phase='test', transform=test_tf)
# teacher_loader = torch.utils.data.DataLoader(train_dataset,
#                               # sampler=ImbalancedDatasetSampler(student_dataset),
#                                                batch_size = 128,
#                                                num_workers = 2,
#                                                pin_memory = True)
# student_loader = torch.utils.data.DataLoader(train_dataset,
#                               # sampler=ImbalancedDatasetSampler(student_dataset),
#                                                batch_size = 128,
#                                                num_workers = 2,
#                                                pin_memory = True)


# teacher_model = # Here shoud be resnet18 DACL model
# student_model = # Here sould be resnet9 DACL model

# teacher_optimizer = optim.SGD(teacher_model.parameters(), 0.01)
# student_optimizer = optim.SGD(student_model.parameters(), 0.01)

# # Now, this is where KD_Lib comes into the picture

# distiller = VanillaKD(teacher_model, student_model, teacher_loader, student_loader,
#                       teacher_optimizer, student_optimizer)
# # distiller.train_teacher(epochs=5, plot_losses=True, save_model=True)    # Train the teacher network
# distiller.train_student(epochs=5, plot_losses=True, save_model=True)    # Train the student network
# distiller.evaluate(teacher=False)                                       # Evaluate the student network
# distiller.get_parameters()       

In [None]:
#@title Tain Swin Transformer
def train_swin(swin, batch_size = 256, 
                 workers = 2, lr = 0.1, epochs = 40):
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.enabled = True

    model = swin
    model.to(device)

    data_transforms = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        
        transforms.RandomApply([
                transforms.RandomRotation(20),
                transforms.RandomCrop(224, padding=32)
            ], p=0.3),
        
        transforms.RandomApply([
                transforms.ColorJitter(brightness=0.05, contrast=0.05, 
                                      saturation=0.05, hue=0.05)
            ], p=0.3),
        transforms.RandomApply([
                transforms.RandomAdjustSharpness(sharpness_factor=2),
            ], p=0.3),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225]),
        transforms.RandomErasing(scale=(0.02,0.125)),
        ])

    # data_transforms = transforms.Compose([
    #     transforms.Resize((224, 224)),
    #     transforms.RandomHorizontalFlip(),
    #     transforms.RandomApply([
    #             transforms.RandomRotation(20),
    #             transforms.RandomCrop(224, padding=32)
    #         ], p=0.3),
    #     transforms.RandomApply([
    #             transforms.ColorJitter(brightness=0.05, contrast=0.05, 
    #                                   saturation=0.05, hue=0.025)
    #         ], p=0.3),
        
    #     transforms.ToTensor(),
    #     transforms.Normalize(mean=[0.485, 0.456, 0.406],
    #                              std=[0.229, 0.224, 0.225]),
    #     transforms.RandomErasing(scale=(0.02,0.25)),
    #     ])
    
    train_dataset = RafDataset(phase = 'train', transform = data_transforms)    
    
    print('Whole train set size:', train_dataset.__len__())

    train_loader = torch.utils.data.DataLoader(train_dataset,
                              # sampler=ImbalancedDatasetSampler(train_dataset),
                                               shuffle = True,  
                                               batch_size = batch_size,
                                               num_workers = workers,
                                               pin_memory = True)

    data_transforms_val = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])])   

    val_dataset = RafDataset(phase = 'test', transform = data_transforms_val)   

    print('Validation set size:', val_dataset.__len__())
    
    val_loader = torch.utils.data.DataLoader(val_dataset,
                                               batch_size = batch_size,
                                               num_workers = workers,
                                               shuffle = False,  
                                               pin_memory = True)

    criterion = torch.nn.CrossEntropyLoss()

    params = model.parameters()
    optimizer = torch.optim.SGD(params,lr=lr, weight_decay = 1e-4, momentum=0.9)
    # optimizer = torch.optim.AdamW(params,lr=lr, 
                                  # weight_decay = 1e-4, amsgrad=True)
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)


    best_acc = 0
    for epoch in tqdm(range(1, epochs + 1)):
        running_loss = 0.0
        correct_sum = 0
        iter_cnt = 0
        model.train()

        for (imgs, targets) in train_loader:
            iter_cnt += 1
            optimizer.zero_grad()

            imgs = imgs.to(device)
            targets = targets.to(device)
            
            out = model(imgs)

            loss = criterion(out,targets)
            
            loss.backward()
            optimizer.step()
            
            running_loss += loss
            _, predicts = torch.max(out, 1)
            correct_num = torch.eq(predicts, targets).sum()
            correct_sum += correct_num

        acc = correct_sum.float() / float(train_dataset.__len__())
        running_loss = running_loss/iter_cnt
        tqdm.write('[Epoch %d] Training accuracy: %.4f. Loss: %.3f. LR %.6f' % (epoch, acc, running_loss,optimizer.param_groups[0]['lr']))
        
        with torch.no_grad():
            running_loss = 0.0
            iter_cnt = 0
            bingo_cnt = 0
            sample_cnt = 0
            
            ## for calculating balanced accuracy
            y_true = []
            y_pred = []

            model.eval()
            for (imgs, targets) in val_loader:
                imgs = imgs.to(device)
                targets = targets.to(device)
                
                out = model(imgs)
                loss = criterion(out,targets) 
                
                running_loss += loss
                iter_cnt+=1
                _, predicts = torch.max(out, 1)
                correct_num  = torch.eq(predicts,targets)
                bingo_cnt += correct_num.sum().cpu()
                sample_cnt += out.size(0)
                
                y_true.append(targets.cpu().numpy())
                y_pred.append(predicts.cpu().numpy())
        
            running_loss = running_loss/iter_cnt   
            scheduler.step()

            acc = bingo_cnt.float()/float(sample_cnt)
            acc = np.around(acc.numpy(),4)
            best_acc = max(acc,best_acc)

            y_true = np.concatenate(y_true)
            y_pred = np.concatenate(y_pred)
            balanced_acc = np.around(balanced_accuracy_score(y_true, y_pred),4)

            tqdm.write("[Epoch %d] Validation accuracy:%.4f. bacc:%.4f. Loss:%.3f" % (epoch, acc, balanced_acc, running_loss))
            tqdm.write("best_acc:" + str(best_acc))

            if acc > 0.86 and acc == best_acc:
                torch.save({'iter': epoch,
                            'model_state_dict': model.state_dict(),
                             'optimizer_state_dict': optimizer.state_dict(),},
                            os.path.join('/home/mangaboba/environ/fer/checkpoints', "rafdb_epoch"+'_swin'+str(epoch)+"_acc"+str(acc)+"_bacc"+str(balanced_acc)+".pth"))
                tqdm.write('Model saved.')
