### ResNet-18 - CIFAR100 classification

In [1]:
import torch
import torch.nn as nn
from torchvision.models import resnet18
from torchvision.datasets import CIFAR100
from torch.utils.data import DataLoader
from torchvision.transforms import Compose, Resize, ToTensor

In [2]:
cifar100 = CIFAR100('../data/', train=True, transform=Compose([Resize((224, 224)), ToTensor()]))
cifar_ldr = DataLoader(cifar100)

# get a batch of data
x, y = next(iter(cifar_ldr))
x, y = x.to(0), y.to(0)

In [3]:
resnet = resnet18()
_ = resnet.cuda()

In [4]:
from tacklebox.hook_management import HookManager
hookmngr = HookManager()

In [5]:
# forward hook function signature: (module, inputs, output)
def print_shape(module, inputs, output):
    print('%s output shape: ' % module.name, end='')
    print(output.shape)

module = resnet.conv1

In [6]:
# register print_shape with the module, naming it myconv for reference
hookmngr.register_forward_hook(print_shape, hook_fn_name='print_shape', myconv=module)

In [7]:
# lookup the HookFunction wrapper for print_shape
print(hookmngr.name_to_hookfn['print_shape'])
# lookup the HookHandle wrapper for the handle returned from registering print_shape with myconv
print(hookmngr.name_to_hookhandle['print_shape[myconv]'])
# lookup the module named myconv
print(hookmngr.name_to_module['myconv'])

<tacklebox.hook_management.HookFunction object at 0x7f94a1058978>
<print_shape[myconv] <class 'tacklebox.hook_management.HookHandle'> registered to myconv (active)>
Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)


In [8]:
# now register the same method with another module, this time leaving it deactivated for now
hookmngr.register_forward_hook(print_shape, mylayer=resnet.layer1, activate=False)

# note that we didnt need to name the hook function again

In [9]:
# lookup the HookHandle wrapper for the handle returned from registering print_shape with mylayer
print(hookmngr.name_to_hookhandle['print_shape[mylayer]'])
# lookup our new module, mylayer
print(hookmngr.name_to_module['mylayer'])
# lookup all HookHandles corresponding to the print_shape hook function
print(hookmngr.name_to_hookfn['print_shape'].handles)

<print_shape[mylayer] <class 'tacklebox.hook_management.HookHandle'> registered to mylayer (inactive)>
Sequential(
  (0): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
  (1): BasicBlock(
    (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (relu): ReLU(inplace=True)
    (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  )
)
[<print_shape[myconv] <class 'tacklebox.hook_ma

In [10]:
# lets test the hook function
with torch.no_grad():
    resnet(x)

myconv output shape: torch.Size([1, 64, 112, 112])


In [11]:
# activate mylayer
hookmngr.activate_module_hooks_by_name('mylayer')

with torch.no_grad():
    resnet(x)

myconv output shape: torch.Size([1, 64, 112, 112])
mylayer output shape: torch.Size([1, 64, 56, 56])


In [12]:
# deactivate hook registered to myconv
hookmngr.deactivate_all_hooks()

with torch.no_grad():
    resnet(x)

In [13]:
# use python context to activate hook registered to mylayer, then deactivate it after forward pass
with torch.no_grad():
    with hookmngr.hook_module_context_by_name('mylayer'):
        resnet(x)
    
    resnet(x)  # hook doesnt execute once we exit context

mylayer output shape: torch.Size([1, 64, 56, 56])


In [14]:
# now try combining the contexts -- less indentation :)
with hookmngr.hook_all_context() + torch.no_grad():
    resnet(x)
    print(torch.is_grad_enabled())

myconv output shape: torch.Size([1, 64, 112, 112])
mylayer output shape: torch.Size([1, 64, 56, 56])
False


In [15]:
# dealing with multiple hook functions

# forward pre-hook function signature: (module, inputs)
def zero_input(module, inputs):
    input, = inputs
    input = input - input
    print('Set %s input to zero' % module.name)
    return input

def print_mean(module, inputs, output):
    print('%s input mean = %.2f, output mean = %.2f' % (module.name, inputs[0].mean().item(),
                                                        output.mean().item()))
    
# to pass modules without names, pass as args instead of as kwargs
hookmngr.register_forward_pre_hook(zero_input, resnet.conv1, resnet.layer1,
                                   hook_fn_name='zero_input', activate=False)
hookmngr.register_forward_hook(print_mean, resnet.conv1, resnet.layer1,
                              hook_fn_name='print_mean', activate=False)

In [16]:
with hookmngr.hook_all_context() + torch.no_grad():
    resnet(x)

Set myconv input to zero
myconv output shape: torch.Size([1, 64, 112, 112])
myconv input mean = 0.00, output mean = 0.00
Set mylayer input to zero
mylayer output shape: torch.Size([1, 64, 56, 56])
mylayer input mean = 0.00, output mean = 0.00


In [17]:
# lets only activate our forward hooks
with hookmngr.hook_all_context(category='forward_hook') + torch.no_grad():
    resnet(x)

myconv output shape: torch.Size([1, 64, 112, 112])
myconv input mean = 0.54, output mean = 0.00
mylayer output shape: torch.Size([1, 64, 56, 56])
mylayer input mean = 0.51, output mean = 0.92


In [18]:
# now lets use only our original hook function, print_shape
with hookmngr.hook_all_context(hook_types=[print_shape, zero_input]) + torch.no_grad():
    resnet(x)

Set myconv input to zero
myconv output shape: torch.Size([1, 64, 112, 112])
Set mylayer input to zero
mylayer output shape: torch.Size([1, 64, 56, 56])


In [19]:
# backward hooks
def print_grad_shape_new(module, grad_in, grad_out):
    module.val = 'val'
    print('%s grad_in shape: ' % module.name, end='')
    print(grad_in[0].shape, end=', ')
    print('grad_out shape: ', grad_out.shape)

hookmngr.register_backward_hook(print_grad_shape_new, resnet.conv1, hook_fn_name='print_grad_shape_new')

In [21]:
hookmngr.

{'HookManager._forward_hook_base[myconv]': <HookManager._forward_hook_base[myconv] <class 'tacklebox.hook_management.HookHandle'> registered to myconv (inactive)>,
 'HookManager._forward_pre_hook_base[myconv]': <HookManager._forward_pre_hook_base[myconv] <class 'tacklebox.hook_management.HookHandle'> registered to myconv (inactive)>,
 'print_shape[myconv]': <print_shape[myconv] <class 'tacklebox.hook_management.HookHandle'> registered to myconv (inactive)>,
 'HookManager._forward_hook_base[mylayer]': <HookManager._forward_hook_base[mylayer] <class 'tacklebox.hook_management.HookHandle'> registered to mylayer (inactive)>,
 'HookManager._forward_pre_hook_base[mylayer]': <HookManager._forward_pre_hook_base[mylayer] <class 'tacklebox.hook_management.HookHandle'> registered to mylayer (inactive)>,
 'print_shape[mylayer]': <print_shape[mylayer] <class 'tacklebox.hook_management.HookHandle'> registered to mylayer (inactive)>,
 'zero_input[myconv]': <zero_input[myconv] <class 'tacklebox.hook_m

In [20]:
xent = nn.CrossEntropyLoss()
out = resnet(x)
loss = xent(out, y)
loss.backward()

In [61]:
out = resnet.conv1(x)

In [63]:
out.sum().backward()

In [56]:
hookmngr.register_backward_hook(print_grad_shape_new, resnet.layer1, hook_fn_name='print_grad_shape_new')

In [57]:
hookmngr.name_to_hookfn['print_grad_shape_new'].handles

[<print_grad_shape_new[myconv] <class 'tacklebox.hook_management.HookHandle'> registered to myconv (active)>,
 <print_grad_shape_new[mylayer] <class 'tacklebox.hook_management.HookHandle'> registered to mylayer (active)>]

# In summary,


### XLM - Multi30k machine translation

In [85]:
import torch
from transformers import XLMTokenizer, XLMWithLMHeadModel
import spacy
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator
import torch.nn as nn

In [72]:
english = spacy.load('en')
german = spacy.load('de')

def tokenize_en(text):
    return [tok.text for tok in english.tokenizer(text)]
def tokenize_de(text):
    return [tok.text for tok in german.tokenizer(text)]

en_text = Field(sequential=True, use_vocab=True, tokenize=tokenize_en, lower=True)
de_text = Field(sequential=True, use_vocab=True, tokenize=tokenize_de, lower=True)

train, val, test = Multi30k.splits(root='../data', exts=('.en', '.de'), fields=(en_text, de_text))

en_text.build_vocab(train, max_size=30000, min_freq=3)
de_text.build_vocab(train, max_size=30000, min_freq=3)
vocab_en = en_text.vocab
vocab_de = de_text.vocab
pad_idx = vocab_de.stoi['<pad>']

train_ldr, val_ldr, test_ldr = BucketIterator.splits((train, val, test),
                                                    batch_size=5)

In [66]:
xlm = XLMWithLMHeadModel.from_pretrained('xlm-mlm-ende-1024')
xlm.transformer.embeddings = nn.Embedding(len(vocab_en), xlm.config.emb_dim, padding_idx=pad_idx)
xlm.pred_layer.proj = nn.Linear(xlm.config.emb_dim, len(vocab_de), bias=True)
xlm.cuda()

XLMWithLMHeadModel(
  (transformer): XLMModel(
    (position_embeddings): Embedding(512, 1024)
    (lang_embeddings): Embedding(2, 1024)
    (embeddings): Embedding(4554, 1024, padding_idx=1)
    (layer_norm_emb): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
    (attentions): ModuleList(
      (0): MultiHeadAttention(
        (q_lin): Linear(in_features=1024, out_features=1024, bias=True)
        (k_lin): Linear(in_features=1024, out_features=1024, bias=True)
        (v_lin): Linear(in_features=1024, out_features=1024, bias=True)
        (out_lin): Linear(in_features=1024, out_features=1024, bias=True)
      )
      (1): MultiHeadAttention(
        (q_lin): Linear(in_features=1024, out_features=1024, bias=True)
        (k_lin): Linear(in_features=1024, out_features=1024, bias=True)
        (v_lin): Linear(in_features=1024, out_features=1024, bias=True)
        (out_lin): Linear(in_features=1024, out_features=1024, bias=True)
      )
      (2): MultiHeadAttention(
        (q_l

In [151]:
xent = nn.CrossEntropyLoss()

batch = next(iter(train_ldr))
src, trg = batch.src.to(0), batch.trg.to(0)
out, = xlm(src)
min_idx = min([out.shape[0], trg.shape[0]])
out, trg = out[:min_idx], trg[:min_idx]

mask = (trg != pad_idx).type(torch.bool)
loss = xent(out[mask], trg[mask])
loss.backward()