In [1]:
'''
Multi-headed self-attention layer
'''
import torch
from torch import nn
from einops.layers.torch import Rearrange 

class MultiHeadedSelfAttention(nn.Module):
    def __init__(self, indim, adim, nheads, drop):
        '''
        indim : (int) input dimension of input vector
        adim: (int) dimension of each attention head
        nheads: (int) number of heads in the multi-headed attention layer
        drop: (float 0~1) probabilty of dropping a node
        
        Implements QKV multi-headed attention layer
        output = softmax(Q*K/sqrt(d))*V
        scale = 1/sqrt(d), here, d=adim
        '''
        super(MultiHeadedSelfAttention, self).__init__()
        hdim = adim*nheads
        self.scale= hdim**-0.5 # scale in softmax(Q*K*scale)*V
        # Create a list of nheads (key, value, query) layers.
        # Each layer has 3 linear layers, one for each qkv.

        #nn.Linear(indim, hdim, bias=False)
        #there are nheads number of linear layers, each with hdim output
        self.query_lyr = self.get_qkv_layer(indim, hdim, nheads)
        self.value_lyr = self.get_qkv_layer(indim, hdim, nheads)
        
        self.attention_scores = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(drop)
        
        self.out_layer = nn.Sequential(Rearrange('bsize nheads indim hdim -> bsize (nheads indim) hdim'),
                                       nn.Linear(hdim, indim,),
                                       nn.Dropout(drop))
    
    def forward(self, x):
        query = self.key_lyr(x)
        key = self.query_lyr(x)
        value = self.value_lyr(x)
        
        dotp = torch.matmul(query, key.transpose(-1,-2))*self.scale
        
        scores = self.attention_scores(dotp)
        
        scores = self.dropout(scores)
        
        weighted = torch.matmul(scores, value)
        
        out = self.out_layer(weighted)
        
        return out
        
        

ModuleNotFoundError: No module named 'einops'

In [None]:
'''Transformer Encoder Layer'''
class TransformerEncoder(nn.Module):
    '''
    Although torch has a nn.Transformer class, it includes both encoder and decoder layers (with cross attention). Since ViT requires only the encoder, we can't use nn.Transformer.
    So we implement our own Transformer encoder layer class
    '''
    def __init__(self, nheads, nlayers, embed_dim, head_dim, mlp_hdim, dropout):
        '''
        nheads: (int) number of heads in the multi-headed attention layer
        nlayers: (int) number of multi-headed attention layers in the encoder
        embed_dim: (int) input dimension of input tokens
        head_dim: (int) dimension of each attention head
        mlp_hdim: (int) number of hidden dimensions in hidden layer of MLP
        dropoutL: (float 0~1) probabilty of dropping a node
        '''
        super(TransformerEncoder, self).__init__()
        self.nheads = nheads
        self.nlayers = nlayers
        self.embed_dim = embed_dim
        self.head_dim = head_dim
        self.mlp_hdim = mlp_hdim
        self.drop_prob = dropout
        
        self.salayers, self.fflayers = self.getlayers()
        
    def getlayers(self):
        samodules = nn.ModuleList()
        ffmodules = nn.ModuleList()
        
        for _ in range(self.nlayers):
            sam = nn.Sequential(
                nn.LayerNorm(self.embed_dim),
                MultiHeadedSelfAttention(
                    self.embed_dim,
                    self.head_dim,
                    self.nheads,
                    self.drop_prob
                )
            )
            
            samodules.append(sam)
            
            ffm = nn.Sequential(
                nn.LayerNorm(self.embed_dim),
                nn.Linear(self.embed_dim, self.mlp_hdim),
                nn.GELU(),
                nn.Dropout(self.drop_prob),
                nn.Linear(self.mlp_hdim, self.embed_dim),
                nn.Dropout(self.drop_prob)
            )
            
            ffmodules.append(ffm)
        
        return samodulesm, ffmodules
    
    def forward(self, x):
        for (sal,ffl) in zip(self.salayers, self.fflayers):
            x = x + sal(x)
            x = x + ffl(x)
        
        return x
            
        
        
        
    

In [None]:
'''
Vision Transformer Class
'''
class VisionTransformer(nn.Module):
    def __init__(self, cfg):
        super(VisionTransformer, self).__init__()
        
        input_size = cfg['input_size']
        self.patch_size = cfg['patch_size']
        self.embed_dim = cfg['embed_dim']
        salayers = cfg['salayers']
        nheads = cfg['nheads']
        head_dim = cfg['head_dim']
        mlp_hdim = cfg['mlp_hdim']
        drop_prob = cfg['drop_prob']
        nclasses = cfg['nclasses']
        
        self.num_patches = (input_size[0]//self.patch_size[0])*(input_size[1]//self.patch_size) + 1
        
        self.patch_embedding = nn.Sequential(
            Rearrange('b c (h px) (w py) -> b (h w) (px py c)', px=self.patch_size[0], py=self.patch_size),
            nn.Linear(self.patch_size*self.patch_size[1]*input_size[2], self.embed_dim)
        )
        
        self.dropout_layer = nn.Dropout(drop_prob)
        
        self.cls_token = nn.Parameter(torch.rand(1,1, self.embed_dim))
        '''Similar to BERT, we add a cls token as a learnable parameter at the beginning of the ViT model. This token is evoked with self attention
        and is used to predict the class of the image at the end. Tokens from all patches are IGNORED.
        '''
        self.positional_embedding = nn.Parameter(torch.rand(1, self.num_patches+1, self.embed_dim))
        #learnable positional embedding
        
        self.transformer = TransformerEncoder(
            nheads = nheads,
            nlayers = salayers,
            embed_dim = self.embed_dim,
            head_dim = head_dim,
            mlp_hdim = mlp_hdim,
            dropout = drop_prob
        )
        
        self.prediction_head = nn.Sequential(nn.LayerNorm(self.embed_dim), nn.Linear(self.embed_dim, nclasses)) 


In [None]:
#implementing the forward pass

def forward(self, x):
    #x is in the format of (batch_size, channels, height, width)
    npatches  = (x.size(2)//self.patch_size[0])*(x.size(3)//self.patch_size) + 1
    embed = self.patch_embedding(x)
    
    x = torch.cat((self.cls_token.expand(x.size(0), -1, -1), embed), dim=1)
    #repeat class token for every sample in batch and cat along patch dimension, so class is treated as a patch
    
    if npatches == self.num_patches:
        x = x + self.positional_embedding
        #if the image is the same size as the training images, add the positional embedding
    else:
        interpolated = nn.functional.interpolate(
            self.positional_embedding[None], #add a dummy dimension
            (npatches+1, self.embed_dim),
            mode='bicubic'
        )
        #we use bilinear but only linear will be used for the cls token
        x+= interpolated[0]#remove dummy dimension
    
    x = self.dropout_layer(x)
    
    x = self.transformer(x)
    
    x = x[:,0,:] #use the first token (cls token) to predict the class and ignore the rest
    
    pred = self.prediction_head(x)
    
    return pred
    
    


In [None]:
!pip install darklight


In [None]:
import darklight as dl
import torch
from vit import VisionTransformer
import vitconfigs as vcfg
 
net=VisionTransformer(vcfg.base)
dm=dl.ImageNetManager('/sfnvme/imagenet/', size=[224,224], bsize=128)
 
opt_params={
  'optimizer': torch.optim.AdamW,
  'okwargs': {'lr': 1e-4, 'weight_decay':0.05},
  'scheduler':torch.optim.lr_scheduler.CosineAnnealingWarmRestarts,
  'skwargs': {'T_0':10,'T_mult':2},
  'amplevel': None
  }
trainer=dl.StudentTrainer(net, dm, None, opt_params=opt_params)
trainer.train(epochs=300, save='vitbase_{}.pth')