In [297]:
import torch as tc
from torch import nn
import torch.nn.functional as F

In [298]:
def num_parameters(model):
    
    n = sum([tc.numel(p) for p in model.parameters() if p.requires_grad])/1e6
    return str(n) + 'M'

In [299]:
x = tc.randint(0, 10, (3, 32, 32), dtype = tc.float32)

In [300]:
class  MultiHeadAttention(nn.Module):
    
    def __init__(self, d_model, nhead, dropout = 0.1, activation = 'relu', mask = None, device = 'cpu'):
        super(MultiHeadAttention, self).__init__()
        self.d_model    = d_model
        self.nhead      = nhead 
        self.dropout    = dropout 
        self.activation = activation
        self.mask       = mask
        self.device     = device
        #get k, q, v from x by linear layers
        self.fc0         = nn.Linear(d_model, d_model).to(self.device)
        self.fc1         = nn.Linear(d_model, d_model).to(self.device)
        self.fc2         = nn.Linear(d_model, d_model).to(self.device)
        self.fc_cat      = nn.Linear(nhead * d_model, d_model).to(self.device)
        
        
        pass 
    def forward(self, k, q, v):
        k = self.fc0(k)
        k = nn.Dropout(self.dropout)(k)
        k = self.activation(k)
        
        q = self.fc1(q)
        q = nn.Dropout(self.dropout)(q)
        q = self.activation(q)
        
        v = self.fc2(v)
        v = nn.Dropout(self.dropout)(v)
        v = self.activation(v)
        
        p = self.scaled_dot_product_attention(k, q, v)
        
        if self.nhead > 1:
            for i in range(1, self.nhead):
                
                k = self.fc0(k)
                k = nn.Dropout(self.dropout)(k)
                k = self.activation(k)
                
                q = self.fc1(q)
                q = nn.Dropout(self.dropout)(q)
                q = self.activation(q)
                
                v = self.fc2(v)
                v = nn.Dropout(self.dropout)(v)
                v = self.activation(v)
                
                p = tc.cat([p, self.scaled_dot_product_attention(k, q, v)], dim = 2)
                
        p = self.fc_cat(p)
        
        return p.to(self.device) 
    
    def scaled_dot_product_attention(self, k, q, v):
        
        scores = tc.matmul(k, q.transpose(-2, -1)) / tc.sqrt(tc.tensor(self.d_model))
        if self.mask is not None:
            scores = scores.masked_fill(self.mask == 0, -1e9)
        attention_weights = F.softmax(scores, dim=-1)
        output = tc.matmul(attention_weights, v)
        return output.to(self.device)
    

In [301]:
class Encoder(nn.Module):
    
    def __init__(self, d_model, nhead, dim_feedforward, dropout = 0.1, activation = 'relu', eps = 1e6, device = 'cpu'):
        super(Encoder, self).__init__()
        #self.layers for MultiHeadAttention part
        self.eps         = eps
        self.dropout     = dropout 
        self.activation  = activation
        self.d_model     = d_model
        self.nhead       = nhead
        self.device      = device
        
        #activation function
        if activation   == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'leakyrelu':
            self.activation = nn.LeakyReLU()
        elif activation == 'gelu':
            self.activation = nn.GELU()
        elif activation == 'sigmoid':
            self.activation = nn.Sigmoid()
        elif activation == 'silu':
            self.activation = nn.SiLU()
        elif activation == 'selu':
            self.activation = nn.SELU()
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        else:
            print(f'{activation}  activation function is not valid. Please select one of the following options: 1)relu 2)leakyrelu 3)gelu 4)sigmoid 5)silu 6)selu 7)tanh')
        #Multi Head Attention
        self.multi_head_attention = MultiHeadAttention(self.d_model, self.nhead, self.dropout, self.activation, mask = None, device = self.device)
        
        #self.layers for Norm part
        self.norm0 = nn.LayerNorm(d_model, eps = eps).to(self.device)
        self.norm1 = nn.LayerNorm(d_model, eps = eps).to(self.device)
        
        #self.layers for FeedForward part
        self.ff4 = nn.Linear(d_model, dim_feedforward).to(self.device)
        self.ff  = nn.Linear(dim_feedforward,     d_model).to(self.device)
    
    def forward(self, x):
        #positional encoding
        x = x + self.positional_encoding(x.shape[-2])
        x = x.to(self.device)
        
        #Multi Head Attention
        y = self.multi_head_attention(x, x, x)
        y = nn.Dropout(self.dropout)(y)
        y = self.activation(y)
        
        #add and normalize
        z = x + y 
        z = self.norm0(z)
        
        #Feed Forward
        y = self.ff4(z)
        y = nn.Dropout(self.dropout)(y)
        y = self.activation(y)
        
        y = self.ff(y)
        
        #add and normalize
        z = z + y
        y = self.norm1(z)
        
        return y
    
    def positional_encoding(self, l, n = 100):
        
        position = tc.zeros(l, self.d_model)
        
        for k in range(l):
            for i in range(int(self.d_model/2)):
                
                position[k][2 * i]     = tc.sin(tc.tensor(k / (n) ** (2 * i / self.d_model)))
                position[k][2 * i + 1] = tc.cos(tc.tensor(k / (n) ** (2 * i / self.d_model)))
                    
        return position.to(self.device)

In [302]:
mask = tc.zeros(2, 4, 4)
mask[1, 2:] = 1

In [303]:
num_parameters(nn.TransformerEncoderLayer(512, 16, 2048, 0.1, 'relu'))

'3.152384M'

In [304]:
num_parameters(Encoder(512, 16, 2048, 0, 'relu',1e-6, 'cuda'))

'7.084544M'

In [305]:
Encoder(32, 8, 2048, 0.1, 'relu',1e-6, 'cpu')(x).shape

torch.Size([3, 32, 32])

In [306]:
Encoder(512, 4, 2048)

Encoder(
  (activation): ReLU()
  (multi_head_attention): MultiHeadAttention(
    (activation): ReLU()
    (fc0): Linear(in_features=512, out_features=512, bias=True)
    (fc1): Linear(in_features=512, out_features=512, bias=True)
    (fc2): Linear(in_features=512, out_features=512, bias=True)
    (fc_cat): Linear(in_features=2048, out_features=512, bias=True)
  )
  (norm0): LayerNorm((512,), eps=1000000.0, elementwise_affine=True)
  (norm1): LayerNorm((512,), eps=1000000.0, elementwise_affine=True)
  (ff4): Linear(in_features=512, out_features=2048, bias=True)
  (ff): Linear(in_features=2048, out_features=512, bias=True)
)

In [307]:
class Decoder(nn.Module):
    
    def __init__(self, d_model, nhead, dim_feedforward, dropout = 0.1, activation = 'relu', eps = 1e6, mask = None, device = 'cpu'):
        super(Decoder, self).__init__()
        self.mask        = mask
        self.eps         = eps
        self.dropout     = dropout 
        self.activation  = activation
        self.d_model     = d_model
        self.nhead       = nhead
        self.device      = device
        
        #activation function
        if activation   == 'relu':
            self.activation = nn.ReLU()
        elif activation == 'leakyrelu':
            self.activation = nn.LeakyReLU()
        elif activation == 'gelu':
            self.activation = nn.GELU()
        elif activation == 'sigmoid':
            self.activation = nn.Sigmoid()
        elif activation == 'silu':
            self.activation = nn.SiLU()
        elif activation == 'selu':
            self.activation = nn.SELU()
        elif activation == 'tanh':
            self.activation = nn.Tanh()
        else:
            print(f'{activation}  activation function is not valid. Please select one of the following options: 1)relu 2)leakyrelu 3)gelu 4)sigmoid 5)silu 6)selu 7)tanh')
        
        #self.layers for MaskedMultiHeadAttention part
        self.masked_multi_head_attention = MultiHeadAttention(self.d_model, self.nhead, self.dropout, self.activation, mask = self.mask, device = self.device)
        
        #self.layers for MultiHeadAttention part
        self.multi_head_attention = MultiHeadAttention(self.d_model, self.nhead, self.dropout, self.activation, mask = None, device = self.device)
        
        #self.layers for Norm part
        self.norm0 = nn.LayerNorm(d_model, eps = eps, device = device).to(self.device)
        self.norm1 = nn.LayerNorm(d_model, eps = eps, device = device).to(self.device)
        self.norm2 = nn.LayerNorm(d_model, eps = eps, device = device).to(self.device)
        
        #self.layers for FeedForward part
        self.ff4 = nn.Linear(d_model, dim_feedforward, device = device).to(self.device)
        self.ff  = nn.Linear(dim_feedforward,     d_model, device = device).to(self.device)
        
        #last layer
        self.fc  = nn.Linear(d_model, d_model, device = device).to(self.device)
    
    def forward(self, target, encoder_out):
        #positional encoding
        target = target + self.positional_encoding(target.shape[-2])
        target = target.to(self.device)
        
        #Masked Multi Head Attention
        y = self.masked_multi_head_attention(target, target, target)
        y = nn.Dropout(self.dropout)(y)
        y = self.activation(y)
        
        #add and normalize
        z = target + y 
        z = self.norm0(z)
        
        #Multi Head Attention
        y = self.multi_head_attention(encoder_out, encoder_out, z)
        y = nn.Dropout(self.dropout)(y)
        y = self.activation(y)
        
        #add and normalize 
        z = z + y 
        z = self.norm0(z)
        
        #Feed Forward
        y = self.ff4(z)
        y = nn.Dropout(self.dropout)(y)
        y = self.activation(y)
        
        y = self.ff(y)
        
        #add and normalize
        z = z + y
        y = self.norm1(z)
        
        #last layer
        y = self.fc(y)
        y = nn.Softmax(dim = -1)(y)
        
        return y
    
    def positional_encoding(self, l, n = 100):
        
        position = tc.zeros(l, self.d_model)
        
        for k in range(l):
            for i in range(int(self.d_model/2)):
                
                position[k][2 * i]     = tc.sin(tc.tensor(k / (n) ** (2 * i / self.d_model)))
                position[k][2 * i + 1] = tc.cos(tc.tensor(k / (n) ** (2 * i / self.d_model)))
                    
        return position.to(self.device)

In [308]:
num_parameters(nn.TransformerDecoderLayer(512, 8, 2048, 0, 'relu'))

'4.204032M'

In [309]:
num_parameters(Decoder(512, 8, 2048, 0, 'relu', 1e-6, None, 'cpu'))

'8.136704M'

In [310]:
Decoder(32, 4, 2048, 0.1, 'relu', 1e-6, None, 'cpu')(x, x).shape

torch.Size([3, 32, 32])

In [311]:
class Transformer(nn.Module):
    
    def __init__(self, d_model , nhead, num_encoder_layers, num_decoder_layers, dim_feedforward, dropout = 0.1, activation = 'relu', eps = 1e6, mask = None,  device = 'cpu'):
        super(Transformer, self).__init__()
        self.num_encoder_layers = num_encoder_layers 
        self.num_decoder_layers = num_decoder_layers 
        self.dim_feedforward    = dim_feedforward 
        self.nhead              = nhead 
        self.d_model            = d_model 
        self.dropout            = dropout 
        self.activation         = activation 
        self.eps                = eps 
        self.mask               = mask 
        self.device             = device
        #init encoders
        self.encoder_list = []
        for i in range(self.num_encoder_layers):
            self.encoder_list.append(Encoder(self.d_model, self.nhead, self.dim_feedforward, self.dropout, self.activation, self.eps, device = self.device))
        self.encoder_layer = nn.Sequential(*self.encoder_list)
        #init decoders
        self.decoder_list = []
        for i in range(self.num_decoder_layers):
            self.decoder_list.append(Decoder(self.d_model, self.nhead, self.dim_feedforward, self.dropout, self.activation, self.eps, self.mask, device = self.device))
        self.decoder_layer = nn.Sequential(*self.decoder_list)
        
        pass
    def forward(self, x, target):
        #Encoder
        if self.num_encoder_layers == 0:
            y = x 
        else:
            y = self.encoder(x)
        
        #Decoder
        if self.num_decoder_layers == 0:
            None 
        else:
            y = self.decoder(y, target)
        
        return y
    def encoder(self, x):
        y = self.encoder_layer[0](x)
        if self.num_encoder_layers > 1:
            for i in range(1, self.num_encoder_layers):
                y = self.encoder_layer[i](y) 
        return y
    
    def decoder(self, decoder_in, target):
        z = self.decoder_layer[0](target, decoder_in)
        if self.num_decoder_layers > 1:
            for i in range(1, self.num_decoder_layers):
                z = self.decoder_layer[i](z, decoder_in)
        return z

In [312]:
num_parameters(nn.Transformer(512, 8, 4, 4, 2048, 0, 'relu'))

'29.427712M'

In [313]:
num_parameters(Transformer(512, 8, 4, 4, 2048, 0, 'relu', 1e-6, None, 'cpu'))

'52.496384M'

In [314]:
Transformer(32, 8, 10, 0, 2048, 0, 'relu', 1e-6, None, 'cuda')(x.cuda(), x.cuda()).shape

torch.Size([3, 32, 32])

In [315]:
Transformer(512, 8, 2, 4, 2048, 0, 'relu', 1e-6, None, 'cpu')

Transformer(
  (encoder_layer): Sequential(
    (0): Encoder(
      (activation): ReLU()
      (multi_head_attention): MultiHeadAttention(
        (activation): ReLU()
        (fc0): Linear(in_features=512, out_features=512, bias=True)
        (fc1): Linear(in_features=512, out_features=512, bias=True)
        (fc2): Linear(in_features=512, out_features=512, bias=True)
        (fc_cat): Linear(in_features=4096, out_features=512, bias=True)
      )
      (norm0): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
      (norm1): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
      (ff4): Linear(in_features=512, out_features=2048, bias=True)
      (ff): Linear(in_features=2048, out_features=512, bias=True)
    )
    (1): Encoder(
      (activation): ReLU()
      (multi_head_attention): MultiHeadAttention(
        (activation): ReLU()
        (fc0): Linear(in_features=512, out_features=512, bias=True)
        (fc1): Linear(in_features=512, out_features=512, bias=True)
        (fc

In [316]:
nn.Transformer(512, 8, 4, 4, 2048, 0, 'relu')

Transformer(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0, inplace=False)
        (dropout2): Dropout(p=0, inplace=False)
      )
    )
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (decoder): TransformerDecoder(
    (layers): ModuleList(
      (0-3): 4 x TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_fea