In [None]:
import numpy as np
import os, json, time, math
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torch.utils.data as data
import torchvision.models as models
import torchvision.transforms as transforms
from einops import rearrange

In [None]:
# global constants
training_size = 28317 + 3541    # training samples
image_size = 48                 # image size
H, W = image_size, image_size   # image height, image weight
h, w = 3, 3     # number of grids
C = 3           # number of channels
D = 128         # token dimension
K = 8           # number of output classes
L = 12          # VTA depth
N = 3           # number of tokens generated except Tcls
heads = 8       # heads of MHSA
dim_head = 64   # dimension of the project key vector
dropout = 0.2   # dropout rate

In [None]:
# device setup

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1' 
if torch.cuda.device_count() > 0:  
    print("Using %d GPU(s)" % torch.cuda.device_count())
else:
    print("Using CPU")
cudnn.benchmark = True
cudnn.enabled = True

In [None]:
# data loading & data processing

def load_data(image_path, emotion_path, subset):
    images = np.load(image_path)        
    images = images/255.0
    images = np.float32(images)         # shape = (35393, 48, 48, 1)

    emotions = np.load(emotion_path)    
    emotions = np.float32(emotions)     # shape = (35393, 8)

    if subset == 'train':
        return images[:training_size], emotions[:training_size]
    if subset == 'test':
        return images[training_size:], emotions[training_size:]

class FERPlusDataset(data.Dataset):
    def __init__(self, image_path, emotion_path, subset):
        assert(subset=='train' or subset=='test')
        self.images, self.emotions = load_data(image_path, emotion_path, subset)

    def __getitem__(self, index):
        image = self.images[index]
        image = self.resize(image)
        emotion = self.emotions[index]
        return image, emotion

    def __len__(self):
        return len(self.images)
    
    def resize(self, img):
        img = torch.tensor(img)                     # (48, 48, 1)
        img = torch.reshape(img, (1, 48, 48))       # (1, 48, 48)
        if image_size != 48:
            img = transforms.Resize([H, W])(img)    # (1, H, W)
        img = img.repeat(C, 1, 1)                   # (3, H, W)
        return img

In [None]:
# Visual Transformer architecture

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
        
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class Attention(nn.Module):
    def __init__(self, dim, heads, dim_head):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads==1 and dim_head==dim)
        self.heads = heads
        self.scale = dim_head**-0.5
        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim*3, bias=False)
        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
        ) if project_out else nn.Identity()

    def forward(self, x):
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = self.heads), qkv)
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = self.attend(dots)
        out = torch.matmul(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

class FeedForward(nn.Module):
    def __init__(self, dim, mlp_dim):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Linear(mlp_dim, dim),
        )
    def forward(self, x):
        return self.net(x)
    
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads, dim_head)),
                PreNorm(dim, FeedForward(dim, mlp_dim))
            ]))

    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

In [None]:
# FER-VT architecture

class FER_VT(nn.Module):
    def __init__(self):
        super(FER_VT, self).__init__()
        self.expander = torch.ones((int(H/h), int(W/w)), device=device)
        self.LFN = nn.Sequential(
            nn.Conv2d(
                in_channels = 3,
                out_channels = 64,
                kernel_size = 1,
                stride = 1,
                bias = False
            ),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(negative_slope=0.1),
            nn.Conv2d(
                in_channels = 64,
                out_channels = 3,
                kernel_size = 1,
                stride = 1,
                bias = False
            ),
            nn.BatchNorm2d(3),
            nn.LeakyReLU(negative_slope=0.1)
        )
        self.FT1 = nn.Sequential(
            nn.Conv2d(
                in_channels = 3,
                out_channels = 3,
                kernel_size = 3,
                padding = 1,
                stride = 1,
                bias = False
            ),
            nn.BatchNorm2d(3),
            nn.ReLU()
        )
        self.FT2 = nn.Sequential(
            nn.Conv2d(
                in_channels = 3,
                out_channels = 3,
                kernel_size = 3,
                padding = 1,
                stride = 1,
                bias = False
            ),
            nn.BatchNorm2d(3),
            nn.ReLU()
        )
        self.RFN_lower = nn.Sequential(
            nn.Conv2d(
                in_channels = 3,
                out_channels = 3,
                kernel_size = 3,
                padding = 1,
                stride = 1,
                bias = True
            ),
            nn.BatchNorm2d(3),
            nn.PReLU(),
            nn.Conv2d(
                in_channels = 3,
                out_channels = 3,
                kernel_size = 3,
                padding = 1,
                stride = 1,
                bias = True
            ),
            nn.BatchNorm2d(3),
            nn.PReLU()
        )
        self.RFN_upper = nn.Sequential(
            nn.Conv2d(
                in_channels = 3,
                out_channels = 3,
                kernel_size = 3,
                padding = 1,
                stride = 1,
                bias = False
            ),
            nn.Sigmoid()
        )
        
        self.activation = {}
        self.backbone = models.resnet18()
        self.backbone.layer2.register_forward_hook(self.get_activation('out3'))
        self.backbone.layer3.register_forward_hook(self.get_activation('out2'))
        self.backbone.layer4.register_forward_hook(self.get_activation('out1'))

        self.PTF1 = nn.Sequential(  # (32, 512, 2, 2) -> (32, 512, 2, 2)
            nn.Conv2d(  
                in_channels = 512,
                out_channels = 512,
                kernel_size = 1,
                padding = 0,
                stride = 1,
                bias = False
            ),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(negative_slope=0.1)
        )
        self.PTF2 = nn.Sequential(  # (32, 256, 3, 3) -> (32, 512, 2, 2)
            nn.Conv2d(  
                in_channels = 256,
                out_channels = 512,
                kernel_size = 2,
                padding = 0,
                stride = 1,         # 48: 1; 222: 2
                bias = False
            ),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(negative_slope=0.1)
        )
        self.PTF3 = nn.Sequential(  # (32, 128, 6, 6) -> (32, 512, 2, 2)
            nn.Conv2d(  
                in_channels = 128,  
                out_channels = 512, 
                kernel_size = 3,    # 48: 3; 222: 4
                padding = 0,        # 48: 0; 222: 1
                stride = 3,         # 48: 3; 222: 4
                bias = False
            ),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(negative_slope=0.1)
        )
        self.token_embed = nn.Sequential(
            nn.LayerNorm(512*2*2),
            nn.ReLU(),
            nn.Dropout(p=dropout),
            nn.Linear(512*2*2, D)
        )
        self.Tcls = nn.Parameter(torch.randn((1, 1, D), requires_grad=True, device=device))
        self.Epos = nn.Parameter(torch.randn((1, 1, (N+1)*D), requires_grad=True, device=device))
        self.transformer = Transformer(dim=D, depth=L, heads=heads, dim_head=dim_head, mlp_dim=2*D)
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(D),
            nn.Linear(D, K),
            nn.Softmax(dim=1)
        )

        # model weight initialization
        self.apply(self.init_weights)
        self.RFN_lower.apply(self.init_weights_RFN)
        self.RFN_upper.apply(self.init_weights_RFN)
        for ViT_module in [self.transformer, self.mlp_head]:
            for name, m in ViT_module.named_modules():
                self.init_weights_ViT(m, name)
    
    def init_weights(self, m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            nn.init.kaiming_uniform_(m.weight.data)
            if m.bias is not None:
                nn.init.constant_(m.bias.data, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight.data, 1)
            if m.bias is not None:
                nn.init.constant_(m.bias.data, 0)
    
    def init_weights_RFN(self, m):
        if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
            nn.init.xavier_uniform_(m.weight.data, gain=0.1)
            if m.bias is not None:
                nn.init.constant_(m.bias.data, 0)
        elif isinstance(m, nn.BatchNorm2d):
            nn.init.constant_(m.weight.data, 1)
            if m.bias is not None:
                nn.init.constant_(m.bias.data, 0)

    def init_weights_ViT(self, m, name):
        if isinstance(m, nn.Linear):
            nn.init.constant_(m.weight.data, 0)
            if m.bias is not None:
                nn.init.constant_(m.bias.data, 0)
            # if 'qkv' in name:
            #     val = math.sqrt(6. / float(m.weight.shape[0] // 3 + m.weight.shape[1]))
            #     nn.init.uniform_(m.weight, -val, val)
            # else:
            #     nn.init.xavier_uniform_(m.weight)
            # if m.bias is not None:
            #     nn.init.constant_(m.bias.data, 0)

    def get_activation(self, name):
        def hook(model, input, output):
            self.activation[name] = output.detach()
        return hook

    def forward(self, x):   # x.shape = (B, C, H, W)  
        # 1. Grid-Wise Attention
        # 1.1. Local Feature Extraction
        I = x               # I.shape = (B, C, H, W)
        B = x.shape[0]
        grids = []          # grids: hw x (B, C, H/h, W/w)
        for i in range(h):
            for j in range(w):
                crop = I[:, :, int(i*(H/h)):int((i+1)*(H/h)), int(j*(W/w)):int((j+1)*(W/w))]   
                grids.append(crop)              # crop.shape = (B, C, H/h, W/w)   
        for i in range(len(grids)):
            grids[i] = self.LFN(grids[i])       # grids: hw x (B, C, H/h, W/w)
        
        # 1.2. Grid-Wise Attention Calculation
        d_k = W/w
        query = torch.zeros([h*w, B, C, int(H/h), int(W/w)], dtype=torch.float32, device=device)    
        for i in range(h*w):                    # query.shape = (hw, B, C, H/h, W/w)
            query[i] = grids[i]                 
        key = torch.zeros([h*w, B, C, int(W/w), int(H/h)], dtype=torch.float32, device=device)      
        for i in range(h*w):                    # key.shape = (hw, B, C, W/w, H/h)
            key[i] = torch.transpose(grids[i], -1, -2)
        scores = torch.matmul(query, key)/d_k   # scores.shape = (hw, B, C, H/h, H/h)
        attn = nn.Softmax(dim=1)(scores)        # attn.shape = (hw, B, C, H/h, H/h)
        I_hat = nn.AdaptiveAvgPool2d(1)(attn)   # I_hat.shape = (hw, B, C, 1, 1)
        pattn = self.expander * I_hat           # pattn.shape = (hw, B, C, H/h, W/w)
        I_tilde = torch.zeros([B, C, H, W], dtype=torch.float32, device=device) 
        for i in range(h):                      # I_tilde.shape = (B, C, H, W)
            for j in range(w):
                I_tilde[:,:,int(i*(H/h)):int((i+1)*(H/h)),int(j*(W/w)):int((j+1)*(W/w))] = pattn[i*w+j]
        I_prime_tilde = I_tilde * I             # I_prime_tilde.shape = (B, C, H, W)

        # 1.3. Residual Feature Fusion
        RFN_lower_in = self.FT1(I) + self.FT2(I_prime_tilde)
        RFN_lower_out = self.RFN_lower(RFN_lower_in)
        I_bar = self.RFN_upper(RFN_lower_in + RFN_lower_out)    # I_bar.shape = (B, C, H, W)

        # 2. Backbone Network (ResNet)
        _ = self.backbone(I_bar)                    # I_bar.shape = (32, 3, 48, 48)
        L3_double_prime = self.activation['out3']   # L3_double_prime.shape = (32, 128, 6, 6) 32 128 28 28
        L2_double_prime = self.activation['out2']   # L2_double_prime.shape = (32, 256, 3, 3) 32 256 14 14
        L1_double_prime = self.activation['out1']   # L1_double_prime.shape = (32, 512, 2, 2) 32 512 7 7
        
        # 3. Visual Transformer Attention
        # 3.1. Visual Token Generation 
        # 3.1.1. Pyramid Feature Extraction
        L1_prime = self.PTF1(L1_double_prime)   # L1_prime.shape = (32, 512, 2, 2)
        L2_prime = self.PTF2(L2_double_prime)   # L2_prime.shape = (32, 512, 2, 2)
        L3_prime = self.PTF3(L3_double_prime)   # L3_prime.shape = (32, 512, 2, 2)

        # 3.1.2. Visual Token Embedding
        C1H1W1 = L1_prime.shape[1] * L1_prime.shape[2] * L1_prime.shape[3]  # C1H1W1 = 2048
        L1 = torch.reshape(L1_prime, (B, 1, C1H1W1))    # L1.shape = (32, 1, 2048)
        L2 = torch.reshape(L2_prime, (B, 1, C1H1W1))    # L2.shape = (32, 1, 2048)
        L3 = torch.reshape(L3_prime, (B, 1, C1H1W1))    # L3.shape = (32, 1, 2048)
        T1 = self.token_embed(L1)   # T1.shape = (32, 1, 128)
        T2 = self.token_embed(L2)   # T2.shape = (32, 1, 128)
        T3 = self.token_embed(L3)   # T3.shape = (32, 1, 128)

        # 3.2. Token-based Visual Transformer
        Z0 = torch.concat([self.Tcls.repeat(B, 1, 1), T1, T2, T3], dim=1) + self.Epos.reshape(1, N+1, D).repeat(B, 1, 1)
        ZL = self.transformer(Z0)               # Zi.shape = (32, 4, 128), i = 0, 1, ..., L
        pred_scores = self.mlp_head(ZL[:,0,:])  # pred_scores.shape = (32, 8)
        
        return pred_scores

In [None]:
def accuracy(output, target):
    batch_size = target.size(0)
    acc = 0.0
    for i in range(batch_size):
        true = target[i]
        pred = output[i]
        index_max = torch.argmax(pred)
        if true[index_max] == torch.max(true):
            acc += 1.0
    acc /= batch_size
    return acc

class AverageMeter(object): 
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

In [None]:
def train(train_loader, model, loss_fn, optimizer, epoch):
    batch_time = AverageMeter()
    losses = AverageMeter()
    accuracies = AverageMeter()

    model.train()
    end = time.time()

    for i, (images, emotions) in enumerate(train_loader):
        input = images.to(device)
        target = emotions.to(device)

        input_var = (torch.autograd.Variable(input)).to(device)
        target_var = (torch.autograd.Variable(target)).to(device)
        pred_scores = model(input_var)

        loss = loss_fn(pred_scores, target_var)
        acc = accuracy(pred_scores.data, target)
        losses.update(loss.item(), input.size(0))
        accuracies.update(acc, input.size(0))

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        batch_time.update(time.time() - end)
        end = time.time()

        print('\r',
              'Training [Epoch: {}/{} ({}/{})]: '
              'Time {:.2f}s ({:.2f}s) '
              'Loss {:.6f} ({:.6f}) '
              'Accuracy {:.4f} ({:.4f})'
              .format(epoch+1, epochs, i+1, len(train_loader), 
                      batch_time.val, batch_time.avg,
                      losses.val, losses.avg,
                      accuracies.val, accuracies.avg),
              end='')


def validate(val_loader, model, loss_fn):
    batch_time = AverageMeter()
    losses = AverageMeter()
    accuracies = AverageMeter()

    model.eval()
    end = time.time()
    print()

    for i, (images, emotions) in enumerate(val_loader):
        input = images.to(device).detach()
        target = emotions.to(device).detach()

        input_var = (torch.autograd.Variable(input)).to(device).detach()
        target_var = (torch.autograd.Variable(target)).to(device).detach()
        pred_scores = model(input_var)

        loss = loss_fn(pred_scores, target_var)
        acc = accuracy(pred_scores.data, target)
        losses.update(loss.item(), input.size(0))
        accuracies.update(acc, input.size(0))

        batch_time.update(time.time() - end)
        end = time.time()

        print('\r',
              'Test [Batch: {}/{}]: '
              'Time {:.2f}s ({:.2f}s) '
              'Loss {:.6f} ({:.6f}) '
              'Accuracy {:.4f} ({:.4f})'
              .format(i+1, len(val_loader), 
                      batch_time.val, batch_time.avg,
                      losses.val, losses.avg,
                      accuracies.val, accuracies.avg),
              end='')

    print(' *** Test Accuracy {:.4f} ***'.format(accuracies.avg))
    return accuracies.avg

In [None]:
batch_size_train = 32
batch_size_test = 32
image_path = '../dataset/FERPlus/images.npy'
emotion_path = '../dataset/FERPlus/emotions_multi.npy'

train_loader = torch.utils.data.DataLoader(
    FERPlusDataset(
        image_path,
        emotion_path,
        'train'
    ),
    batch_size=batch_size_train,
    shuffle=True,
)

val_loader = torch.utils.data.DataLoader(
    FERPlusDataset(
        image_path,
        emotion_path,
        'test'
    ),
    batch_size=batch_size_test,
    shuffle=False,
)

In [None]:
model = FER_VT()
if torch.cuda.device_count() > 1: 
    model = nn.DataParallel(model)
model = model.to(device)

In [None]:
learning_rate = 1e-3
weight_decay = 0.1
loss_fn = nn.MSELoss(reduction='mean')
loss_fn = loss_fn.to(device)
optimizer = torch.optim.Adam(
    model.parameters(),
    lr=learning_rate,
    weight_decay=weight_decay,
)

In [None]:
epochs = 40
accs = []
for epoch in range(epochs):
    train(train_loader, model, loss_fn, optimizer, epoch)
    acc = validate(val_loader, model, loss_fn)
    accs.append(acc)

In [None]:
if not os.path.isdir('./results/'):
    os.mkdir('./results/')
with open('./results/FER-VT.json', 'w', encoding='utf-8') as f:
    json.dump(accs, f, ensure_ascii=False)