In [4]:
# default_exp data

In [5]:
#hide
%load_ext autoreload
%autoreload 2

In [6]:
#hide
from nbdev.showdoc import *

In [7]:
#export
from fastai.basics import *
from inspect import signature

# Learner for transformers

## Parameter groups

In [10]:
# export
def default_splitter(model):
    groups = L([model.base_model.embeddings, model.base_model.encoder]) + L(model.children())[1:]
    return groups.map(params)

## TransLearner and utils

In [12]:
#export
def to_device(b, device=None):
    "Recursively put `b` on `device`. Handles `dict`s"
    if defaults.use_cuda==False: device='cpu'
    elif device is None: device=default_device()
    def _inner(o):
        if isinstance(o,Tensor): return o.to(device, non_blocking=True)
        elif hasattr(o, "to_device"): return o.to_device(device)
        elif isinstance(o, dict): return {k:to_device(v) for k,v in o.items()}
        else: return o
    return apply(_inner, b)

In [13]:
#export
class TransCallback(Callback):
    "Handles usecase with loss returned by HuggingFace model"
    def after_pred(self):
        if 'loss' in self.pred:
            self.learn.loss_grad = self.pred.loss
            self.learn.loss = self.pred.loss.clone()
            self.learn.yb = (self.xb[0]['labels'], )
            self.learn.compute_loss = False
        self.learn.pred = self.pred.logits

In [14]:
#export
@delegates(Learner.__init__)
class TransLearner(Learner):
    "Learner for training transformers from HuggingFace"
    def __init__(self, dls, model, **kwargs):
        super().__init__(dls, model, **kwargs)
        self.model_args = set(signature(model.forward).parameters.keys())
        self.add_cb(TransCallback())
        self.compute_loss = True

    def one_batch(self, i, b):
        self.iter = i
        b_on_device = tuple(to_device(e) for e in b) if self.dls.device is not None else b
        self._split(b_on_device)
        self._with_events(self._do_one_batch, 'batch', CancelBatchException)
    
    def _do_one_batch(self):
        x = self.xb[0]
        for k in x.keys():
            if k not in self.model_args: del x[k]
        self.pred = self.model(**self.x)
        self('after_pred')
        if len(self.yb) and self.compute_loss:
            self.loss_grad = self.loss_func(self.pred, *self.yb)
            self.loss = self.loss_grad.clone()
        self('after_loss')
        if not self.training or not len(self.yb): return
        self('before_backward')
        self.loss_grad.backward()
        self._with_events(self.opt.step, 'step', CancelStepException)
        self.opt.zero_grad()

In [15]:
#hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_data.ipynb.
Converted 01_learner.ipynb.
Converted index.ipynb.
