In [37]:
import torch as t
import torch.nn as nn
import argparse

from imports import *
from encoder import Encoder as encoder
from decoder import Decoder as decoder

# Taking a simple sent that will be passed into the transformer. 
sent: t.Tensor = t.randint(0, 8, (4,))
out_sent: t.Tensor = t.randint(0, 8, (4,))

print(f"Input sent: {sent}")
print(f"Output sent: {out_sent}")

class Transformer(nn.Module):
    def __init__(self, num_heads: int, sent: t.Tensor, out_sent: t.Tensor):
        super(Transformer, self).__init__()
        self.encoder = encoder(num_heads, sent)
        self.decoder = decoder(num_heads, out_sent)
        self.linear_output = nn.Linear(512, 8)  # 8 is the vocabulary size

        for params in self.parameters():
            params.requires_grad = True

    def forward(self) -> t.Tensor:
        encoder_output = self.encoder.forward()
        decoder_output = self.decoder.forward(encoder_output) 
        print(f"Shape of Decoder output: {decoder_output.shape}")
        output = self.linear_output(decoder_output)
        return output
    
# parser = argparse.ArgumentParser(
ce_loss = nn.CrossEntropyLoss()
num_heads = 2

model = Transformer(num_heads, sent, out_sent)
model.train()
predicted_logits = model.forward()

print(f"The predicted logits are {predicted_logits}")

ce_loss_none = nn.CrossEntropyLoss(reduction='none')
loss_contributions = ce_loss_none(predicted_logits.view(4, 8), out_sent)
print(f"Contribution of each word to the loss: {loss_contributions}")

loss = loss_contributions.mean()
print(f"Total loss: {loss}"); print()



Input sent: tensor([6, 3, 2, 2])
Output sent: tensor([3, 6, 3, 7])
torch.Size([1, 4, 1024])
Shape of Decoder output: torch.Size([1, 4, 512])
The predicted logits are tensor([[[ 0.1240,  0.9169, -0.4552, -0.4940, -0.4583,  0.1199,  0.1451,
          -0.6991],
         [ 0.1252,  1.1188, -0.3874, -0.4276, -0.4849,  0.0709,  0.1380,
          -0.4977],
         [ 0.2904,  1.0713, -0.4033, -0.1648, -0.7008,  0.1705,  0.1184,
          -0.5556],
         [ 0.1432,  1.1053, -0.3179,  0.0422, -0.7275,  0.1366,  0.0114,
          -0.2488]]], grad_fn=<ViewBackward0>)
Contribution of each word to the loss: tensor([2.6092, 2.0559, 2.3790, 2.4859], grad_fn=<NllLossBackward0>)
Total loss: 2.382481813430786

tensor([[-0.4354, -0.1223,  0.7927,  ...,  0.8131, -1.1523,  0.5010],
        [-1.0468, -0.3158,  1.9221,  ...,  1.9767, -2.7809,  1.2188],
        [-0.2511, -0.0725,  0.4484,  ...,  0.4655, -0.6561,  0.2869],
        ...,
        [-0.4214, -0.1134,  0.7497,  ...,  0.7770, -1.1006,  0.4786],
   

  return t.tensor(final_sent, dtype = t.float32)
  return t.tensor(final_sent, dtype = t.float32)


In [None]:
loss_contributions.backward(t.tensor([1, 1, 1, 1], dtype=t.float32))
print(model.linear_output.weight.grad)

In [None]:
loss_contributions.backward(t.tensor([1, 1, 1, 1], dtype=t.float32))
print(model.linear_output.weight.grad)

In [None]:
loss_contributions.backward(t.tensor([1, 1, 1, 1], dtype=t.float32))
print(model.linear_output.weight.grad)

In [None]:
loss_contributions.backward(t.tensor([1, 1, 1, 1], dtype=t.float32))
print(model.linear_output.weight.grad)

In [25]:
''' 
Task: TO analyse the gradient for the decoder part flowing in each component. 
'''

print(model)

print(model.linear_output.weight.grad.size())
p = t.autograd.grad(loss, model.linear_output.weight, retain_graph=True)
print(p[0].size())
print(a)

Transformer(
  (encoder): Encoder(
    (W_q): Linear(in_features=512, out_features=1024, bias=True)
    (W_k): Linear(in_features=512, out_features=1024, bias=True)
    (W_v): Linear(in_features=512, out_features=1024, bias=True)
    (W_o): Linear(in_features=1024, out_features=512, bias=True)
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (fc1): Linear(in_features=512, out_features=1024, bias=True)
    (fc2): Linear(in_features=1024, out_features=512, bias=True)
    (relu): ReLU()
  )
  (decoder): Decoder(
    (W_q): Linear(in_features=512, out_features=1024, bias=True)
    (W_k): Linear(in_features=512, out_features=1024, bias=True)
    (W_v): Linear(in_features=512, out_features=1024, bias=True)
    (W_q_m): Linear(in_features=512, out_features=1024, bias=True)
    (W_k_m): Linear(in_features=512, out_features=1024, bias=True)
    (W_v_m): Linear(in_features=512, out_features=1024, bias=True)
    (W_o): Linear(in_features=1024, out_features=512, bias=Tru