In [None]:
# default_exp core

# Core

> Basic healper functions

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#hide
%load_ext autoreload
%autoreload 2

In [None]:
#export
import torch
from torch import nn, einsum
import torch.nn.functional as F
import torch.autograd.profiler as profiler
from fastai.basics import *
from fastai.text.all import *

from functools import partial, reduce, wraps
from inspect import isfunction
from operator import mul
from copy import deepcopy

from torch import Tensor
from typing import Tuple

from einops import rearrange, repeat

## Helper functions

### General purpose utils

In [None]:
#export
def exists(val):
    return val is not None

def default(val, d):
    if exists(val):
        return val
    return d() if isfunction(d) else d

def expand_dim1(x):
    if len(x.shape) == 1:
        return x[None, :]
    else: return x

def max_neg_value(tensor):
    return -torch.finfo(tensor.dtype).max

### Generative utils

In [None]:
#export
# generative helpers
# credit https://github.com/huggingface/transformers/blob/a0c62d249303a68f5336e3f9a96ecf9241d7abbe/src/transformers/generation_logits_process.py
def top_p_filter(logits, top_p=0.9):
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cum_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    sorted_indices_to_remove = cum_probs > top_p
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0
    # if min_tokens_to_keep > 1:
    #         # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
    #         sorted_indices_to_remove[..., : min_tokens_to_keep - 1] = 0
    indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
    logits[indices_to_remove] = float('-inf')
    return logits

def top_k_filter(logits, top_k=20):
    indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
    logits[indices_to_remove] = float('-inf')
    return logits

_sampler = {
    'top_k':top_k_filter,
    'top_p':top_p_filter,
    'gready':lambda x: x.argmax(-1)
}

### LSH specific helpers

From [lucidrains/reformer-pytorch](https://github.com/lucidrains/reformer-pytorch/).

In [None]:
#exports
def cache_method_decorator(cache_attr, cache_namespace, reexecute = False):
    def inner_fn(fn):
        @wraps(fn)
        def wrapper(self, *args, key_namespace=None, fetch=False, set_cache=True, **kwargs):
            namespace_str = str(default(key_namespace, ''))
            _cache = getattr(self, cache_attr)
            _keyname = f'{cache_namespace}:{namespace_str}'

            if fetch:
                val = _cache[_keyname]
                if reexecute:
                    fn(self, *args, **kwargs)
            else:
                val = fn(self, *args, **kwargs)
                if set_cache:
                    setattr(self, cache_attr, {**_cache, **{_keyname: val}})
            return val
        return wrapper
    return inner_fn

In [None]:
#exports
def look_one_back(x):
    x_extra = torch.cat([x[:, -1:, ...], x[:, :-1, ...]], dim=1)
    return torch.cat([x, x_extra], dim=2)

In [None]:
#exports
def chunked_sum(tensor, chunks=1):
    *orig_size, last_dim = tensor.shape
    tensor = tensor.reshape(-1, last_dim)
    summed_tensors = [c.sum(dim=-1) for c in tensor.chunk(chunks, dim=0)]
    return torch.cat(summed_tensors, dim=0).reshape(orig_size)

In [None]:
#exports
def sort_key_val(t1, t2, dim=-1):
    values, indices = t1.sort(dim=dim)
    t2 = t2.expand_as(t1)
    return values, t2.gather(dim, indices)

In [None]:
#exports
def batched_index_select(values, indices):
    last_dim = values.shape[-1]
    return values.gather(1, indices[:, :, None].expand(-1, -1, last_dim))

## Profiling functions

Utility functions to assess model performance. Test functions with `mod` and input `x`. 

In [None]:
mod = get_text_classifier(AWD_LSTM, vocab_sz=10_000, n_class=10)
x = torch.randint(0, 100, (3, 72))

In [None]:
#export
def do_cuda_timing(f, inp, context=None, n_loops=100):
    '''
        Get timings of cuda modules. Note `self_cpu_time_total` is returned, but
        from experiments this appears to be similar/same to the total CUDA time
        
        f :  function to profile, typically an nn.Module
        inp : required input to f
        context : optional additional input into f, used for Decoder-style modules
    '''
    f.cuda()
    inp = inp.cuda()
    if context is not None: context = context.cuda()
    with profiler.profile(record_shapes=False, use_cuda=True) as prof:
        with profiler.record_function("model_inference"):
            with torch.no_grad():
                for _ in range(n_loops):
                    if context is None: f(inp)
                    else: f(inp, context)
                    torch.cuda.synchronize()
                    
    res = round((prof.key_averages().self_cpu_time_total / 1000) / n_loops, 3)
    print(f'{res}ms')
    return res

In [None]:
#export
def model_performance(n_loops=5, model='arto', dls=None, n_epochs=1, lr=5e-4):
    """
        DEMO CODE ONLY!
        Run training loop to measure timings. Note that the models internally
        should be changed depending on the model you would like to use. 
        You should also adjust the metrics you are monitoring
    """
    acc_ls, ppl_ls =[], []
    for i in range(n_loops):
        # ADD YOUR MODEL(S) INIT HERE
#         if model == 'arto': m = artoTransformerLM(vocab_sz, 512)
#         elif model == 'pt': m = ptTransformerLM(vocab_sz, 512)
#         else: print('model name not correct')
        
        learn = Learner(dls, m,
                    loss_func=CrossEntropyLossFlat(),
                    metrics=[accuracy, Perplexity()]).to_native_fp16()

        learn.fit_one_cycle(n_epochs, lr, wd=0.05)
        
        acc_ls.append(learn.recorder.final_record[2])
        ppl_ls.append(learn.recorder.final_record[3])
    print(f'Avg Accuracy: {round(sum(acc_ls)/len(acc_ls),3)}, std: {np.std(acc_ls)}')
    print(f'Avg Perplexity: {round(sum(ppl_ls)/len(ppl_ls),3)}, std: {np.std(ppl_ls)}')
    print()
    return learn, acc_ls, ppl_ls

In [None]:
#export
def total_params(m):
    """
    Give the number of parameters of a module and if it's trainable or not
    - Taken from Taken from fastai.callback.hook
    """
    params = sum([p.numel() for p in m.parameters()])
    trains = [p.requires_grad for p in m.parameters()]
    return params, (False if len(trains)==0 else trains[0])

Number of params for our test model:

In [None]:
total_params(mod)

(24336280, True)

In [None]:
#hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_core.ipynb.
Converted 01_layers.ipynb.
Converted 02_attention.ipynb.
Converted 03_transformer.ipynb.
Converted 04_reformer.ipynb.
Converted index.ipynb.
