In [1]:
import torch as t
import torch.nn as nn
import argparse

from imports import *
from encoder import Encoder as encoder
from decoder import Decoder as decoder

# Taking a simple sent that will be passed into the transformer. 
sent: t.Tensor = t.randint(0, 8, (4,))
out_sent: t.Tensor = t.randint(0, 8, (4,))

print(f"Input sent: {sent}")
print(f"Output sent: {out_sent}")

class Transformer(nn.Module):
    def __init__(self, num_heads: int, sent: t.Tensor, out_sent: t.Tensor):
        super(Transformer, self).__init__()
        self.encoder = encoder(num_heads, sent)
        self.decoder = decoder(num_heads, out_sent)
        self.linear_output = nn.Linear(512, 8)  # 8 is the vocabulary size

        for params in self.parameters():
            params.requires_grad = True

    def forward(self) -> t.Tensor:
        encoder_output = self.encoder.forward()
        decoder_output = self.decoder.forward(encoder_output) 
        print(f"Shape of Decoder output: {decoder_output.shape}")
        output = self.linear_output(decoder_output)
        return output
    
# parser = argparse.ArgumentParser(
ce_loss = nn.CrossEntropyLoss()
num_heads = 2

model = Transformer(num_heads, sent, out_sent)
model.train()
predicted_logits = model.forward()

print(f"The predicted logits are {predicted_logits}")

ce_loss_none = nn.CrossEntropyLoss(reduction='none')
loss_contributions = ce_loss_none(predicted_logits.view(4, 8), out_sent)
print(f"Contribution of each word to the loss: {loss_contributions}")

loss = loss_contributions.mean()
print(f"Total loss: {loss}"); print()

Input sent: tensor([7, 3, 5, 4])
Output sent: tensor([3, 0, 3, 3])
torch.Size([1, 4, 1024])
Shape of Decoder output: torch.Size([1, 4, 512])
The predicted logits are tensor([[[ 0.4488, -0.6622,  0.8067,  0.0869,  0.7885,  0.4761, -0.8387,
           0.5125],
         [ 0.1403, -0.7855,  1.0671,  0.4012,  1.1623,  0.5504, -0.7449,
           0.5894],
         [ 0.4802, -0.9749,  0.8247,  0.2037,  0.5942,  0.6012, -0.8240,
           0.4386],
         [ 0.4189, -0.9863,  0.8025,  0.3035,  0.6784,  0.4982, -0.7471,
           0.5425]]], grad_fn=<ViewBackward0>)
Contribution of each word to the loss: tensor([2.3401, 2.4411, 2.2056, 2.1204], grad_fn=<NllLossBackward0>)
Total loss: 2.27679443359375



  return t.tensor(final_sent, dtype = t.float32)
  return t.tensor(final_sent, dtype = t.float32)


In [2]:
loss_contributions.backward(t.tensor([1, 1, 1, 1], dtype=t.float32), retain_graph=True)
print(model.linear_output.weight.grad)
model.zero_grad()

tensor([[-0.7219,  0.1676,  0.1594,  ..., -0.8444,  0.5558, -0.7131],
        [ 0.1068, -0.0358, -0.2449,  ...,  0.1662, -0.2096,  0.0177],
        [ 0.6245, -0.2614, -1.3262,  ...,  0.9298, -1.1596,  0.1132],
        ...,
        [ 0.4396, -0.1916, -0.9582,  ...,  0.6514, -0.8291,  0.0583],
        [ 0.1146, -0.0519, -0.2563,  ...,  0.1736, -0.2203,  0.0166],
        [ 0.4222, -0.1835, -0.9495,  ...,  0.6449, -0.8153,  0.0655]])


In [3]:
loss_contributions.backward(t.tensor([1, 1, 1, 1], dtype=t.float32), retain_graph=True)
print(model.linear_output.weight.grad)
model.zero_grad()

tensor([[-0.7219,  0.1676,  0.1594,  ..., -0.8444,  0.5558, -0.7131],
        [ 0.1068, -0.0358, -0.2449,  ...,  0.1662, -0.2096,  0.0177],
        [ 0.6245, -0.2614, -1.3262,  ...,  0.9298, -1.1596,  0.1132],
        ...,
        [ 0.4396, -0.1916, -0.9582,  ...,  0.6514, -0.8291,  0.0583],
        [ 0.1146, -0.0519, -0.2563,  ...,  0.1736, -0.2203,  0.0166],
        [ 0.4222, -0.1835, -0.9495,  ...,  0.6449, -0.8153,  0.0655]])


In [4]:
loss_contributions.backward(t.tensor([1,0,0,0], dtype=t.float32), retain_graph=True)
print(model.linear_output.weight.grad)
model.zero_grad()

tensor([[ 0.0590,  0.0772, -0.2896,  ...,  0.1595, -0.2137,  0.0045],
        [ 0.0194,  0.0254, -0.0953,  ...,  0.0525, -0.0704,  0.0015],
        [ 0.0844,  0.1105, -0.4142,  ...,  0.2281, -0.3056,  0.0064],
        ...,
        [ 0.0606,  0.0794, -0.2976,  ...,  0.1639, -0.2196,  0.0046],
        [ 0.0163,  0.0213, -0.0799,  ...,  0.0440, -0.0590,  0.0012],
        [ 0.0629,  0.0823, -0.3086,  ...,  0.1700, -0.2277,  0.0047]])


In [5]:
loss_contributions.backward(t.tensor([0, 1, 0, 0], dtype=t.float32), retain_graph=True)
print(model.linear_output.weight.grad)
model.zero_grad()

tensor([[-0.9923,  0.2979,  0.9268,  ..., -1.2781,  1.1681, -0.6781],
        [ 0.0375, -0.0113, -0.0350,  ...,  0.0483, -0.0441,  0.0256],
        [ 0.2391, -0.0718, -0.2233,  ...,  0.3079, -0.2814,  0.1634],
        ...,
        [ 0.1426, -0.0428, -0.1332,  ...,  0.1837, -0.1679,  0.0975],
        [ 0.0390, -0.0117, -0.0365,  ...,  0.0503, -0.0460,  0.0267],
        [ 0.1483, -0.0445, -0.1385,  ...,  0.1910, -0.1745,  0.1013]])


In [6]:
loss_contributions.backward(t.tensor([0, 0, 1, 0], dtype=t.float32), retain_graph=True)
print(model.linear_output.weight.grad)
model.zero_grad()

tensor([[ 0.1634, -0.0780, -0.2220,  ...,  0.1530, -0.2207, -0.0233],
        [ 0.0381, -0.0182, -0.0518,  ...,  0.0357, -0.0515, -0.0054],
        [ 0.2307, -0.1101, -0.3133,  ...,  0.2159, -0.3115, -0.0329],
        ...,
        [ 0.1845, -0.0881, -0.2506,  ...,  0.1726, -0.2491, -0.0263],
        [ 0.0444, -0.0212, -0.0603,  ...,  0.0415, -0.0599, -0.0063],
        [ 0.1568, -0.0748, -0.2130,  ...,  0.1467, -0.2117, -0.0224]])


In [7]:
loss_contributions.backward(t.tensor([0, 0, 0, 1], dtype=t.float32), retain_graph=True)
print(model.linear_output.weight.grad)
model.zero_grad()

tensor([[ 0.0480, -0.1295, -0.2558,  ...,  0.1212, -0.1779, -0.0161],
        [ 0.0118, -0.0318, -0.0627,  ...,  0.0297, -0.0436, -0.0040],
        [ 0.0704, -0.1900, -0.3754,  ...,  0.1779, -0.2611, -0.0237],
        ...,
        [ 0.0519, -0.1402, -0.2769,  ...,  0.1312, -0.1926, -0.0175],
        [ 0.0149, -0.0403, -0.0797,  ...,  0.0378, -0.0554, -0.0050],
        [ 0.0543, -0.1465, -0.2894,  ...,  0.1371, -0.2013, -0.0183]])


In [43]:
import torch as t
import torch.nn as nn
import argparse

import warnings
warnings.filterwarnings("ignore")


from imports import *
from encoder import Encoder as encoder
from decoder import Decoder as decoder

# Taking a simple sent that will be passed into the transformer. 
sent: t.Tensor = t.randint(0, 8, (4,))
out_sent: t.Tensor = t.randint(0, 8, (4,))

print(f"Input sent: {sent}")
print(f"Output sent: {out_sent}")

class Transformer(nn.Module):
    def __init__(self, num_heads: int, sent: t.Tensor, out_sent: t.Tensor):
        super(Transformer, self).__init__()
        self.encoder = encoder(num_heads, sent)
        self.decoder = decoder(num_heads, out_sent)
        self.linear_output = nn.Linear(512, 8)  # 8 is the vocabulary size

        for params in self.parameters():
            params.requires_grad = True

    def forward(self) -> t.Tensor:
        encoder_output = self.encoder.forward()
        decoder_output = self.decoder.forward(encoder_output) 
        print(f"Shape of Decoder output: {decoder_output.shape}")
        output = self.linear_output(decoder_output)
        return output
    
# parser = argparse.ArgumentParser(
ce_loss = nn.CrossEntropyLoss()
num_heads = 2

model = Transformer(num_heads, sent, out_sent)
model.train()
predicted_logits = model.forward()

print(f"The predicted logits are {predicted_logits}")

ce_loss_none = nn.CrossEntropyLoss(reduction='none')

loss_contributions = ce_loss_none(predicted_logits.view(4, 8), out_sent)
print(f"Contribution of each word to the loss: {loss_contributions}")

loss = loss_contributions.sum()
print(f"Total loss: {loss}"); print()

loss_contributions.backward(t.tensor([1, 1, 1, 1], dtype=t.float32), retain_graph=True)
total  = model.linear_output.weight.grad
model.zero_grad()

loss_contributions.backward(t.tensor([1, 0, 0, 0], dtype=t.float32), retain_graph=True)
first_token  = model.linear_output.weight.grad
model.zero_grad()

loss_contributions.backward(t.tensor([0, 1, 0, 0], dtype=t.float32), retain_graph=True)
second_token  = model.linear_output.weight.grad
model.zero_grad()

loss_contributions.backward(t.tensor([0, 0, 1, 0], dtype=t.float32), retain_graph=True)
third_token  = model.linear_output.weight.grad
model.zero_grad()

loss_contributions.backward(t.tensor([0, 0, 0, 1], dtype=t.float32), retain_graph=True)
fourth_token  = model.linear_output.weight.grad
model.zero_grad()


print(f"Total gradient: {total[0,0]:.4f}")
print(f"The addition of gradient of first token: {first_token[0,0]:.4f}, second token: {second_token[0,0]:.4f}, third token: {third_token[0,0]:.4f}, fourth token: {fourth_token[0,0]:.4f} is: ")
print(first_token[0,0] + second_token[0,0] + third_token[0,0] + fourth_token[0,0])
print()

print(f"Are all the values equal of summed of gradient and total gradient: {t.allclose(total, first_token + second_token + third_token + fourth_token)}")

loss.backward(retain_graph=True)
_token  = model.linear_output.weight.grad
print(_token[0][0])
model.zero_grad()

print(f"Are all the values equal of without jacobian product-simply taking gradient and one taken by jacobian product: {t.allclose(total, _token)}")

Input sent: tensor([6, 3, 1, 0])
Output sent: tensor([5, 2, 7, 6])
torch.Size([1, 4, 1024])
Shape of Decoder output: torch.Size([1, 4, 512])
The predicted logits are tensor([[[-0.2316, -0.8765, -1.2680, -0.2764,  0.8412, -0.7650,  1.0826,
          -0.7140],
         [-0.3567, -0.9351, -1.1221, -0.3347,  0.9145, -0.6786,  0.8842,
          -0.6992],
         [ 0.0574, -0.4736, -1.4513, -0.0695,  0.6909, -0.6255,  1.2332,
          -0.8025],
         [-0.0403, -0.6667, -1.4379, -0.1408,  0.7269, -0.6249,  0.9919,
          -0.8684]]], grad_fn=<ViewBackward0>)
Contribution of each word to the loss: tensor([2.9023, 3.2083, 3.0281, 1.1244], grad_fn=<NllLossBackward0>)
Total loss: 10.263071060180664

Total gradient: -0.1308
The addition of gradient of first token: -0.0417, second token: 0.0102, third token: -0.0172, fourth token: -0.0820 is: 
tensor(-0.1308)

Are all the values equal of summed of gradient and total gradient: True
tensor(-0.1308)
Are all the values equal of without jacobian 

In [38]:
x = -0.1674/-0.0418
print(x)

4.0047846889952154
