In [1]:
# important: update to torch 2.0 s.t. pre_hooks are available
!pip install torch torchvision -U

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting torch
  Downloading torch-2.0.0-cp39-cp39-manylinux1_x86_64.whl (619.9 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m619.9/619.9 MB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
Collecting torchvision
  Downloading torchvision-0.15.1-cp39-cp39-manylinux1_x86_64.whl (6.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.0/6.0 MB[0m [31m96.6 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cufft-cu11==10.9.0.58
  Downloading nvidia_cufft_cu11-10.9.0.58-py3-none-manylinux1_x86_64.whl (168.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m168.4/168.4 MB[0m [31m7.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cudnn-cu11==8.5.0.96
  Downloading nvidia_cudnn_cu11-8.5.0.96-2-py3-none-manylinux1_x86_64.whl (557.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m557.1/557.1 MB[0m [31m2.9 MB/s[

In [24]:
import torch
from torch import nn
from torch.utils.data import DataLoader,TensorDataset,random_split
from torchvision import datasets, transforms

import matplotlib.pyplot as plt
import torchvision
import numpy as np
import time, os, copy, random
print("PyTorch Version: ",torch.__version__)
print("Torchvision Version: ",torchvision.__version__)

PyTorch Version:  2.0.0+cu117
Torchvision Version:  0.15.1+cu117


# Create artifical data

Set random seeds.

In [25]:
torch.manual_seed(0)
random.seed(0)
np.random.seed(0)

Set the number of batches and the batch size. For these early tests, 2 batches of size 3 should give good insights while not being to complicated.

In [26]:
n_batches = 5
batch_size = 3
n_samples = n_batches * batch_size

The network architecture, initial weights and test data is similar to this source:
https://www.kaggle.com/code/sironghuang/understanding-pytorch-hooks.

In the linked notebook, only one datapoint is evaluated. Here, this datapoint will be repeated to include the effects of using batches.

In [27]:
artifical_data = torch.empty((n_samples,2))
artifical_data[:,0] = 0.05
artifical_data[:,1] = 0.1
print(f'dataset size :{artifical_data.shape}')
print(f'single sample, size: {artifical_data[0,:].shape} | values: {artifical_data[0,:]}')

dataset size :torch.Size([15, 2])
single sample, size: torch.Size([2]) | values: tensor([0.0500, 0.1000])


In [28]:
artifical_labels = torch.empty_like(artifical_data)
artifical_labels[:,0] = 0.01
artifical_labels[:,1] = 0.99
print(f'label set size :{artifical_labels.shape}')
print(f'single label, size: {artifical_labels[0,:].shape} | values: {artifical_labels[0,:]}')

label set size :torch.Size([15, 2])
single label, size: torch.Size([2]) | values: tensor([0.0100, 0.9900])


Next, the datasets and dataloader are created from the tensors. The first 4*batch_size samples are being used as the training set and the remaining batch_size samples are the test set. The splitting of datasets is not necessary for now but will make extension easy later on.

Tensordata requires a 2D tensor, where each line represents one training sample. Targets may be 1-D or 2-D.

In [29]:
train_set = TensorDataset(artifical_data[:4*batch_size,], artifical_labels[:4*batch_size,])
train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=False)
print(f'Number of batches in the training set is {len(train_set)}')

Number of batches in the training set is 12


In [30]:
eval_set = TensorDataset(artifical_data[4*batch_size:,], artifical_labels[4*batch_size:,])
eval_loader = DataLoader(train_set, batch_size=batch_size, shuffle=False)
print(f'Number of batches in the evaluation set is {len(eval_set)}')

Number of batches in the evaluation set is 3


In [31]:
dataloaders = {'train':train_loader,
               'eval':eval_loader}

# Create sample model

The base model architecture and weights are taken from [here](https://www.kaggle.com/code/sironghuang/understanding-pytorch-hooks) for reference.

Here, the architecture is extended by a dropout layer.

In [32]:
class TestModel(nn.Module):
  def __init__(self, dropout_rate = 0.5):
        super().__init__()
        # self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(2,2)
        self.s1 = nn.Sigmoid()
        self.do1 = nn.Dropout(p=0.5)
        self.fc2 = nn.Linear(2,2)
        self.s2 = nn.Sigmoid()
        self.fc1.weight = torch.nn.Parameter(torch.Tensor([[0.15,0.2],[0.250,0.30]]))
        self.fc1.bias = torch.nn.Parameter(torch.Tensor([0.35]))
        self.fc2.weight = torch.nn.Parameter(torch.Tensor([[0.4,0.45],[0.5,0.55]]))
        self.fc2.bias = torch.nn.Parameter(torch.Tensor([0.6]))

  def forward(self, x):
      # x = self.flatten(x)
      x= self.fc1(x)
      x = self.s1(x)
      
      
      x = self.do1(x)
      x= self.fc2(x)
      x = self.s2(x)
      return x

In [33]:
model = TestModel()
print(model)

TestModel(
  (fc1): Linear(in_features=2, out_features=2, bias=True)
  (s1): Sigmoid()
  (do1): Dropout(p=0.5, inplace=False)
  (fc2): Linear(in_features=2, out_features=2, bias=True)
  (s2): Sigmoid()
)


In [34]:
for name, param in model.named_parameters():
    print(f"Layer: {name} | Size: {param.size()} | Values : {param[:2]} \n")

Layer: fc1.weight | Size: torch.Size([2, 2]) | Values : tensor([[0.1500, 0.2000],
        [0.2500, 0.3000]], grad_fn=<SliceBackward0>) 

Layer: fc1.bias | Size: torch.Size([1]) | Values : tensor([0.3500], grad_fn=<SliceBackward0>) 

Layer: fc2.weight | Size: torch.Size([2, 2]) | Values : tensor([[0.4000, 0.4500],
        [0.5000, 0.5500]], grad_fn=<SliceBackward0>) 

Layer: fc2.bias | Size: torch.Size([1]) | Values : tensor([0.6000], grad_fn=<SliceBackward0>) 



# Prepare optimizer and loss function

In [35]:
sgd_parameters = {
    'lr':1e-3,        # undefined
    'momentum':0,   # 0
    'dampening':0,    # 0
    'weight_decay':0  # 0
}
optimizer = torch.optim.SGD(model.parameters(), **sgd_parameters)

In [36]:
loss_fn = nn.MSELoss()

# Hooks

Create two hooks for debugging purposes:
- the forward hook will print the input and output tensor produced during the forward pass.
- the backward hook will print the gradient of the output (the gradient coming from the loss) and the gradient input (the gradient used for following calculations closer to the input layers) during the backward pass. 

In [37]:
def forward_debug_hook(module, input, output):
  print('forward hook')
  print(input)
  print(output)

def backward_debug_hook(module, grad_input, grad_output):
   print('backward hook')
   print(grad_input)
   print(grad_output)

Another hook is used to extract the unregularized gradient. The hook is supposed to catch the output gradient (the gradient coming from the loss) when reaching the logits layer (last fc layer of the model).

For this, the hook is created as a class to store the gradient for later use. 

Note that the hook will be attached to the layer the whole time and the stored gradient will simply be overwritten.



In [38]:
class Catch_Hook():
  def __init__(self, module):
    self.hook = module.register_full_backward_hook(self.hook_fn)

  def hook_fn(self, module, grad_input, grad_output):
    self.caught_grad = grad_output
    print('caught a gradient')

  def close(self):
    self.hook.remove()

affected_layer = model.fc2
catch_hook = Catch_Hook(affected_layer)

Re-insertion of the gradient is a little bit more tricky.

By using the return statement in the backward hook, the gradient can be manipulated. There are two possibilities:
1. using the full_backward_hook will insert the gradient in the return statement as the ***input gradient*** (see [here](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_hook))
2. using the full_backward_hook will insert the gradient in the return statement as the ***output gradient*** (see [here](https://pytorch.org/docs/stable/generated/torch.nn.Module.html#torch.nn.Module.register_full_backward_pre_hook))

In the logit case, we want to replace the gradient dLoss/dLogits, which is in this case excactly the output gradient of the corresponding layer. If one was to replace the input_gradient, this would essentialy replace the gradient for dLoss/dLogits*dLogits/dInputsOfLogits.

Therefore one should use the backward_pre_hook (which is making troubles atm).

Additionally, we want don't want to apply this hook during both the unregularized and the regularized run, as we don't want to replace it during the unregularized run. We will therefore register and remove the hook in the training loop.

If you want to test the hook, you can use the commented code snippet at the end. Don't forget to remove the hook before applying a new one.

In [39]:
class Insert_Hook():
  def __init__(self, module, new_grad_output):
    self.new_grad_output = new_grad_output
    # use prepend=True so that this is definetly the first hook being applied
    self.hook = module.register_full_backward__pre_hook(self.hook_fn)

  def hook_fn(self, module, grad_input, grad_output):
    print('inserted gradient')
    # simply return the previously caught grad_output
    # this will replace the current grad_output (if prehook is used)
    # if non-pre hook is used, grad_input will be replaced (not desire in our case)
    return self.new_grad_output
  
  def close(self):
    self.hook.remove()

# artifical_grad = (torch.ones([3,2]),)
# print(artifical_grad)
# insert_hook = Insert_Hook(affected_layer,artifical_grad)

# Model training

Comments on the training loops are inside the code.

The general idea for each batch is:
First run 

In [40]:
def train_model(model, dataloaders, loss_fn, optimizer, num_epochs=5):
    since = time.time()

    val_acc_history = []

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0

    n_train_batches = len(dataloaders['train'])

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        ########## train phase ##########
        phase = 'train'
        model.train()

        running_loss = 0.0
        running_corrects = 0

        for batch, (inputs, labels) in enumerate(dataloaders[phase]):
          optimizer.zero_grad()

          handlef = affected_layer.register_forward_hook(forward_debug_hook)
          handleb = affected_layer.register_full_backward_hook(backward_debug_hook)

          #++++++++ catch unregularized gradient ++++++++#
          print('*'*5 + 'unregularized run' + '*'*5)
          model.eval()
          outputs = model(inputs)
          loss = loss_fn(outputs, labels.float())
          loss.backward()
          new_grad_output = catch_hook.caught_grad
          
          model.train()
          optimizer.zero_grad()
          #++++++++ \catch unregularized gradient ++++++++#

          #++++++++ prepare insertion of unregularized gradient ++++++++#
          handleb.remove()
          insert_hook = Insert_Hook(affected_layer,new_grad_output)
          handleb = affected_layer.register_full_backward_hook(backward_debug_hook)
          #++++++++ \prepare insertion of unregularized gradient ++++++++#
          
          # Get model outputs and calculate loss
          print('*'*5 + 'forward pass' + '*'*5)
          outputs = model(inputs)
          print('outputs')
          print(outputs)

          # outputs.backward(torch.tensor([[0.7414,-0.2171],[0.7414,-0.2171],[0.7414,-0.2171]],dtype=torch.float),retain_graph=True)

          print('*'*5 + 'loss calculation' + '*'*5)
          loss = loss_fn(outputs, labels.float())
          print('loss')
          print(loss)
          # print('weights grad')
          # print(affected_layer.weight.grad)
          # print('bias grad')
          # print(affected_layer.bias.grad)

          preds = (outputs>0.5).int()
          

          # backward + optimize
          print('*'*5 + 'backward pass' + '*'*5)
          
          # print(affected_layer._backward_hooks)
          loss.backward()          
          print('weights grad')
          print(affected_layer.weight.grad)
          print('bias grad')
          print(affected_layer.bias.grad)

          insert_hook.close()
          handlef.remove()
          handleb.remove()

          # print(affected_layer._backward_hooks)
          optimizer.step()

          running_loss += loss.item() * inputs.size(0)
          running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(dataloaders[phase].dataset)
        epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
        print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))

        ########## eval phase ##########
        phase = 'eval'
        model.eval()

        running_loss = 0.0
        running_corrects = 0

        for batch, (inputs, labels) in enumerate(dataloaders[phase]):
          # disable gradient tracking for speedup
          with torch.set_grad_enabled(phase == 'train'):
            outputs = model(inputs)
            loss = loss_fn(outputs, labels.float())
            preds = (outputs>0.5).int()

          running_loss += loss.item() * inputs.size(0)
          running_corrects += torch.sum(preds == labels.data)

        epoch_loss = running_loss / len(dataloaders[phase].dataset)
        epoch_acc = running_corrects.double() / len(dataloaders[phase].dataset)
        print('{} Loss: {:.4f} Acc: {:.4f}'.format(phase, epoch_loss, epoch_acc))  

        val_acc_history.append(epoch_acc)
        print()


    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

    return val_acc_history

In [41]:
hist = train_model(model,
                   dataloaders,
                   loss_fn,
                   optimizer,
                   num_epochs=1
                   )

Epoch 0/0
----------
*****unregularized run*****
forward hook
(tensor([[0.5933, 0.5969],
        [0.5933, 0.5969],
        [0.5933, 0.5969]], grad_fn=<BackwardHookFunctionBackward>),)
tensor([[1.1059, 1.2249],
        [1.1059, 1.2249],
        [1.1059, 1.2249]], grad_fn=<AddmmBackward0>)
caught a gradient
backward hook
(tensor([[0.0121, 0.0138],
        [0.0121, 0.0138],
        [0.0121, 0.0138]]),)
(tensor([[ 0.0462, -0.0127],
        [ 0.0462, -0.0127],
        [ 0.0462, -0.0127]]),)


AttributeError: ignored

In [None]:
print(affected_layer._backward_hooks)
print(catch_hook.caught_grad)
catch_hook.close()
print(affected_layer._backward_hooks)