## Hook Management - part 2

### XLM - Multi30k machine translation

In [1]:
import torch
import transformers
from transformers import XLMTokenizer, XLMWithLMHeadModel
import spacy
import torchtext
from torchtext.datasets import Multi30k
from torchtext.data import Field, BucketIterator
import torch.nn as nn
from tacklebox.hook_management import HookManager

In [2]:
!python --version
print('torch %s' % torch.__version__)
print('spacy %s' % spacy.__version__)
print('transformers %s' % transformers.__version__)
print('torchtext %s' % torchtext.__version__)

Python 3.7.3
torch 1.2.0
spacy 2.1.9
transformers 2.4.1
torchtext 0.4.0


In [3]:
english = spacy.load('en')
german = spacy.load('de')

def tokenize_en(text):
    return [tok.text for tok in english.tokenizer(text)]
def tokenize_de(text):
    return [tok.text for tok in german.tokenizer(text)]

en_text = Field(sequential=True, use_vocab=True, tokenize=tokenize_en, lower=True)
de_text = Field(sequential=True, use_vocab=True, tokenize=tokenize_de, lower=True)

train, val, test = Multi30k.splits(root='../data', exts=('.en', '.de'), fields=(en_text, de_text))

en_text.build_vocab(train, max_size=30000, min_freq=3)
de_text.build_vocab(train, max_size=30000, min_freq=3)
vocab_en = en_text.vocab
vocab_de = de_text.vocab
pad_idx = vocab_de.stoi['<pad>']

train_ldr, val_ldr, test_ldr = BucketIterator.splits((train, val, test),
                                                    batch_size=5)

In [4]:
xlm = XLMWithLMHeadModel.from_pretrained('xlm-mlm-ende-1024')
xlm.transformer.embeddings = nn.Embedding(len(vocab_en), xlm.config.emb_dim, padding_idx=pad_idx)
xlm.pred_layer.proj = nn.Linear(xlm.config.emb_dim, len(vocab_de), bias=True)
_ = xlm.cuda()

In [5]:
xent = nn.CrossEntropyLoss()

batch = next(iter(train_ldr))
src, trg = batch.src.to(0), batch.trg.to(0)

def mt_loss(out, target):
    # only compute loss for non-padding indices
    min_idx = min([out.shape[0], target.shape[0]])
    out, target = out[:min_idx], target[:min_idx]
    mask = (target != pad_idx).type(torch.bool)
    return xent(out[mask], target[mask])

out, = xlm(src)
mt_loss(out, trg).backward()

In [6]:
# import and initialize the hook manager
hookmngr = HookManager()

In [7]:
# define named modules, including the embedding layer (id=embedding) and final attention (id=final_attn)
named_modules = {
    'embedding': xlm.transformer.embeddings,
    'final_attn': xlm.transformer.attentions[-1]
}

### Filtering hooks

In [8]:
# forward pre-hook function signature: (module, inputs)
def zero_input(module, inputs):
    ret = []
    zeroed = 0
    for input in inputs:
        if type(input) == torch.Tensor and input.dtype == torch.float:
            ret += [input - input]
            zeroed += 1
        else:
            ret += [input]
    print('Set %d/%d of %s inputs to zero' % (zeroed, len(inputs), module.name))
    return tuple(ret)

def print_mean(module, inputs, outputs):
    print('%s input mean = %.2f, output mean = %.2f' % (module.name,
                                                        inputs[0].sum().item() / inputs[0].numel(),
                                                        outputs[0].sum().item() / outputs[0].numel()))

# register both hooks on all named modules, deferring activation
hookmngr.register_forward_pre_hook(zero_input, hook_fn_name='zero_input', activate=False, **named_modules)
hookmngr.register_forward_hook(print_mean, hook_fn_name='print_mean', activate=False, **named_modules)

In [9]:
# activate all hooks
with hookmngr.hook_all_context() + torch.no_grad():
    xlm(src)

Set 0/1 of embedding inputs to zero
embedding input mean = 101.78, output mean = 0.00
Set 1/2 of final_attn inputs to zero
final_attn input mean = 0.00, output mean = -0.00


In [10]:
# lets only activate our forward hooks
with hookmngr.hook_all_context(category='forward_hook') + torch.no_grad():
    xlm(src)

embedding input mean = 101.78, output mean = 0.00
final_attn input mean = -0.01, output mean = 0.02


In [11]:
# now lets use only our original hook function, print_mean
with hookmngr.hook_all_context(hook_types=[print_mean]) + torch.no_grad():
    xlm(src)

embedding input mean = 101.78, output mean = 0.00
final_attn input mean = -0.01, output mean = 0.02


### Backward hooks

In [12]:
# backward hooks
def print_grad_shape(module, grad_in, grad_out):
    print('%s ' % module.name, end='')
    if grad_in[0] is not None:
        print('grad_in shape: ', grad_in[0].shape, end=', ')
    if grad_out[0] is not None:
        print('grad_out shape: ', grad_out[0].shape, end='')
    print('')

# register backward hook print_grad_shape on the final attentionlayer
hookmngr.register_backward_hook(print_grad_shape, named_modules['final_attn'],
                               activate=False, hook_fn_name='print_grad_shape')

# Note: can pass modules as additional args if already registered

In [13]:
# activate all backward hooks
with hookmngr.hook_all_context(category='backward_hook'):
    out, = xlm(src)
    mt_loss(out, trg).backward()

final_attn grad_in shape:  torch.Size([19, 5, 1024]), grad_out shape:  torch.Size([19, 5, 1024])


### Using inputs, outputs and gradients all at once

In [14]:
# using intermediate outputs with gradient
def print_outputs_with_grad(module, grad_in, grad_out, inputs, outputs):
    print('%s input-gradient pairs: ' % module.name, end='')
    for inp, grad in zip(inputs, grad_in):
        print(inp.dtype, type(grad), end=', ')
    print('')
    print('%s output-gradient pairs: ' % module.name, end='')
    for out, grad in zip(outputs, grad_out):
        print(out.dtype, type(grad), end=', ')
    print('')

# use retain_forward_cache argument to provide hook function access to forward pass data during backward pass
hookmngr.register_backward_hook(print_outputs_with_grad, named_modules['final_attn'],
                               activate=False, hook_fn_name='print_outputs_with_grad',
                               retain_forward_cache=True)

In [15]:
# activate only print_outputs_with_grad
with hookmngr.hook_all_context(hook_types=[print_outputs_with_grad]):
    out, = xlm(src)
    mt_loss(out, trg).backward()

final_attn input-gradient pairs: torch.float32 <class 'torch.Tensor'>, torch.bool <class 'NoneType'>, 
final_attn output-gradient pairs: torch.float32 <class 'torch.Tensor'>, 


### Hook removal

In [16]:
print('\n'.join(hookmngr.name_to_hookhandle.keys()))

zero_input[embedding]
zero_input[final_attn]
print_mean[embedding]
print_mean[final_attn]
print_grad_shape[final_attn]
print_outputs_with_grad[final_attn]


In [17]:
# remove print_mean from final_attn
hookmngr.remove_hook_by_name('print_mean[final_attn]')

In [18]:
print('\n'.join(hookmngr.name_to_hookhandle.keys()))

zero_input[embedding]
zero_input[final_attn]
print_mean[embedding]
print_grad_shape[final_attn]
print_outputs_with_grad[final_attn]


In [19]:
# remove zero_input from all modules
hookmngr.remove_hook_function(zero_input)

In [20]:
print('\n'.join(hookmngr.name_to_hookhandle.keys()))

print_mean[embedding]
print_grad_shape[final_attn]
print_outputs_with_grad[final_attn]


In [21]:
# remove all remaining hooks from final_attn
hookmngr.remove_module_by_name('final_attn')

In [22]:
print('\n'.join(hookmngr.name_to_hookhandle.keys()))

print_mean[embedding]
