# Relaxed Attention for Transformer Models

paper -> https://arxiv.org/pdf/2209.09735.pdf

### Notes

- abstract: The powerful modeling capabilities of all-attention-based transformer architec-
tures often cause overfitting

- instead of swin, I implemented vanilla ViT with relaxed attention
- followed the training receipe at the paper, additionally lr scheduler added

In [1]:
import torch
import torch.nn as nn
import torchvision
import torch.nn.functional as F
!pip install einops
import einops
import torchvision
import torchvision.transforms as transforms
import cv2
import numpy as np
from torchvision import datasets
from torch.utils.data.sampler import SubsetRandomSampler
import os
from pathlib import Path
import torch
from PIL import Image
from tqdm import tqdm
from torchmetrics import F1Score,Accuracy
import matplotlib.pyplot as plt
import copy
import pytorch_lightning as pl
import pandas as pd
import seaborn as sn
from IPython.core.display import display
from torch.utils.data import random_split, DataLoader


Collecting einops
  Downloading einops-0.4.1-py3-none-any.whl (28 kB)
Installing collected packages: einops
Successfully installed einops-0.4.1
[0m

In [29]:
valid_per = 0.20
batch_size = 1
epochs = 20
beta1 = 0.9
beta2 = 0.990
weight_decay = 0.3
embed_size =768
num_heads = 8
patch_size = 12
in_channels = 3
num_encoders = 8
num_class = 12
relaxation_coeff = 0.03
device = "cuda"
batch_size = 8
learning_rate = 0.000125
epochs = 30
label_smooth = 0.1

In [3]:
class Project(nn.Module):
    def __init__(self,patch_size:int,in_channels:int,embed_size:int,batch_size:int):
        super().__init__()
        self.patch_size = patch_size
        self.batch_size = batch_size 
        self.embed_size = embed_size # embed size is the size of linearly projected patch of image
        self.in_channels = in_channels # channel size of image, 1 for grayscale, 3 for colored image

        self.linear = nn.Linear(self.in_channels*self.patch_size**2,self.embed_size)
        self.class_token = nn.Parameter(torch.randn(self.batch_size,1,embed_size))
        self.position_embed = nn.Parameter(torch.randn(self.patch_size**2+1,self.embed_size))

    def forward(self,x:int): # num_batch x in_channel x width x height -> num_bahch x  num_patch x embed_size
        out = einops.rearrange(x,"b c (h px) (w py) -> b (h w) (c px py)",px = self.patch_size, py = self.patch_size)
        out = self.linear(out)
        out = torch.cat([out,self.class_token],dim = 1)
        out = out + self.position_embed
        return out



In [4]:
class DotProductAttention(nn.Module):
    def __init__(self,relaxation_coeff,patch_size):
        super().__init__()
        self.softmax = nn.Softmax(dim = 1)
        self.relaxation = Relaxation(relaxation_coeff,patch_size)
    def forward(self,query,key,value): 
        relaxation = self.relaxation(self.softmax(torch.matmul(query,key)/key.size(dim = 2)**(1/2)))
        sdp = torch.matmul(relaxation,value)
        return sdp

In [5]:
class Relaxation(nn.Module):
    def __init__(self,relaxation_coeff:float,patch_size:int):
        super().__init__()
        self.relaxation_coeff = relaxation_coeff
        self.T = (torch.ones((patch_size**2+1,patch_size**2+1))/patch_size).to("cuda")

    def forward(self,embed:torch.Tensor):
        return ((1-self.relaxation_coeff)*embed + self.relaxation_coeff*self.T)


In [30]:
class MultiHeadAttention(nn.Module):
    def __init__(self,embed_size,num_heads,relaxation_coeff,patch_size,dropout = 0.2):
        super().__init__()
        self.num_heads = num_heads
        self.embed_size = embed_size
        self.DotProductAttention = DotProductAttention(relaxation_coeff,patch_size)
       

        self.key = nn.Linear(self.embed_size,self.num_heads* self.embed_size,bias = False)
        self.query = nn.Linear(self.embed_size,self.num_heads* self.embed_size,bias = False)
        self.value = nn.Linear(self.embed_size,self.num_heads* self.embed_size,bias = False)
        
        self.linear = nn.Linear(self.num_heads*self.embed_size,embed_size,bias = False)
        self.layer_norm = nn.LayerNorm(self.embed_size, eps=1e-6)
        self.dropout = nn.Dropout(dropout)
    def forward(self,embed):
        batch_size = embed.size(0)
        query = self.query(embed)
        key = einops.rearrange(self.query(embed),"b n e ->b e n")
        value = self.value(embed)
        sdp = self.DotProductAttention(query,key,value)
        return self.linear(self.dropout(sdp))
        


In [31]:
class EncoderBlock(nn.Module):
    def __init__(self,embed_size,num_heads,relaxation_coeff,patch_size,dropout = 0.2):
        super().__init__()
        self.embed_size = embed_size
        self.num_heads = num_heads
        self.dropout =dropout
        self.mha = MultiHeadAttention(768,8,relaxation_coeff,patch_size)
        self.Linear1 = nn.Linear(self.embed_size,self.embed_size*4,bias=False)
        self.Linear2 = nn.Linear(self.embed_size*4,self.embed_size,bias = False)
        self.gelu = nn.GELU()
        self.layer_norm1 = nn.LayerNorm(self.embed_size,eps = 1e-6)
        self.layer_norm2 = nn.LayerNorm(self.embed_size,eps = 1e-6)
        self.dropout = nn.Dropout(dropout)
    def forward(self,embed:torch.Tensor):
        embed = embed + self.mha(self.layer_norm1(embed))
        embed = embed+ self.Linear2(self.dropout(self.gelu(self.Linear1(self.dropout(self.layer_norm2(embed))))))
        return embed


In [32]:
class Transformer(nn.Module):
    def __init__(self,embed_size,num_heads,patch_size,in_channels,batch_size,num_encoders,num_class,device,relaxation_coeff):
        super().__init__()
        self.device = device
        self.num_heads = num_heads
        self.embed_size = embed_size
        self.patch_size =patch_size
        self.in_channels = in_channels
        self.batch_size = batch_size
        self.num_encoders = num_encoders
        self.num_class = num_class
        self.proj = Project(self.patch_size,self.in_channels,self.embed_size,self.batch_size)

        self.tiny_block = [EncoderBlock(self.embed_size,self.num_heads,relaxation_coeff,patch_size) for i in range(self.num_encoders)]
        self.block_seq = nn.Sequential(*self.tiny_block)
        #self.block_seq = EncoderBlock(self.embed_size,self.num_heads,relaxation_coeff,patch_size)
        self.linear1 = nn.Linear(self.embed_size,self.num_class*4)
        self.linear2 = nn.Linear(self.num_class*4,self.num_class)
        self.layernorm1 = nn.LayerNorm(num_class*4)
        self.layernorm2 = nn.LayerNorm(num_class)
        self.logsoftmax = nn.LogSoftmax(dim = 0)
    def num_of_parameters(self,):

        return sum(p.numel() for p in self.parameters())
    
    def forward(self,img):
        out = self.proj(img)
        out = self.block_seq(out)
        out = self.linear1(torch.squeeze(torch.index_select(out,1,torch.tensor([self.patch_size**2]).to(self.device))))
        out = self.layernorm1(out)
        out = self.linear2(out)
        out = self.layernorm2(out)
        out = self.logsoftmax(out)
        return out

In [33]:
class Custom_Dataset():

    def __init__(self, directory):
        self.path = Path(directory)
        Path.ls = lambda x: list(x.iterdir())
        try:
            files = os.listdir(directory)
        except:
            print("wrong path")
        self.x = [torch.tensor(np.transpose(np.array(Image.open(img).resize((144,144)))[:, :, :3], (2, 0, 1))).type(
            torch.FloatTensor) for img in (self.path/files[0]).ls()]
        self.x = torch.stack(self.x)/255
        self.y = torch.tensor([0]*len((self.path/files[0]).ls()))
        
        for i in range(len(files)-1):
            try:
                self.x2 = [torch.tensor(np.transpose(np.array(Image.open(img).resize((144,144)))[:, :, :3], (2, 0, 1))).type(
                torch.FloatTensor) for img in (self.path/files[i+1]).ls()]
            except:
                return 
            self.x2 = torch.stack(self.x2)/255
            self.x = torch.cat((self.x, self.x2), 0)
            self.y = torch.cat((self.y, torch.tensor(
                [i+1]*len((self.path/files[i+1]).ls()))))
        
    def __len__(self):
        return len(self.x)
    
    def __getitem__(self, i):
        return self.x[i], self.y[i]
    

In [11]:
dataset = Custom_Dataset("../input/animal-image-classification-dataset/Animal Image Dataset")
data_train, data_val = random_split(dataset, [len(dataset)-3000, 3000])

In [34]:
trainloader =DataLoader(data_train, batch_size=8,drop_last=True)
validloader = DataLoader(data_val, batch_size=8,drop_last=True)

In [35]:
model = Transformer(embed_size,num_heads,patch_size,in_channels,batch_size,num_encoders,num_class,device,relaxation_coeff).to(device)

In [36]:
f1 = F1Score().to(device)

In [37]:
class LabelSmoothing_NLLL(nn.Module):
    """NLL loss with label smoothing.
    """
    def __init__(self, smoothing=0.0):
        """Constructor for the LabelSmoothing module.
        :param smoothing: label smoothing factor
        """
        super(LabelSmoothing_NLLL, self).__init__()
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing

    def forward(self, x, target):
        logprobs = torch.nn.functional.log_softmax(x, dim=-1)
        nll_loss = -logprobs.gather(dim=-1, index=target.unsqueeze(1))
        nll_loss = nll_loss.squeeze(1)
        smooth_loss = -logprobs.mean(dim=-1)
        loss = self.confidence * nll_loss + self.smoothing * smooth_loss
        return loss.mean()

In [38]:
criterion = LabelSmoothing_NLLL(smoothing=label_smooth)

In [39]:
optimizer = torch.optim.Adam(model.parameters(),lr = learning_rate,betas = (beta1,beta2),weight_decay=weight_decay)

In [40]:
device = "cuda"
valid_loss_min = np.Inf
for i in range(epochs):
    print("Epoch - {} Started".format(i+1))

    train_loss = 0.0
    valid_loss = 0.0
    train_score = 0.0
    val_score = 0.0
    model.train()
    for data, target in trainloader:
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()*data.size(0)
        train_score = train_score +f1(output,target)
    
    model.eval()
    
    for data, target in validloader:
        data, target = data.to(device), target.to(device)
        with torch.no_grad():
            output = model(data)
        loss = criterion(output, target)
        valid_loss += loss.item()*data.size(0)
        val_score = val_score + f1(output,target)

    train_loss = train_loss/len(trainloader.sampler)
    valid_loss = valid_loss/len(validloader.sampler)
    train_score = batch_size*train_score/len(trainloader.sampler)
    val_score = batch_size*val_score/len(validloader.sampler)

    print(f"F1 Score for train: {train_score}, F1 Score for validation: {val_score} ")
    print('Epoch: {} \tTraining Loss: {:.6f} \tValidation Loss: {:.6f}'.format(
        i+1, train_loss, valid_loss))

    if valid_loss <= valid_loss_min:
        print('Validation loss decreased ({:.6f} --> {:.6f}).  Saving model ...'.format(
            valid_loss_min,
            valid_loss))
        best_model_wts = copy.deepcopy(model.state_dict())

        valid_loss_min = valid_loss
torch.save(best_model_wts, 'model.pt')

Epoch - 1 Started
F1 Score for train: 0.20975783467292786, F1 Score for validation: 0.24799999594688416 
Epoch: 1 	Training Loss: 2.103555 	Validation Loss: 2.089716
Validation loss decreased (inf --> 2.089716).  Saving model ...
Epoch - 2 Started
F1 Score for train: 0.1905270665884018, F1 Score for validation: 0.21833333373069763 
Epoch: 2 	Training Loss: 2.090135 	Validation Loss: 2.088041
Validation loss decreased (2.089716 --> 2.088041).  Saving model ...
Epoch - 3 Started
F1 Score for train: 0.20477208495140076, F1 Score for validation: 0.21466666460037231 
Epoch: 3 	Training Loss: 2.087483 	Validation Loss: 2.086394
Validation loss decreased (2.088041 --> 2.086394).  Saving model ...
Epoch - 4 Started
F1 Score for train: 0.2004985809326172, F1 Score for validation: 0.21466666460037231 
Epoch: 4 	Training Loss: 2.084855 	Validation Loss: 2.084636
Validation loss decreased (2.086394 --> 2.084636).  Saving model ...
Epoch - 5 Started
F1 Score for train: 0.18660968542099, F1 Score fo

KeyboardInterrupt: 