In [45]:
import torch
from torch import nn
import types
from functools import partial


def extend(model, input_shape):
    if not isinstance(model, nn.Module):
        raise TypeError("model should be a nn.Module")
    if not isinstance(input_shape, tuple):
        raise TypeError("input_shape should be a tuple")

    device = next(model.parameters()).device

    weight_input_list = []
    weight_output_list = []
    weight_repeat_list = []
    bias_output_list = []
    bias_repeat_list = []

    x = torch.zeros((1,) + input_shape, device=device)
    with torch.no_grad():
        for module in model.children():
            y = module(x)
            if sum(p.numel() for p in module.parameters()):
                # for all layers with parameters

                # store parameters and clear bias for future calculation
                if module.weight is not None:
                    initial_weight = module.weight.data.clone()
                if module.bias is not None:
                    initial_bias = module.bias.data.clone()
                    module.bias.data = torch.zeros_like(module.bias)

                if module.weight is not None:
                    Nweight = module.weight.numel()
                    weight_input = []
                    weight_output = []
                    weight_repeat = torch.zeros(
                        Nweight, dtype=torch.long, device=device
                    )
                    Xeye = torch.eye(x.numel(), device=device).reshape(
                        (-1,) + x.shape[1:]
                    )
                    for i in range(Nweight):
                        weight = torch.zeros(Nweight, device=device)
                        weight[i] = 1.0
                        module.weight.data = weight.reshape(module.weight.shape)
                        # output of module is of dimension (j,k)
                        out = module(Xeye).reshape(x.numel(), y.numel())
                        if (out[out.abs() > 1e-5] - 1.0).abs().max() > 1e-5:
                            raise RuntimeError(
                                "the network is not written in the standard form, see https://github.com/ChenAo-Phys/pytorch-Jacobian"
                            )
                        nonzero = torch.nonzero(out > 0.5, as_tuple=False)
                        weight_input.append(nonzero[:, 0])
                        weight_output.append(nonzero[:, 1])
                        weight_repeat[i] = nonzero.shape[0]
                    weight_input_list.append(torch.cat(weight_input, dim=0))
                    weight_output_list.append(torch.cat(weight_output, dim=0))
                    weight_repeat_list.append(weight_repeat)
                    module.weight.data = initial_weight
                else:
                    weight_input_list.append(None)
                    weight_output_list.append(None)
                    weight_repeat_list.append(None)

                if module.bias is not None:
                    Nbias = module.bias.numel()
                    bias_output = []
                    bias_repeat = torch.zeros(Nbias, dtype=torch.long, device=device)
                    for i in range(Nbias):
                        bias = torch.zeros(Nbias, device=device)
                        bias[i] = 1.0
                        module.bias.data = bias.reshape(module.bias.shape)
                        out = module(x).reshape(-1)
                        if (out[out.abs() > 1e-5] - 1.0).abs().max() > 1e-5:
                            raise RuntimeError(
                                "the network is not written in the standard form, see https://github.com/ChenAo-Phys/pytorch-Jacobian"
                            )
                        nonzero = torch.nonzero(out > 0.5, as_tuple=False)
                        bias_output.append(nonzero[:, 0])
                        bias_repeat[i] = nonzero.shape[0]
                    bias_output_list.append(torch.cat(bias_output, dim=0))
                    bias_repeat_list.append(bias_repeat)
                    module.bias.data = initial_bias
                else:
                    bias_output_list.append(None)
                    bias_repeat_list.append(None)

            x = torch.zeros_like(y)

    if not hasattr(model, "_Jacobian_shape_dict"):
        model._Jacobian_shape_dict = {}
    model._Jacobian_shape_dict[input_shape] = (
        weight_input_list,
        weight_output_list,
        weight_repeat_list,
        bias_output_list,
        bias_repeat_list,
    )

    # assign jacobian method to model
    def jacobian(self, as_tuple=False):
        shape = self.input_shape
        if hasattr(self, "_Jacobian_shape_dict") and shape in self._Jacobian_shape_dict:
            (
                weight_input_list,
                weight_output_list,
                weight_repeat_list,
                bias_output_list,
                bias_repeat_list,
            ) = self._Jacobian_shape_dict[shape]
        else:
            raise RuntimeError(
                "model or specific input shape is not extended for jacobian calculation"
            )

        device = next(model.parameters()).device
        jac = []
        layer = 0
        for module in self.children():
            if sum(p.numel() for p in module.parameters()):
                weight_input = weight_input_list[layer]
                weight_output = weight_output_list[layer]
                weight_repeat = weight_repeat_list[layer]
                bias_output = bias_output_list[layer]
                bias_repeat = bias_repeat_list[layer]
                x = self.x_in[layer]
                N = x.shape[0]
                dz_dy = self.gradient[layer].reshape(N, -1)

                if weight_repeat is not None:
                    Nweight = weight_repeat.shape[0]
                    dz_dy_select = dz_dy[:, weight_output]
                    x_select = x.reshape(N, -1)[:, weight_input]
                    repeat = torch.repeat_interleave(weight_repeat)
                    dz_dW = torch.zeros(N, Nweight, device=device).index_add_(
                        1, repeat, dz_dy_select * x_select
                    )
                    if as_tuple:
                        dz_dW = dz_dW.reshape((N,) + module.weight.shape)
                    jac.append(dz_dW)
                if bias_repeat is not None:
                    Nbias = bias_repeat.shape[0]
                    dz_dy_select = dz_dy[:, bias_output]
                    repeat = torch.repeat_interleave(bias_repeat)
                    dz_db = torch.zeros(N, Nbias, device=device).index_add_(
                        1, repeat, dz_dy_select
                    )
                    if as_tuple:
                        dz_db = dz_db.reshape((N,) + module.bias.shape)
                    jac.append(dz_db)
                layer += 1

        if as_tuple:
            return tuple(jac)
        else:
            return torch.cat(jac, dim=1)

    if not hasattr(model, "jacobian"):
        model.jacobian = types.MethodType(jacobian, model)


class JacobianMode:
    def __init__(self, model):
        self.model = model
        if not isinstance(model, nn.Module):
            raise TypeError("model should be a nn.Module")

    def __enter__(self):
        model = self.model
        model.x_in = []
        model.gradient = []
        self.forward_pre_hook = []
        self.backward_hook = []

        def record_input_shape(self, input):
            model.input_shape = input[0].shape[1:]

        def record_forward(self, input, layer):
            model.x_in[layer] = input[0].detach()

        def record_backward(self, grad_input, grad_output, layer):
            model.gradient[layer] = grad_output[0]

        module0 = next(model.children())
        self.first_forward_hook = module0.register_forward_pre_hook(record_input_shape)

        layer = 0
        for module in model.children():
            if sum(p.numel() for p in module.parameters()):
                model.x_in.append(None)
                model.gradient.append(None)
                self.forward_pre_hook.append(
                    module.register_forward_pre_hook(
                        partial(record_forward, layer=layer)
                    )
                )
                self.backward_hook.append(
                    module.register_backward_hook(partial(record_backward, layer=layer))
                )
                layer += 1

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.first_forward_hook.remove()
        for hook in self.forward_pre_hook:
            hook.remove()
        for hook in self.backward_hook:
            hook.remove()

        del self.model.input_shape
        del self.model.x_in
        del self.model.gradient

In [44]:
import torch as t
import torch.nn as nn
import argparse

import warnings
warnings.filterwarnings("ignore")


from imports import *
from encoder import Encoder as encoder
from decoder import Decoder as decoder

# Taking a simple sent that will be passed into the transformer. 
sent: t.Tensor = t.randint(0, 8, (4,))
out_sent: t.Tensor = t.randint(0, 8, (4,))

print(f"Input sent: {sent}")
print(f"Output sent: {out_sent}")

class Transformer(nn.Module):
    def __init__(self, num_heads: int, sent: t.Tensor, out_sent: t.Tensor):
        super(Transformer, self).__init__()
        self.encoder = encoder(num_heads, sent)
        self.decoder = decoder(num_heads, out_sent)
        self.linear_output = nn.Linear(512, 8)  # 8 is the vocabulary size

        for params in self.parameters():
            params.requires_grad = True

    def forward(self) -> t.Tensor:
        encoder_output = self.encoder.forward()
        decoder_output = self.decoder.forward(encoder_output) 
        print(f"Shape of Decoder output: {decoder_output.shape}")
        output = self.linear_output(decoder_output)
        return output
    
# parser = argparse.ArgumentParser(
ce_loss = nn.CrossEntropyLoss()
num_heads = 2

model = Transformer(num_heads, sent, out_sent)
model.train()
predicted_logits = model.forward()

print(f"The predicted logits are {predicted_logits}")

ce_loss_none = nn.CrossEntropyLoss(reduction='none')

loss_contributions = ce_loss_none(predicted_logits.view(4, 8), out_sent)
print(f"Contribution of each word to the loss: {loss_contributions}")

loss = loss_contributions.sum()
print(f"Total loss: {loss}"); print()




Input sent: tensor([3, 5, 2, 3])
Output sent: tensor([3, 4, 2, 0])
torch.Size([1, 4, 1024])
Shape of Decoder output: torch.Size([1, 4, 512])
The predicted logits are tensor([[[-1.1023, -0.9961,  0.2055, -0.4318,  0.3488,  0.3634, -0.1508,
          -0.0252],
         [-1.2135, -0.9771,  0.2404, -0.4032,  0.3361,  0.4185, -0.1972,
          -0.0302],
         [-1.2262, -0.9320,  0.3724, -0.4092,  0.1326,  0.3861, -0.4339,
           0.0442],
         [-1.3338, -0.8611,  0.4725, -0.3433,  0.2549,  0.2428, -0.6383,
           0.1921]]], grad_fn=<ViewBackward0>)
Contribution of each word to the loss: tensor([2.4155, 1.6550, 1.5889, 3.3223], grad_fn=<NllLossBackward0>)
Total loss: 8.98176097869873



In [43]:
import torch as t
import torch.nn as nn
import argparse

import warnings
warnings.filterwarnings("ignore")


from imports import *
from encoder import Encoder as encoder
from decoder import Decoder as decoder

# Taking a simple sent that will be passed into the transformer. 
sent: t.Tensor = t.randint(0, 8, (4,))
out_sent: t.Tensor = t.randint(0, 8, (4,))

print(f"Input sent: {sent}")
print(f"Output sent: {out_sent}")

class Transformer(nn.Module):
    def __init__(self, num_heads: int, sent: t.Tensor, out_sent: t.Tensor):
        super(Transformer, self).__init__()
        self.encoder = encoder(num_heads, sent)
        self.decoder = decoder(num_heads, out_sent)
        self.linear_output = nn.Linear(512, 8)  # 8 is the vocabulary size

        for params in self.parameters():
            params.requires_grad = True

    def forward(self) -> t.Tensor:
        encoder_output = self.encoder.forward()
        decoder_output = self.decoder.forward(encoder_output) 
        print(f"Shape of Decoder output: {decoder_output.shape}")
        output = self.linear_output(decoder_output)
        return output
    
# parser = argparse.ArgumentParser(
ce_loss = nn.CrossEntropyLoss()
num_heads = 2

model = Transformer(num_heads, sent, out_sent)
model.train()
predicted_logits = model.forward()

print(f"The predicted logits are {predicted_logits}")

ce_loss_none = nn.CrossEntropyLoss(reduction='none')

loss_contributions = ce_loss_none(predicted_logits.view(4, 8), out_sent)
print(f"Contribution of each word to the loss: {loss_contributions}")

loss = loss_contributions.sum()
print(f"Total loss: {loss}"); print()

loss_contributions.backward(t.tensor([1, 1, 1, 1], dtype=t.float32), retain_graph=True)
total  = model.linear_output.weight.grad
model.zero_grad()

loss_contributions.backward(t.tensor([1, 0, 0, 0], dtype=t.float32), retain_graph=True)
first_token  = model.linear_output.weight.grad
model.zero_grad()

loss_contributions.backward(t.tensor([0, 1, 0, 0], dtype=t.float32), retain_graph=True)
second_token  = model.linear_output.weight.grad
model.zero_grad()

loss_contributions.backward(t.tensor([0, 0, 1, 0], dtype=t.float32), retain_graph=True)
third_token  = model.linear_output.weight.grad
model.zero_grad()

loss_contributions.backward(t.tensor([0, 0, 0, 1], dtype=t.float32), retain_graph=True)
fourth_token  = model.linear_output.weight.grad
model.zero_grad()


print(f"Total gradient: {total[0,0]:.4f}")
print(f"The addition of gradient of first token: {first_token[0,0]:.4f}, second token: {second_token[0,0]:.4f}, third token: {third_token[0,0]:.4f}, fourth token: {fourth_token[0,0]:.4f} is: ")
print(first_token[0,0] + second_token[0,0] + third_token[0,0] + fourth_token[0,0])
print()

print(f"Are all the values equal of summed of gradient and total gradient: {t.allclose(total, first_token + second_token + third_token + fourth_token)}")

loss.backward(retain_graph=True)
_token  = model.linear_output.weight.grad
print(_token[0][0])
model.zero_grad()

print(f"Are all the values equal of without jacobian product-simply taking gradient and one taken by jacobian product: {t.allclose(total, _token)}")

Input sent: tensor([6, 3, 1, 0])
Output sent: tensor([5, 2, 7, 6])
torch.Size([1, 4, 1024])
Shape of Decoder output: torch.Size([1, 4, 512])
The predicted logits are tensor([[[-0.2316, -0.8765, -1.2680, -0.2764,  0.8412, -0.7650,  1.0826,
          -0.7140],
         [-0.3567, -0.9351, -1.1221, -0.3347,  0.9145, -0.6786,  0.8842,
          -0.6992],
         [ 0.0574, -0.4736, -1.4513, -0.0695,  0.6909, -0.6255,  1.2332,
          -0.8025],
         [-0.0403, -0.6667, -1.4379, -0.1408,  0.7269, -0.6249,  0.9919,
          -0.8684]]], grad_fn=<ViewBackward0>)
Contribution of each word to the loss: tensor([2.9023, 3.2083, 3.0281, 1.1244], grad_fn=<NllLossBackward0>)
Total loss: 10.263071060180664

Total gradient: -0.1308
The addition of gradient of first token: -0.0417, second token: 0.0102, third token: -0.0172, fourth token: -0.0820 is: 
tensor(-0.1308)

Are all the values equal of summed of gradient and total gradient: True
tensor(-0.1308)
Are all the values equal of without jacobian 

In [38]:
x = -0.1674/-0.0418
print(x)

4.0047846889952154
