## Multi-head attention transformer
### Encoder and Decoder
### (With masking)

Pytorch's implementation (in built)

########################################################################


### Training part of the transformer


In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_len=512, dropout=0):
        super(PositionalEncoding, self).__init__()
        self.encoding = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1).float()
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * -(torch.log(torch.tensor(10000.0)) / d_model))
        self.encoding[:, 0::2] = torch.sin(position * div_term)
        self.encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = self.encoding.unsqueeze(0)

    def forward(self, x):
        print(x.shape)
        return self.encoding[:, :x.size(1)].detach()



class TransformerModel1(nn.Module):
    def __init__(self, src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, max_seq_len, d_ff, dropout = 0):

        super(TransformerModel1, self).__init__()

        self.src_embedding = nn.Embedding(src_vocab_size, d_model)
        self.tgt_embedding = nn.Embedding(tgt_vocab_size, d_model)

        self.positional_encoding = PositionalEncoding(d_model, dropout=0, max_len=max_seq_len)

        self.transformer = nn.Transformer(
            d_model=d_model,
            nhead=num_heads,
            num_encoder_layers=num_encoder_layers,
            num_decoder_layers=num_decoder_layers,
            dropout=dropout,
            dim_feedforward=d_ff,
        )

        self.fc = nn.Linear(d_model, tgt_vocab_size)



    def generate_mask(self, src, tgt):

        src_mask = None
        seq_length = tgt.size(0)
        
        nopeak_mask = (torch.triu(torch.ones(seq_length, seq_length), diagonal=1)).bool()

        return src_mask, nopeak_mask

    def forward(self, src, tgt):

        src_mask, tgt_mask = self.generate_mask(src, tgt)

        print("Tgt mask shape = ", tgt_mask.shape)

        src = self.src_embedding(src) + self.positional_encoding(src)
        tgt = self.tgt_embedding(tgt) + self.positional_encoding(tgt)


        output = self.transformer(src, tgt, src_mask = src_mask, tgt_mask = tgt_mask, tgt_is_causal = False)
        output = self.fc(output)
        
        return output
    

In [15]:
torch.manual_seed(0)

src_vocab_size = 20
tgt_vocab_size = 20
d_model = 16
num_heads = 4
num_encoder_layers = 1
num_decoder_layers = 1
d_ff = 20
max_seq_len = 5
dropout = 0

transformer = TransformerModel1(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_encoder_layers, num_decoder_layers, max_seq_len, d_ff)

# Generate random sample data
# src_data = torch.randint(1, src_vocab_size, (max_seq_len , 3))  # (seq_length, batch_size,)
# tgt_data = torch.randint(1, tgt_vocab_size, ( max_seq_len, 3))  # (seq_length, batch_size)

src_data = torch.tensor([[0, 2, 4], [1, 0, 7], [2, 2, 0], [3, 5, 6], [6, 1, 9]])
tgt_data = torch.tensor([[1, 7, 9], [3, 4, 1], [5, 2, 8], [8, 0, 3], [4, 5, 9]])  # Target sequence

# src_data = torch.tensor([[2], [1], [5], [4]])
# tgt_data = torch.tensor([[1], [16], [5], [3], [9]]) 


state_dict = transformer.state_dict()

In [16]:
import copy

state_dict1 = copy.deepcopy(state_dict)

In [17]:
src_data.shape, tgt_data.shape

(torch.Size([5, 3]), torch.Size([5, 3]))

In [18]:
tgt_data.view(-1)
tgt_data.shape

torch.Size([5, 3])

In [19]:
src_data.shape, tgt_data.shape

(torch.Size([5, 3]), torch.Size([5, 3]))

In [21]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

for epoch in range(1):

    optimizer.zero_grad()

    print("FWD PASS START\n")
    print("src_data shape = ", src_data.shape)
    print("tgt_data shape = ", tgt_data[:-1, :].shape)
    
    output = transformer(src_data, tgt_data[:-1, :])
    print("FWD PASS END\n")

    print("output shape = ",output.shape)

    loss = criterion(output.contiguous().view(-1, tgt_vocab_size), tgt_data[1:, :].contiguous().view(-1))


    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")

FWD PASS START

src_data shape =  torch.Size([5, 3])
tgt_data shape =  torch.Size([4, 3])
Tgt mask shape =  torch.Size([4, 4])
torch.Size([5, 3])
torch.Size([4, 3])
torch.Size([5, 3, 16]) torch.Size([4, 3, 16])
MASK =  None
query =  tensor([[[-1.1257, -0.1525, -0.2507,  0.5660,  0.8488,  1.6919, -0.3161,
          -1.1151,  0.3224, -0.2634,  0.3501,  1.3082,  0.1197,  2.2378,
           1.1167,  0.7528],
         [ 0.2278,  0.5718, -0.1816,  1.1989,  0.5396,  1.1073,  0.6725,
           1.4406, -0.0922,  1.7925, -0.2866,  1.0526,  0.5240,  3.3021,
          -1.4687, -0.5868],
         [ 0.3401,  0.5039,  1.7018,  2.0966, -1.2796,  3.5474, -0.4100,
           1.3337, -1.6092,  0.4502, -0.4734,  0.5002, -1.0651,  2.1150,
          -0.1399,  1.8057]],

        [[-1.3528, -0.6960,  0.5668,  1.7936,  0.5989, -0.5552, -0.3413,
           2.8531,  0.7503,  0.4146, -0.1735,  1.1834,  1.3895,  2.5862,
           0.9462,  0.1562],
         [-0.2843, -0.6122,  0.0603,  0.5164,  0.9486,  1.6869, -

In [22]:
src_data, tgt_data[:-1, :]

(tensor([[0, 2, 4],
         [1, 0, 7],
         [2, 2, 0],
         [3, 5, 6],
         [6, 1, 9]]),
 tensor([[1, 7, 9],
         [3, 4, 1],
         [5, 2, 8],
         [8, 0, 3]]))