In [None]:
# default_exp learner

In [None]:
#default_cls_lvl 3

In [None]:
#hide
%load_ext autoreload
%autoreload 2

In [None]:
#hide
from nbdev.showdoc import *

In [None]:
#export
from fastai.basics import *
from fastai.text.all import TensorText
from inspect import signature
from fasthugs.data import TransformersTextBlock

from transformers import AutoTokenizer, AutoConfig, AutoModelForSequenceClassification, BatchEncoding
from transformers.modeling_outputs import QuestionAnsweringModelOutput

# Learner for transformers

## Parameter groups

TODOs:
- [x] exclude modules w/o params
- [ ] add layerwise splitter for Transfomers

In [None]:
#skip
#hide
# for n, m in model.base_model.named_children(): print(n)

In [None]:
# export
def default_splitter(model):
    groups = L(model.base_model.children()) + L(m for m in list(model.children())[1:] if params(m))
    return groups.map(params)

In [None]:
def layerwise_splitter(model):
    raise NotImplementedError('use default_splitter for now')

## TransLearner and utils

In [None]:
#export
@typedispatch
def show_results(x: TensorText, y, samples, outs, ctxs=None, max_n=10, trunc_at=150, **kwargs):
    if ctxs is None: ctxs = get_empty_df(min(len(samples), max_n))
    if isinstance(samples[0][0], tuple):
        samples = L((*s[0], *s[1:]) for s in samples)
        if trunc_at is not None: samples = L((s[0].truncate(trunc_at), s[1].truncate(trunc_at), *s[2:]) for s in samples)
    elif trunc_at is not None: samples = L((s[0].truncate(trunc_at),*s[1:]) for s in samples)
    ctxs = show_results[object](x, y, samples, outs, ctxs=ctxs, max_n=max_n, **kwargs)
    display_df(pd.DataFrame(ctxs))
    return ctxs

In [None]:
#export
def to_device(b, device=None):
    "Recursively put `b` on `device`. Handles `BatchEncoding`s"
    if defaults.use_cuda==False: device='cpu'
    elif device is None: device=default_device()
    def _inner(o):      
        if isinstance(o,Tensor): return o.to(device, non_blocking=True)
        elif isinstance(o,BatchEncoding): return o.to(device)
        # elif hasattr(o, "to_device"): return o.to_device(device)
        else: return o
    return apply(_inner, b)

In [None]:
#cuda
device = torch.device('cuda:0')
d = {'a':tensor([1,2,3])}
d_cuda = to_device(d)
assert d_cuda['a'].device == device

In [None]:
#export
class TransCallback(Callback):
    "Handles HuggingFace model inputs and outputs"
    def __init__(self, model):
        self.labels = tuple()
        self.model_args = {k:v.default for k, v in signature(model.forward).parameters.items()}
    
    def before_batch(self):
        if 'labels' in self.xb[0].keys():
            self.labels = (self.xb[0]['labels'], )
        self.learn.xb = tuple([self.xb[0].get(k, self.model_args[k]) for k in self.model_args.keys()])
    
    def after_pred(self):
        if 'loss' in self.pred:
            self.learn.loss_grad = self.pred.loss
            self.learn.loss = self.pred.loss.clone()
        if isinstance(self.pred, QuestionAnsweringModelOutput):
            self.learn.pred = (self.pred.start_logits, self.pred.end_logits)
        else: self.learn.pred = self.pred.logits
    
    def after_loss(self):
        if len(self.labels):
            self.learn.yb = self.labels
            self.labels = tuple()

In [None]:
#export
@delegates(Learner.__init__)
class TransLearner(Learner):
    "Learner for training transformers from HuggingFace"
    def __init__(self, dls, model, **kwargs):
        splitter = kwargs.get('splitter', None)
        if splitter is None: kwargs['splitter'] = default_splitter
        super().__init__(dls, model, **kwargs)
        self.add_cb(TransCallback(model))

In [None]:
#export
@patch
def _set_device(self:TransLearner, b):
    model_device = torch.device(torch.cuda.current_device()) if next(self.model.parameters()).is_cuda else torch.device('cpu')
    dls_device = getattr(self.dls, 'device', default_device())
    if model_device == dls_device: return to_device(b, dls_device)
    else: return to_device(b, model_device)

### Using TransLearner for sequence classification

In [None]:
#slow
path = untar_data(URLs.IMDB_SAMPLE)
texts = pd.read_csv(path/'texts.csv')

model_name = 'distilbert-base-uncased'
max_len = 128
bs = 8
val_bs = 16
tokenizer = AutoTokenizer.from_pretrained(model_name)
dblock = DataBlock(blocks = [TransformersTextBlock(tokenizer=tokenizer), CategoryBlock()],
                   get_x=ItemGetter('text'),
                   get_y=ItemGetter('label'),
                   splitter=ColSplitter())
dls = dblock.dataloaders(texts, bs=bs, val_bs=val_bs)

In [None]:
#slow
model = AutoModelForSequenceClassification.from_pretrained(model_name)
learn = TransLearner(dls, model, metrics=accuracy)
learn.fit(2, 2e-5)

Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertForSequenceClassification: ['vocab_layer_norm.weight', 'vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_projector.bias', 'vocab_projector.weight', 'vocab_transform.weight']
- This IS expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing DistilBertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.weight', 'pre_classifier.bias', 'classifier

epoch,train_loss,valid_loss,accuracy,time
0,0.52515,0.571722,0.73,00:22
1,0.311822,0.347995,0.88,00:22


## Fin

In [None]:
#hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_data.ipynb.
Converted 01_learner.ipynb.
Converted 10_examples.classification-imdb.ipynb.
Converted 11_examples.mlm-imdb.ipynb.
Converted 12_examples.glue-benchmark.ipynb.
Converted 12a_examples.glue-benchmark-sweeps.ipynb.
Converted index.ipynb.
