<a href="https://colab.research.google.com/github/ajayrfhp/LearningDeepLearning/blob/main/pytorch_hooks.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

- Once e.backward() is called, only gradients to leaf node can be visualized. Hooks on tensors allow us to access/modify intermediate gradients
- Hooks on modules allow us to access intermediate layers easily
- [Video on hooks](https://www.youtube.com/watch?v=syLFCVYua6Q&ab_channel=ElliotWaite)

In [None]:
def example_without_hook():
  a = torch.tensor(2.0, requires_grad=True)
  b = torch.tensor(3.0, requires_grad=True)

  c = a * b
  d = torch.tensor(4.0, requires_grad=True)

  e = c * d

  e.backward()
  print(d.grad, c.grad)

example_without_hook()

tensor(6.) None


  print(d.grad, c.grad)


In [None]:
def example_with_hook():
  a = torch.tensor(2.0, requires_grad=True)
  b = torch.tensor(3.0, requires_grad=True)

  c = a * b
  c.retain_grad()
  c.register_hook(lambda x : print("gradient of c is", x))
  d = torch.tensor(4.0, requires_grad=True)

  e = c * d

  e.backward()
  print(d.grad, c.grad)

example_with_hook()

gradient of c is tensor(4.)
tensor(6.) tensor(4.)


In [None]:
def example_module_with_hook():
  class SumNet(nn.Module):
    def __init__(self):
      super(SumNet, self).__init__()

    def forward(self, a, b, c):
      d = a + b + c
      return d

  def forward_hook(module, inputs, output):
    return output

  sum_net = SumNet()
  a = torch.tensor(1.0, requires_grad=True)
  b = torch.tensor(2.0, requires_grad=True)
  c = torch.tensor(3.0, requires_grad=True)
  f = sum_net.register_forward_hook(forward_hook)
  print(sum_net.forward(1, 2, 3))

example_module_with_hook()

6


In [None]:
def example_get_layer_output_with_hook():
  class Net(nn.Module):
    def __init__(self):
      super(Net, self).__init__()
      self.model = nn.Sequential(
          nn.LazyLinear(10),
          nn.ReLU(),
          nn.LazyLinear(20),
          nn.ReLU(),
          nn.LazyLinear(30),
      )

    def forward(self, x):
      return self.model(x)


  def get_activation(name):
    def hook(module, inputs, output):
      global activation
      activation[name] = output
    return hook

  net = Net()
  global activation
  activation = {}
  list(net.model.children())[0].register_forward_hook(get_activation('f1'))
  list(net.model.children())[2].register_forward_hook(get_activation('f2'))
  list(net.model.children())[4].register_forward_hook(get_activation('f3'))

  x = torch.randn((1, 5))
  net.forward(x)
  print(activation['f1'].shape)
  print(activation['f2'].shape)
  print(activation['f3'].shape)

example_get_layer_output_with_hook()

torch.Size([1, 10])
torch.Size([1, 20])
torch.Size([1, 30])
