# 🐶 External code for the problem

In [2]:
import torch
from torch import nn
import types
from functools import partial


def extend(model, input_shape):
    if not isinstance(model, nn.Module):
        raise TypeError("model should be a nn.Module")
    if not isinstance(input_shape, tuple):
        raise TypeError("input_shape should be a tuple")

    device = next(model.parameters()).device

    weight_input_list = []
    weight_output_list = []
    weight_repeat_list = []
    bias_output_list = []
    bias_repeat_list = []

    x = torch.zeros((1,) + input_shape, device=device)
    with torch.no_grad():
        for module in model.children():
            y = module(x)
            print(y)
            if sum(p.numel() for p in module.parameters()):
                # for all layers with parameters

                # store parameters and clear bias for future calculation
                if module.weight is not None:
                    initial_weight = module.weight.data.clone()
                if module.bias is not None:
                    initial_bias = module.bias.data.clone()
                    module.bias.data = torch.zeros_like(module.bias)

                if module.weight is not None:
                    Nweight = module.weight.numel()
                    weight_input = []
                    weight_output = []
                    weight_repeat = torch.zeros(
                        Nweight, dtype=torch.long, device=device
                    )
                    Xeye = torch.eye(x.numel(), device=device).reshape(
                        (-1,) + x.shape[1:]
                    )
                    for i in range(Nweight):
                        weight = torch.zeros(Nweight, device=device)
                        weight[i] = 1.0
                        module.weight.data = weight.reshape(module.weight.shape)
                        # output of module is of dimension (j,k)
                        out = module(Xeye).reshape(x.numel(), y.numel())
                        if (out[out.abs() > 1e-5] - 1.0).abs().max() > 1e-5:
                            raise RuntimeError(
                                "the network is not written in the standard form, see https://github.com/ChenAo-Phys/pytorch-Jacobian"
                            )
                        nonzero = torch.nonzero(out > 0.5, as_tuple=False)
                        weight_input.append(nonzero[:, 0])
                        weight_output.append(nonzero[:, 1])
                        weight_repeat[i] = nonzero.shape[0]
                    weight_input_list.append(torch.cat(weight_input, dim=0))
                    weight_output_list.append(torch.cat(weight_output, dim=0))
                    weight_repeat_list.append(weight_repeat)
                    module.weight.data = initial_weight
                else:
                    weight_input_list.append(None)
                    weight_output_list.append(None)
                    weight_repeat_list.append(None)

                if module.bias is not None:
                    Nbias = module.bias.numel()
                    bias_output = []
                    bias_repeat = torch.zeros(Nbias, dtype=torch.long, device=device)
                    for i in range(Nbias):
                        bias = torch.zeros(Nbias, device=device)
                        bias[i] = 1.0
                        module.bias.data = bias.reshape(module.bias.shape)
                        out = module(x).reshape(-1)
                        if (out[out.abs() > 1e-5] - 1.0).abs().max() > 1e-5:
                            raise RuntimeError(
                                "the network is not written in the standard form, see https://github.com/ChenAo-Phys/pytorch-Jacobian"
                            )
                        nonzero = torch.nonzero(out > 0.5, as_tuple=False)
                        bias_output.append(nonzero[:, 0])
                        bias_repeat[i] = nonzero.shape[0]
                    bias_output_list.append(torch.cat(bias_output, dim=0))
                    bias_repeat_list.append(bias_repeat)
                    module.bias.data = initial_bias
                else:
                    bias_output_list.append(None)
                    bias_repeat_list.append(None)

            x = torch.zeros_like(y)

    if not hasattr(model, "_Jacobian_shape_dict"):
        model._Jacobian_shape_dict = {}
    model._Jacobian_shape_dict[input_shape] = (
        weight_input_list,
        weight_output_list,
        weight_repeat_list,
        bias_output_list,
        bias_repeat_list,
    )

    # assign jacobian method to model
    def jacobian(self, as_tuple=False):
        shape = self.input_shape
        if hasattr(self, "_Jacobian_shape_dict") and shape in self._Jacobian_shape_dict:
            (
                weight_input_list,
                weight_output_list,
                weight_repeat_list,
                bias_output_list,
                bias_repeat_list,
            ) = self._Jacobian_shape_dict[shape]
        else:
            raise RuntimeError(
                "model or specific input shape is not extended for jacobian calculation"
            )

        device = next(model.parameters()).device
        jac = []
        layer = 0
        for module in self.children():
            if sum(p.numel() for p in module.parameters()):
                weight_input = weight_input_list[layer]
                weight_output = weight_output_list[layer]
                weight_repeat = weight_repeat_list[layer]
                bias_output = bias_output_list[layer]
                bias_repeat = bias_repeat_list[layer]
                x = self.x_in[layer]
                N = x.shape[0]
                dz_dy = self.gradient[layer].reshape(N, -1)

                if weight_repeat is not None:
                    Nweight = weight_repeat.shape[0]
                    dz_dy_select = dz_dy[:, weight_output]
                    x_select = x.reshape(N, -1)[:, weight_input]
                    repeat = torch.repeat_interleave(weight_repeat)
                    dz_dW = torch.zeros(N, Nweight, device=device).index_add_(
                        1, repeat, dz_dy_select * x_select
                    )
                    if as_tuple:
                        dz_dW = dz_dW.reshape((N,) + module.weight.shape)
                    jac.append(dz_dW)
                if bias_repeat is not None:
                    Nbias = bias_repeat.shape[0]
                    dz_dy_select = dz_dy[:, bias_output]
                    repeat = torch.repeat_interleave(bias_repeat)
                    dz_db = torch.zeros(N, Nbias, device=device).index_add_(
                        1, repeat, dz_dy_select
                    )
                    if as_tuple:
                        dz_db = dz_db.reshape((N,) + module.bias.shape)
                    jac.append(dz_db)
                layer += 1

        if as_tuple:
            return tuple(jac)
        else:
            return torch.cat(jac, dim=1)

    if not hasattr(model, "jacobian"):
        model.jacobian = types.MethodType(jacobian, model)


class JacobianMode:
    def __init__(self, model):
        self.model = model
        if not isinstance(model, nn.Module):
            raise TypeError("model should be a nn.Module")

    def __enter__(self):
        model = self.model
        model.x_in = []
        model.gradient = []
        self.forward_pre_hook = []
        self.backward_hook = []

        def record_input_shape(self, input):
            model.input_shape = input[0].shape[1:]

        def record_forward(self, input, layer):
            model.x_in[layer] = input[0].detach()

        def record_backward(self, grad_input, grad_output, layer):
            model.gradient[layer] = grad_output[0]

        module0 = next(model.children())
        self.first_forward_hook = module0.register_forward_pre_hook(record_input_shape)

        layer = 0
        for module in model.children():
            if sum(p.numel() for p in module.parameters()):
                model.x_in.append(None)
                model.gradient.append(None)
                self.forward_pre_hook.append(
                    module.register_forward_pre_hook(
                        partial(record_forward, layer=layer)
                    )
                )
                self.backward_hook.append(
                    module.register_backward_hook(partial(record_backward, layer=layer))
                )
                layer += 1

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.first_forward_hook.remove()
        for hook in self.forward_pre_hook:
            hook.remove()
        for hook in self.backward_hook:
            hook.remove()

        del self.model.input_shape
        del self.model.x_in
        del self.model.gradient

In [6]:
class MLP(nn.Module):
    def __init__(self, input_shape, output_shape):
        super(MLP, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(input_shape[1] * input_shape[2], 64)
        self.fc3 = nn.Linear(64, output_shape[1])

    def forward(self, x):
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        x = self.fc3(x)
        return x

input_shape = (1, 2, 4)
output_shape = (1, 5)
net = MLP(input_shape, output_shape)

# Example input
example_input = torch.randn(input_shape)

extend(net, (1,2,4))

x = torch.randn(2,1,2,4)
# output shape will be (1000,)

# Jacobian computed by the improved method
# On Colab CPU 0.16s, K80 GPU 0.14s
with JacobianMode(net):
    out = net(x)
    out.sum().backward()
    jac = net.jacobian()
    
print(f"The jacobian shape is {jac.shape}")

tensor([[0., 0., 0., 0., 0., 0., 0., 0.]])
tensor([[ 0.0081,  0.0952, -0.0858, -0.3452, -0.2276,  0.1357, -0.1907,  0.1563,
         -0.1449, -0.0789, -0.1447,  0.2744, -0.1838, -0.2718, -0.0068,  0.2000,
          0.2298,  0.0653, -0.2904, -0.0504,  0.0972, -0.2140,  0.0330,  0.0193,
          0.1591, -0.2106, -0.3524, -0.0752, -0.1981,  0.1219,  0.1868,  0.1009,
          0.1255, -0.1599, -0.0662,  0.0713,  0.1353, -0.2805,  0.0187, -0.2937,
         -0.3135,  0.0054,  0.0703,  0.0291, -0.0976, -0.0992, -0.1811,  0.1716,
         -0.3285,  0.3295,  0.2083, -0.3178, -0.2265,  0.1358, -0.2843, -0.2679,
          0.0066,  0.0384,  0.0302, -0.1789,  0.0421, -0.1842, -0.3064,  0.0110]])
tensor([[ 0.0690,  0.0086, -0.0056,  0.0820, -0.1050]])
The jacobian shape is torch.Size([2, 901])


In [56]:
import torch as t
import torch.nn as nn
import argparse

import warnings
warnings.filterwarnings("ignore")


from imports import *
from encoder import Encoder as encoder
from decoder import Decoder as decoder

# Taking a simple sent that will be passed into the transformer. 
sent: t.Tensor = t.randint(0, 8, (4,))
out_sent: t.Tensor = t.randint(0, 8, (4,))

print(f"Input sent: {sent}")
print(f"Output sent: {out_sent}")

class Transformer(nn.Module):
    def __init__(self, num_heads: int, sent: t.Tensor, out_sent: t.Tensor):
        super(Transformer, self).__init__()
        self.encoder = encoder(num_heads, sent)
        self.decoder = decoder(num_heads, out_sent)
        self.linear_output = nn.Linear(512, 8)  # 8 is the vocabulary size

        for params in self.parameters():
            params.requires_grad = True

    def forward(self) -> t.Tensor:
        encoder_output = self.encoder.forward()
        decoder_output = self.decoder.forward(encoder_output) 
        print(f"Shape of Decoder output: {decoder_output.shape}")
        output = self.linear_output(decoder_output)
        return output
    
# parser = argparse.ArgumentParser(
ce_loss = nn.CrossEntropyLoss()
num_heads = 2

model = Transformer(num_heads, sent, out_sent)
model.train()
predicted_logits = model.forward()

extend(model, (4,))

print(f"The predicted logits are {predicted_logits}")

ce_loss_none = nn.CrossEntropyLoss(reduction='none')

loss_contributions = ce_loss_none(predicted_logits.view(4, 8), out_sent)
print(f"Contribution of each word to the loss: {loss_contributions}")

loss = loss_contributions.sum()
print(f"Total loss: {loss}"); print()





Input sent: tensor([5, 3, 6, 0])
Output sent: tensor([4, 2, 3, 5])
torch.Size([1, 4, 1024])
Shape of Decoder output: torch.Size([1, 4, 512])


TypeError: Encoder.forward() takes 1 positional argument but 2 were given

In [64]:
x = torch.randn(4, 512)

for module in model.children():
    for m in module.children():
        print(m)
        print(m(x))


Linear(in_features=512, out_features=1024, bias=True)
tensor([[ 0.0369,  0.5054,  0.4491,  ..., -0.2467,  0.1148, -0.4064],
        [ 0.4862, -0.0840,  0.9288,  ...,  0.2381, -0.2027,  0.7542],
        [ 0.4276, -0.3685,  0.2085,  ..., -0.2436, -0.5113,  0.5940],
        [ 0.9637,  1.6171,  0.2466,  ...,  0.0250,  0.0532, -0.9362]],
       grad_fn=<AddmmBackward0>)
Linear(in_features=512, out_features=1024, bias=True)
tensor([[-0.7159,  0.1916, -0.1559,  ..., -0.6256, -0.3449,  0.2954],
        [ 0.1083,  0.3251, -0.3764,  ..., -0.5664,  1.1568,  0.1296],
        [ 0.1103,  0.5370, -0.1760,  ...,  0.1072, -0.4360,  0.0615],
        [ 0.0761, -0.7269, -0.2071,  ...,  0.7534, -1.1132, -0.6404]],
       grad_fn=<AddmmBackward0>)
Linear(in_features=512, out_features=1024, bias=True)
tensor([[ 0.2403,  0.1211,  0.6501,  ...,  0.1675,  0.4466,  0.6892],
        [-0.8572, -0.4634, -0.0071,  ..., -0.7419, -0.4293,  1.0364],
        [ 0.4096,  0.5001, -0.0366,  ..., -0.4972,  0.2767, -0.3778],


RuntimeError: mat1 and mat2 shapes cannot be multiplied (4x512 and 1024x512)

# 🗿 Custom Code for the problem

In [8]:
from imports import *
import argparse
import sys
import os

import warnings
warnings.filterwarnings("ignore")

sys.path.append("/Users/maheepchaudhary/pytorch/Projects/Transformer-from-Scratch/")

from encoder import Encoder as encoder
from decoder import Decoder as decoder

# Taking a simple sent that will be passed into the transformer. 
sent: t.Tensor = t.randint(0, 8, (4,))
out_sent: t.Tensor = t.randint(0, 8, (4,))

print(f"Input sent: {sent}")
print(f"Output sent: {out_sent}")

class Transformer(nn.Module):
    def __init__(self, num_heads: int, sent: t.Tensor, out_sent: t.Tensor):
        super(Transformer, self).__init__()
        self.encoder = encoder(num_heads, sent)
        self.decoder = decoder(num_heads, out_sent)
        self.linear_output = nn.Linear(512, 8)  # 8 is the vocabulary size

        for params in self.parameters():
            params.requires_grad = True

    def forward(self) -> t.Tensor:
        encoder_output = self.encoder.forward()
        decoder_output = self.decoder.forward(encoder_output) 
        print(f"Shape of Decoder output: {decoder_output.shape}")
        output = self.linear_output(decoder_output)
        return output
    
# parser = argparse.ArgumentParser(
ce_loss = nn.CrossEntropyLoss()
num_heads = 2

model = Transformer(num_heads, sent, out_sent)
print(model)
model.train()
predicted_logits = model.forward()

print(f"The predicted logits are {predicted_logits}")

ce_loss_none = nn.CrossEntropyLoss(reduction='none')

loss_contributions = ce_loss_none(predicted_logits.view(4, 8), out_sent)
print(f"Contribution of each word to the loss: {loss_contributions}")

loss = loss_contributions.sum()
print(f"Total loss: {loss}"); print()

# loss_contributions.backward(t.tensor([1, 1, 1, 1], dtype=t.float32), retain_graph=True)
# total  = model.linear_output.weight.grad
# model.zero_grad()

# loss_contributions.backward(t.tensor([1, 0, 0, 0], dtype=t.float32), retain_graph=True)
# first_token  = model.linear_output.weight.grad
# model.zero_grad()

# loss_contributions.backward(t.tensor([0, 1, 0, 0], dtype=t.float32), retain_graph=True)
# second_token  = model.linear_output.weight.grad
# model.zero_grad()

# loss_contributions.backward(t.tensor([0, 0, 1, 0], dtype=t.float32), retain_graph=True)
# third_token  = model.linear_output.weight.grad
# model.zero_grad()

# loss_contributions.backward(t.tensor([0, 0, 0, 1], dtype=t.float32), retain_graph=True)
# fourth_token  = model.linear_output.weight.grad
# model.zero_grad()



Input sent: tensor([0, 0, 2, 6])
Output sent: tensor([2, 5, 3, 4])
Transformer(
  (encoder): Encoder(
    (W_q): Linear(in_features=512, out_features=1024, bias=True)
    (W_k): Linear(in_features=512, out_features=1024, bias=True)
    (W_v): Linear(in_features=512, out_features=1024, bias=True)
    (W_o): Linear(in_features=1024, out_features=512, bias=True)
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (fc1): Linear(in_features=512, out_features=1024, bias=True)
    (fc2): Linear(in_features=1024, out_features=512, bias=True)
    (relu): ReLU()
  )
  (decoder): Decoder(
    (W_q): Linear(in_features=512, out_features=1024, bias=True)
    (W_k): Linear(in_features=512, out_features=1024, bias=True)
    (W_v): Linear(in_features=512, out_features=1024, bias=True)
    (W_q_m): Linear(in_features=512, out_features=1024, bias=True)
    (W_k_m): Linear(in_features=512, out_features=1024, bias=True)
    (W_v_m): Linear(in_features=512, out_features=1024, bias=T

In [9]:
print(model)

Transformer(
  (encoder): Encoder(
    (W_q): Linear(in_features=512, out_features=1024, bias=True)
    (W_k): Linear(in_features=512, out_features=1024, bias=True)
    (W_v): Linear(in_features=512, out_features=1024, bias=True)
    (W_o): Linear(in_features=1024, out_features=512, bias=True)
    (layer_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
    (fc1): Linear(in_features=512, out_features=1024, bias=True)
    (fc2): Linear(in_features=1024, out_features=512, bias=True)
    (relu): ReLU()
  )
  (decoder): Decoder(
    (W_q): Linear(in_features=512, out_features=1024, bias=True)
    (W_k): Linear(in_features=512, out_features=1024, bias=True)
    (W_v): Linear(in_features=512, out_features=1024, bias=True)
    (W_q_m): Linear(in_features=512, out_features=1024, bias=True)
    (W_k_m): Linear(in_features=512, out_features=1024, bias=True)
    (W_v_m): Linear(in_features=512, out_features=1024, bias=True)
    (W_o): Linear(in_features=1024, out_features=512, bias=Tru

In [12]:
loss.backward(retain_graph=True)
last_layer = model.linear_output.weight.grad
print(f"The shape of gradient of last layer is {last_layer}")
print()
model.zero_grad()


predicted_logits_grad = t.autograd.grad(loss, model.decoder.fc2.weight, retain_graph=True, allow_unused=True)[0]
print(f"The shape of the gradient of predicted logits is {predicted_logits_grad.shape}")
model.zero_grad()

A = t.linalg.pinv(predicted_logits_grad)
A = A.view(8, 4, 1)  # Adjusting the shape of A to (4, 8, 1)

# Adjusting the shape of last_layer to (4, 1, 512) since 4 * 1 * 512 = 2048
last_layer = last_layer.view(8, 1, 512)

C = t.matmul(A, last_layer)  # Performing the matrix multiplication

# Ensuring the resultant shape is (4, 8, 512)
C = C.view(4, 8, 512)

print(f"The shape of the resultant matrix C is {C.shape}")

The shape of gradient of last layer is tensor([[ 0.3603, -0.4221, -0.1045,  ..., -0.2058,  0.4342, -0.4643],
        [ 0.5072, -0.5805, -0.1513,  ..., -0.2974,  0.6041, -0.6573],
        [-0.3724, -0.0117,  0.5763,  ...,  0.3500, -0.5062,  0.6956],
        ...,
        [-1.0166,  0.7122, -0.0242,  ...,  0.6126, -0.7392,  1.0038],
        [ 0.5224, -0.5556, -0.1586,  ..., -0.3127,  0.6003, -0.6696],
        [ 0.7897, -0.8532, -0.2648,  ..., -0.4762,  0.9307, -1.0331]])

The shape of the gradient of predicted logits is torch.Size([512, 1024])


RuntimeError: shape '[8, 4, 1]' is invalid for input of size 524288