In [1]:
import torch
from torch import nn

References:
- [半小时学会 PyTorch Hook](https://cloud.tencent.com/developer/article/1475430)
- [How to Use PyTorch Hooks](https://medium.com/the-dl/how-to-use-pytorch-hooks-5041d777f904)

### Hook for Tensors

Save gradients through `.retain_grad()`

In [2]:
x = torch.arange(0, 4, dtype = torch.float, requires_grad = True)
y = torch.arange(1, 5, dtype = torch.float, requires_grad = True)
z = torch.arange(2, 6, dtype = torch.float, requires_grad = True)

w = x + y
w.retain_grad()

o = w @ z
o.retain_grad()
o.backward()

param_list = ['x', 'y', 'z', 'w', 'o']
for param in param_list:
    print(f'{param}.requires_grad:', eval(param).requires_grad)
    
for param in param_list:
    print(f'{param}.grad:', eval(param).grad)

x.requires_grad: True
y.requires_grad: True
z.requires_grad: True
w.requires_grad: True
o.requires_grad: True
x.grad: tensor([2., 3., 4., 5.])
y.grad: tensor([2., 3., 4., 5.])
z.grad: tensor([1., 3., 5., 7.])
w.grad: tensor([2., 3., 4., 5.])
o.grad: tensor(1.)


Print gradients through `.register_hook()`

In [3]:
x = torch.arange(0, 4, dtype = torch.float, requires_grad = True)
y = torch.arange(1, 5, dtype = torch.float, requires_grad = True)
z = torch.arange(2, 6, dtype = torch.float, requires_grad = True)

# hook function
def hook_print_grad(grad):
    print(grad)
    
w = x + y
handle_w = w.register_hook(hook_print_grad)

o = w @ z
handle_o =o.register_hook(hook_print_grad)

o.backward()

handle_w.remove()
handle_o.remove()

tensor(1.)
tensor([2., 3., 4., 5.])


In [4]:
x = torch.arange(0, 4, dtype = torch.float, requires_grad = True)
y = torch.arange(1, 5, dtype = torch.float, requires_grad = True)
z = torch.arange(2, 6, dtype = torch.float, requires_grad = True)

w = x + y
handle_w = w.register_hook(lambda x: x * 2)
o = w @ z
o.backward()

param_list = ['x', 'y', 'z']
for param in param_list:
    print(f'{param}.grad:', eval(param).grad)

handle_w.remove()

x.grad: tensor([ 4.,  6.,  8., 10.])
y.grad: tensor([ 4.,  6.,  8., 10.])
z.grad: tensor([1., 3., 5., 7.])


### Hooks for `nn.Module`

In [5]:
class NN(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.l1 = nn.Linear(3, 4)
        self.act1 = nn.ReLU()
        self.l2 = nn.Linear(4, 1)
    
    def forward(self, X):
        Y = self.l2(self.act1(self.l1(X)))
        return Y

def hook_forward_fn(module, input, output):
    print(module)
    print('- input:', input)
    print('- output:', output)

def hook_backward_fn(module, grad_input, grad_output):
    print(module) 
    print('- grad_output:', grad_output) 
    print('- grad_input:', grad_input)
    
model = NN()
modules = model.named_children()
for name, module in modules:
    module.register_forward_hook(hook_forward_fn)
    module.register_full_backward_hook(hook_backward_fn)
    
x = torch.ones((3,), requires_grad = True)
print('<===================Forward Process====================>')
o = model(x)
print('<===================Backward Process===================>')
o.backward()

Linear(in_features=3, out_features=4, bias=True)
- input: (tensor([1., 1., 1.], grad_fn=<BackwardHookFunctionBackward>),)
- output: tensor([-0.3731, -0.3684,  0.9563,  0.1577], grad_fn=<AddBackward0>)
ReLU()
- input: (tensor([-0.3731, -0.3684,  0.9563,  0.1577],
       grad_fn=<BackwardHookFunctionBackward>),)
- output: tensor([0.0000, 0.0000, 0.9563, 0.1577], grad_fn=<ReluBackward0>)
Linear(in_features=4, out_features=1, bias=True)
- input: (tensor([0.0000, 0.0000, 0.9563, 0.1577],
       grad_fn=<BackwardHookFunctionBackward>),)
- output: tensor([-0.7612], grad_fn=<AddBackward0>)
Linear(in_features=4, out_features=1, bias=True)
- grad_output: (tensor([1.]),)
- grad_input: (tensor([ 0.2874,  0.4944, -0.3627, -0.2082]),)
ReLU()
- grad_output: (tensor([ 0.2874,  0.4944, -0.3627, -0.2082]),)
- grad_input: (tensor([ 0.0000,  0.0000, -0.3627, -0.2082]),)
Linear(in_features=3, out_features=4, bias=True)
- grad_output: (tensor([ 0.0000,  0.0000, -0.3627, -0.2082]),)
- grad_input: (tensor([-0

### Applications

#### Verbose Printer

In [6]:
class VerboseExecution(nn.Module):
    def __init__(self, model: nn.Module):
        super().__init__()
        self.model = model

        # Register a hook for each layer
        for name, layer in self.model.named_children():
            layer.__name__ = name
            layer.register_forward_hook(
                lambda layer, _, output: print(f"{layer.__name__}: {output.shape}")
            )

    def forward(self, x: torch.tensor) -> torch.tensor:
        return self.model(x)
    
from torchvision.models import resnet50

verbose_resnet = VerboseExecution(resnet50())
dummy_input = torch.ones(10, 3, 224, 224)

_ = verbose_resnet(dummy_input)

conv1: torch.Size([10, 64, 112, 112])
bn1: torch.Size([10, 64, 112, 112])
relu: torch.Size([10, 64, 112, 112])
maxpool: torch.Size([10, 64, 56, 56])
layer1: torch.Size([10, 256, 56, 56])
layer2: torch.Size([10, 512, 28, 28])
layer3: torch.Size([10, 1024, 14, 14])
layer4: torch.Size([10, 2048, 7, 7])
avgpool: torch.Size([10, 2048, 1, 1])
fc: torch.Size([10, 1000])


#### Feature Extractor

In [7]:
from typing import Dict, Iterable, Callable

class FeatureExtractor(nn.Module):
    def __init__(self, model: nn.Module, layers: Iterable[str]):
        super().__init__()
        self.model = model
        self.layers = layers
        self._features = {layer: torch.empty(0) for layer in layers}

        for layer_id in layers:
            layer = dict([*self.model.named_modules()])[layer_id]
            layer.register_forward_hook(self.save_outputs_hook(layer_id))

    def save_outputs_hook(self, layer_id: str) -> Callable:
        def fn(_, __, output):
            self._features[layer_id] = output
        return fn

    def forward(self, x: torch.tensor) -> Dict[str, torch.tensor]:
        _ = self.model(x)
        return self._features
    

resnet_features = FeatureExtractor(resnet50(), layers=["layer4", "avgpool"])
features = resnet_features(dummy_input)

print({name: output.shape for name, output in features.items()})

{'layer4': torch.Size([10, 2048, 7, 7]), 'avgpool': torch.Size([10, 2048, 1, 1])}


### Gradient Clipping

In [8]:
def gradient_clipper(model: nn.Module, val: float) -> nn.Module:
    for parameter in model.parameters():
        parameter.register_hook(lambda grad: grad.clamp_(-val, val))
    return model