In [43]:
import torch as t
import torch.nn as nn
import argparse

import warnings
warnings.filterwarnings("ignore")


from imports import *
from encoder import Encoder as encoder
from decoder import Decoder as decoder

# Taking a simple sent that will be passed into the transformer. 
sent: t.Tensor = t.randint(0, 8, (4,))
out_sent: t.Tensor = t.randint(0, 8, (4,))

print(f"Input sent: {sent}")
print(f"Output sent: {out_sent}")

class Transformer(nn.Module):
    def __init__(self, num_heads: int, sent: t.Tensor, out_sent: t.Tensor):
        super(Transformer, self).__init__()
        self.encoder = encoder(num_heads, sent)
        self.decoder = decoder(num_heads, out_sent)
        self.linear_output = nn.Linear(512, 8)  # 8 is the vocabulary size

        for params in self.parameters():
            params.requires_grad = True

    def forward(self) -> t.Tensor:
        encoder_output = self.encoder.forward()
        decoder_output = self.decoder.forward(encoder_output) 
        print(f"Shape of Decoder output: {decoder_output.shape}")
        output = self.linear_output(decoder_output)
        return output
    
# parser = argparse.ArgumentParser(
ce_loss = nn.CrossEntropyLoss()
num_heads = 2

model = Transformer(num_heads, sent, out_sent)
model.train()
predicted_logits = model.forward()

print(f"The predicted logits are {predicted_logits}")

ce_loss_none = nn.CrossEntropyLoss(reduction='none')

loss_contributions = ce_loss_none(predicted_logits.view(4, 8), out_sent)
print(f"Contribution of each word to the loss: {loss_contributions}")

loss = loss_contributions.sum()
print(f"Total loss: {loss}"); print()

loss_contributions.backward(t.tensor([1, 1, 1, 1], dtype=t.float32), retain_graph=True)
total  = model.linear_output.weight.grad
model.zero_grad()

loss_contributions.backward(t.tensor([1, 0, 0, 0], dtype=t.float32), retain_graph=True)
first_token  = model.linear_output.weight.grad
model.zero_grad()

loss_contributions.backward(t.tensor([0, 1, 0, 0], dtype=t.float32), retain_graph=True)
second_token  = model.linear_output.weight.grad
model.zero_grad()

loss_contributions.backward(t.tensor([0, 0, 1, 0], dtype=t.float32), retain_graph=True)
third_token  = model.linear_output.weight.grad
model.zero_grad()

loss_contributions.backward(t.tensor([0, 0, 0, 1], dtype=t.float32), retain_graph=True)
fourth_token  = model.linear_output.weight.grad
model.zero_grad()


print(f"Total gradient: {total[0,0]:.4f}")
print(f"The addition of gradient of first token: {first_token[0,0]:.4f}, second token: {second_token[0,0]:.4f}, third token: {third_token[0,0]:.4f}, fourth token: {fourth_token[0,0]:.4f} is: ")
print(first_token[0,0] + second_token[0,0] + third_token[0,0] + fourth_token[0,0])
print()

print(f"Are all the values equal of summed of gradient and total gradient: {t.allclose(total, first_token + second_token + third_token + fourth_token)}")

loss.backward(retain_graph=True)
_token  = model.linear_output.weight.grad
print(_token[0][0])
model.zero_grad()

print(f"Are all the values equal of without jacobian product-simply taking gradient and one taken by jacobian product: {t.allclose(total, _token)}")

Input sent: tensor([6, 3, 1, 0])
Output sent: tensor([5, 2, 7, 6])
torch.Size([1, 4, 1024])
Shape of Decoder output: torch.Size([1, 4, 512])
The predicted logits are tensor([[[-0.2316, -0.8765, -1.2680, -0.2764,  0.8412, -0.7650,  1.0826,
          -0.7140],
         [-0.3567, -0.9351, -1.1221, -0.3347,  0.9145, -0.6786,  0.8842,
          -0.6992],
         [ 0.0574, -0.4736, -1.4513, -0.0695,  0.6909, -0.6255,  1.2332,
          -0.8025],
         [-0.0403, -0.6667, -1.4379, -0.1408,  0.7269, -0.6249,  0.9919,
          -0.8684]]], grad_fn=<ViewBackward0>)
Contribution of each word to the loss: tensor([2.9023, 3.2083, 3.0281, 1.1244], grad_fn=<NllLossBackward0>)
Total loss: 10.263071060180664

Total gradient: -0.1308
The addition of gradient of first token: -0.0417, second token: 0.0102, third token: -0.0172, fourth token: -0.0820 is: 
tensor(-0.1308)

Are all the values equal of summed of gradient and total gradient: True
tensor(-0.1308)
Are all the values equal of without jacobian 

In [38]:
x = -0.1674/-0.0418
print(x)

4.0047846889952154
