## Hook Management - part 1

### ResNet-18 - CIFAR100 classification

In [1]:
import torch
import torch.nn as nn
from torchvision.models import resnet18
from torchvision.datasets import CIFAR100
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, ToTensor
import torchvision

In [2]:
!python --version
print('torch %s' % torch.__version__)
print('torchvision %s' % torchvision.__version__)

Python 3.7.3
torch 1.2.0
torchvision 0.4.0a0


In [3]:
cifar100 = CIFAR100('../data/', train=True, transform=Compose([Resize((224, 224)), ToTensor()]))
cifar_ldr = DataLoader(cifar100)

# get a batch of data
x, y = next(iter(cifar_ldr))
x, y = x.to(0), y.to(0)

In [4]:
resnet = resnet18()
_ = resnet.cuda()

In [5]:
# import tacklebox and initialize the hook manager\
from tacklebox.hook_management import HookManager
hookmngr = HookManager()

### Hook definition

In [6]:
# forward hook function signature: (module, inputs, outputs)

# define function print_shape that prints the shape of the first tensor in outputs
def print_shape(module, inputs, outputs):
    output, *_ = outputs
    print('%s output shape: ' % module.name, output.shape)

### Hook registration and lookup

In [7]:
# register print_shape with resnet.conv1, naming it myconv for reference
hookmngr.register_forward_hook(print_shape, hook_fn_name='print_shape', myconv=resnet.conv1)

In [8]:
# lookup the HookFunction wrapper for print_shape
hookmngr.name_to_hookfn['print_shape']

<tacklebox.hook_management.HookFunction at 0x7f6f9eada2e8>

In [9]:
# lookup the module named myconv
hookmngr.name_to_module['myconv']

Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)

In [10]:
# lookup the HookHandle for print_shape registered to myconv
hookmngr.name_to_hookhandle['print_shape[myconv]']

<print_shape[myconv] <class 'tacklebox.hook_management.HookHandle'> registered to myconv (active)>

In [11]:
# now register the same method with resnet.layer1, naming it mylayer and leaving it deactivated
hookmngr.register_forward_hook(print_shape, mylayer=resnet.layer1, activate=False)

# note that we didnt need to name the hook function again

In [12]:
# lookup our new module, mylayer
hookmngr.name_to_module['mylayer']

Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)

In [13]:
# lookup the HookHandle for print_shape registered to mylayer
hookmngr.name_to_hookhandle['print_shape[mylayer]']

<print_shape[mylayer] <class 'tacklebox.hook_management.HookHandle'> registered to mylayer (inactive)>

In [14]:
# lookup all HookHandles corresponding to the print_shape hook function
hookmngr.name_to_hookfn['print_shape'].handles

[<print_shape[myconv] <class 'tacklebox.hook_management.HookHandle'> registered to myconv (active)>,
 <print_shape[mylayer] <class 'tacklebox.hook_management.HookHandle'> registered to mylayer (inactive)>]

In [15]:
# lets test the hook function
with torch.no_grad():
    resnet(x)

myconv output shape:  torch.Size([1, 64, 112, 112])


### Hook activation

In [16]:
# activate hooks registered to mylayer
hookmngr.activate_module_hooks_by_name('mylayer')

with torch.no_grad():
    resnet(x)

myconv output shape:  torch.Size([1, 64, 112, 112])
mylayer output shape:  torch.Size([1, 64, 56, 56])


In [17]:
# deactivate all hooks
hookmngr.deactivate_all_hooks()

with torch.no_grad():
    resnet(x)

### Using hook contexts

In [18]:
with torch.no_grad():
    # use python context to activate hook registered to mylayer, then deactivate it after forward pass
    with hookmngr.hook_module_context_by_name('mylayer'):
        resnet(x)
    
    resnet(x)  # hook doesnt execute once we exit context

mylayer output shape:  torch.Size([1, 64, 56, 56])


In [19]:
# use context to activate all hooks, combining context with torch.no_grad()
with hookmngr.hook_all_context() + torch.no_grad():
    resnet(x)
    print(torch.is_grad_enabled())

myconv output shape:  torch.Size([1, 64, 112, 112])
mylayer output shape:  torch.Size([1, 64, 56, 56])
False
