<h1><center>IST 597 Foundations of Deep Learning</center></h1>

---

<h2><center>Transformers</center><h2>
<h3><center>Neisarg Dave</center><h3>

Resources:
+ https://jalammar.github.io/illustrated-transformer/
+ https://github.com/jessevig/bertviz

Torch API:
+ [Tranformer Encoder](https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoder.html?highlight=transformerencoder#torch.nn.TransformerEncoder)
+ [Transformer Encoder Layer](https://pytorch.org/docs/stable/generated/torch.nn.TransformerEncoderLayer.html?highlight=transformerencoder#torch.nn.TransformerEncoderLayer)
+ [Transformer Decoder](https://pytorch.org/docs/stable/generated/torch.nn.TransformerDecoder.html?highlight=transformer+decoder#torch.nn.TransformerDecoder)
+ [Transformer Decoder Layer](https://pytorch.org/docs/stable/generated/torch.nn.TransformerDecoderLayer.html?highlight=transformer+decoder+layer#torch.nn.TransformerDecoderLayer)
+ [Tranformer](https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html?highlight=transformer#torch.nn.Transformer)
+ [Multihead Attention](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html)

In [2]:
import torch
import torch.nn as nn
import math

In [59]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, d_model, dtype = torch.float)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2, dtype = torch.float) * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        print(x.size(1))
        x = x + self.pe[:, :x.size(1), :]
        return x

In [56]:
class TransformerLanguageModel(nn.Module):
    def __init__(self, in_dim, state_dim, vocab_size, num_layers, num_heads):
        super(TransformerLanguageModel, self).__init__()
        self.vocab_size = vocab_size
        self.in_dim = in_dim
        self.state_dim = state_dim
        self.num_layers = num_layers
        self.n_heads = num_heads
        print("layers :", self.num_layers, "  num_heads : ", self.n_heads)
        self.embedding = torch.nn.Embedding(self.vocab_size, in_dim, padding_idx=0)
        self.pos_embedding = PositionalEncoding(in_dim, dropout=0.1, max_len = 30)
        self.transformer = torch.nn.Transformer(d_model=in_dim, nhead=self.n_heads, num_encoder_layers=self.num_layers, 
                                num_decoder_layers=self.num_layers, dim_feedforward=state_dim, dropout=0, batch_first = True)
        self.transformer.encoder.enable_nested_tensor = False
        self.symbol_classifier = nn.Sequential(nn.Linear(in_dim, self.vocab_size))
                                        

    def get_tgt_mask(self, size, device):
        mask = torch.tril(torch.ones([size, size], device = device) == 1) # Lower triangular matrix
        mask = mask.float()
        mask = mask.masked_fill(mask == 0, float('-inf')) 
        mask = mask.masked_fill(mask == 1, float(0.0))
            
        # EX for size=5:
        # [[0., -inf, -inf, -inf, -inf],
        #  [0.,   0., -inf, -inf, -inf],
        #  [0.,   0.,   0., -inf, -inf],
        #  [0.,   0.,   0.,   0., -inf],
        #  [0.,   0.,   0.,   0.,   0.]]
        return mask

    def forward(self, exp1, exp2, len_1, len_2):
        inp1 = self.embedding(exp1)
        inp2 = self.embedding(exp2)

        inp1 = self.pos_embedding(inp1.transpose(0,1)).transpose(0,1)
        inp2 = self.pos_embedding(inp2.transpose(0,1)).transpose(0,1)

        src_key_mask = torch.arange(inp1.shape[1]).cuda()
        src_key_mask = src_key_mask[None] <= (len_1-1)[:, None]

        tgt_key_mask = torch.arange(inp2.shape[1]).cuda()
        tgt_key_mask = tgt_key_mask[None] <= (len_2-1)[:, None]

        memory_key_mask = src_key_mask
        tgt_mask = self.get_tgt_mask(inp2.shape[1], inp2.device)


        out = self.transformer(inp1, inp2, src_key_padding_mask = ~src_key_mask, tgt_key_padding_mask = ~tgt_key_mask, memory_key_padding_mask = ~memory_key_mask, tgt_mask = tgt_mask)
                                        
        out = self.symbol_classifier(out)
        return out


### Positional Encoding

In [86]:
pe_layer = PositionalEncoding(d_model=4, max_len=16)
pe = pe_layer(torch.zeros(1, 16, 4))
print(pe.shape)

16
torch.Size([1, 16, 4])


In [87]:
print(pe)


tensor([[[ 0.0000,  1.0000,  0.0000,  1.0000],
         [ 0.8415,  0.5403,  0.0100,  0.9999],
         [ 0.9093, -0.4161,  0.0200,  0.9998],
         [ 0.1411, -0.9900,  0.0300,  0.9996],
         [-0.7568, -0.6536,  0.0400,  0.9992],
         [-0.9589,  0.2837,  0.0500,  0.9988],
         [-0.2794,  0.9602,  0.0600,  0.9982],
         [ 0.6570,  0.7539,  0.0699,  0.9976],
         [ 0.9894, -0.1455,  0.0799,  0.9968],
         [ 0.4121, -0.9111,  0.0899,  0.9960],
         [-0.5440, -0.8391,  0.0998,  0.9950],
         [-1.0000,  0.0044,  0.1098,  0.9940],
         [-0.5366,  0.8439,  0.1197,  0.9928],
         [ 0.4202,  0.9074,  0.1296,  0.9916],
         [ 0.9906,  0.1367,  0.1395,  0.9902],
         [ 0.6503, -0.7597,  0.1494,  0.9888]]])
