## Hook Management - part 1

### ResNet-18 - CIFAR100 classification

In [1]:
import torch
import torch.nn as nn
from torchvision.models import resnet18
from torchvision.datasets import CIFAR100
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, ToTensor

In [2]:
cifar100 = CIFAR100('../data/', train=True, transform=Compose([Resize((224, 224)), ToTensor()]))
cifar_ldr = DataLoader(cifar100)

# get a batch of data
x, y = next(iter(cifar_ldr))
x, y = x.to(0), y.to(0)

In [3]:
resnet = resnet18()
_ = resnet.cuda()

In [4]:
from tacklebox.hook_management import HookManager
hookmngr = HookManager()

### Hook definition

In [5]:
# forward hook function signature: (module, inputs, outputs)
def print_shape(module, inputs, outputs):
    output, = outputs
    print('%s output shape: ' % module.name, end='')
    print(output.shape)

module = resnet.conv1

### Hook registration and lookup

In [6]:
# register print_shape with the module, naming it myconv for reference
hookmngr.register_forward_hook(print_shape, hook_fn_name='print_shape', myconv=module)

In [7]:
# lookup the HookFunction wrapper for print_shape
print(hookmngr.name_to_hookfn['print_shape'])
# lookup the HookHandle wrapper for the handle returned from registering print_shape with myconv
print(hookmngr.name_to_hookhandle['print_shape[myconv]'])
# lookup the module named myconv
print(hookmngr.name_to_module['myconv'])

<tacklebox.hook_management.HookFunction object at 0x7fd722754828>
<print_shape[myconv] <class 'tacklebox.hook_management.HookHandle'> registered to myconv (active)>
Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)


In [8]:
# now register the same method with another module, this time leaving it deactivated for now
hookmngr.register_forward_hook(print_shape, mylayer=resnet.layer1, activate=False)

# note that we didnt need to name the hook function again

In [9]:
# lookup the HookHandle wrapper for the handle returned from registering print_shape with mylayer
print(hookmngr.name_to_hookhandle['print_shape[mylayer]'])
# lookup our new module, mylayer
print(hookmngr.name_to_module['mylayer'])
# lookup all HookHandles corresponding to the print_shape hook function
print(hookmngr.name_to_hookfn['print_shape'].handles)

<print_shape[mylayer] <class 'tacklebox.hook_management.HookHandle'> registered to mylayer (inactive)>
Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)
[<print_shape[myconv] <class 'tacklebox.hook_ma

In [11]:
# lets test the hook function
with torch.no_grad():
    resnet(x)

myconv output shape: torch.Size([1, 64, 112, 112])


### Hook activation

In [12]:
# activate mylayer
hookmngr.activate_module_hooks_by_name('mylayer')

with torch.no_grad():
    resnet(x)

myconv output shape: torch.Size([1, 64, 112, 112])
mylayer output shape: torch.Size([1, 64, 56, 56])


In [13]:
# deactivate hook registered to myconv
hookmngr.deactivate_all_hooks()

with torch.no_grad():
    resnet(x)

### Using hook contexts

In [14]:
# use python context to activate hook registered to mylayer, then deactivate it after forward pass
with torch.no_grad():
    with hookmngr.hook_module_context_by_name('mylayer'):
        resnet(x)
    
    resnet(x)  # hook doesnt execute once we exit context

mylayer output shape: torch.Size([1, 64, 56, 56])


In [15]:
# now try combining the contexts -- less indentation :)
with hookmngr.hook_all_context() + torch.no_grad():
    resnet(x)
    print(torch.is_grad_enabled())

myconv output shape: torch.Size([1, 64, 112, 112])
mylayer output shape: torch.Size([1, 64, 56, 56])
False
