In [None]:
!nvidia-smi

Mon Apr 19 17:59:32 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Quadro RTX 5000     Off  | 00000000:1C:00.0 Off |                  Off |
| 35%   29C    P8     5W / 230W |      0MiB / 16125MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
import sys
if 'google.colab' in sys.modules:
    !pip install -Uqq fastcore onnx onnxruntime sentencepiece seqeval rouge-score
    !pip install -Uqq --no-deps fastai ohmeow-blurr
    !pip install -Uqq transformers datasets wandb 

In [None]:
from fastai.text.all import *
from fastai.callback.wandb import *

In [None]:
from transformers import *
from datasets import load_dataset, concatenate_datasets

from blurr.data.all import *
from blurr.modeling.all import *

  '"sox" backend is being deprecated. '
[nltk_data] Downloading package wordnet to /home/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


## Data preprocessing

In [None]:
ds_name = 'snli'

In [None]:
train_ds = load_dataset(ds_name, split='train')
valid_ds = load_dataset(ds_name, split='validation')

Reusing dataset snli (/home/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)
Reusing dataset snli (/home/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)


In [None]:
len(train_ds), len(valid_ds)

(550152, 10000)

In [None]:
train_ds.column_names

['premise', 'hypothesis', 'label']

In [None]:
train_ds[2]

{'premise': 'A person on a horse jumps over a broken down airplane.',
 'hypothesis': 'A person is outdoors, on a horse.',
 'label': 0}

In [None]:
from collections import Counter

In [None]:
Counter(train_ds['label'])

Counter({1: 182764, 2: 183187, 0: 183416, -1: 785})

In [None]:
train_ds = train_ds.filter(lambda sample: sample['label'] in [0,1,2])
valid_ds = valid_ds.filter(lambda sample: sample['label'] in [0,1,2])

Loading cached processed dataset at /home/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-18cfe39918caca0a.arrow
Loading cached processed dataset at /home/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-45271a826cbdfaba.arrow


## Setup

In [None]:
model_name = 'distilbert-base-uncased'
# data
max_len = 512
bs = 32
val_bs = bs*2
# training
lr = 2e-5

## Tracking

In [None]:
import wandb

WANDB_NAME = f'{ds_name}-{model_name}-alum'
GROUP = f'{ds_name}-{model_name}-alum-{lr:.0e}'
NOTES = f'Simple finetuning {model_name} with RAdam lr={lr:.0e}'
CONFIG = {}
TAGS =[model_name,ds_name,'radam','alum']

In [None]:
wandb.init(reinit=True, project="vat", entity="fastai_community",
           name=WANDB_NAME, group=GROUP, notes=NOTES, tags=TAGS, config=CONFIG);

[34m[1mwandb[0m: Currently logged in as: [33mfastai_community[0m (use `wandb login --relogin` to force relogin)


## Training

In [None]:
def _to_device(e, device):
    if hasattr(e, 'to'): return e.to(device)
    elif isinstance(e, dict):
        for _, v in e.items():
            if hasattr(v, 'to'): v.to(device)
        return {k:(v.to(device) if hasattr(v, 'to') else v) for k, v in e.items()}

In [None]:
@patch
def one_batch(self:Learner, i, b):
        self.iter = i
        b_on_device = tuple(_to_device(e, self.dls.device) for e in b) if self.dls.device is not None else b
        self._split(b_on_device)
        self._with_events(self._do_one_batch, 'batch', CancelBatchException)

In [None]:
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(model_name, model_cls=AutoModelForSequenceClassification, tokenizer_cls=AutoTokenizer, 
                                                                               config_kwargs={'num_labels':3}, tokenizer_kwargs={'max_len':512})

In [None]:
def get_x(sample):
    return sample['premise'], sample['hypothesis']

In [None]:
ds = concatenate_datasets([train_ds, valid_ds])
train_idx = list(range(len(train_ds)))
valid_idx = list(range(len(train_ds), len(train_ds)+len(valid_ds)))

In [None]:
# use number of chars as proxy to number of tokens for simplicity
lens = ds.map(lambda s: {'len': len(s['premise'])+len(s['hypothesis'])}, remove_columns=ds.column_names, num_proc=4)

Loading cached processed dataset at /home/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-5ca5f39f0f347987.arrow
Loading cached processed dataset at /home/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-b303872196dbefa3.arrow
Loading cached processed dataset at /home/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-22c40595756fbadc.arrow
Loading cached processed dataset at /home/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-f8c72ce8c6f484ef.arrow


In [None]:
train_lens = lens.select(train_idx)['len']
valid_lens = lens.select(valid_idx)['len']

In [None]:
blocks = (HF_TextBlock(hf_arch, hf_config, hf_tokenizer, hf_model),
          CategoryBlock(vocab={0:'entailment', 1:'neutral', 2:'contradiction'}))
dblock = DataBlock(blocks=blocks,
                   get_x = get_x,
                   get_y=ItemGetter('label'),
                   splitter=IndexSplitter(list(range(len(train_ds), len(train_ds)+len(valid_ds)))))
# dblock.summary(train_ds)

In [None]:
%%time
dls = dblock.dataloaders(ds, bs=bs, val_bs=val_bs, dl_kwargs=[{'res':train_lens}, {'val_res':valid_lens}], num_workers=4)

CPU times: user 1min 12s, sys: 1.68 s, total: 1min 14s
Wall time: 1min 14s


In [None]:
# b = dls.one_batch()

In [None]:
model = HF_BaseModelWrapper(hf_model)
learn = Learner(dls,
                model,
                opt_func=RAdam,
                metrics=[accuracy],
                cbs=[HF_BaseModelCallback],
                splitter=hf_splitter).to_fp16()

# learn.blurr_summary()

### ALUM finetuning

In [None]:
# !pip install git+git://github.com/aikindergarten/vat.git --no-deps -q

In [None]:
from vat.core import ALUMCallback

In [None]:
learn.add_cb(ALUMCallback(learn.model.hf_model.base_model.embeddings, start_epoch=2, alpha=0.5));

In [None]:
learn.fit_one_cycle(5, lr, cbs=WandbCallback(log_preds=False, log_model=False))

Could not gather input dimensions


epoch,train_loss,valid_loss,accuracy,time
0,0.426816,0.364193,0.861309,23:10
1,0.348046,0.301381,0.888437,23:13
2,0.388748,0.355832,0.886507,43:17
3,0.351712,0.333714,0.890063,43:38
4,0.3385,0.328792,0.8924,43:44


Starting virtual adversarial training at epoch 2


In [None]:
learn.validate()

(#2) [0.32879212498664856,0.8923999071121216]

In [None]:
test_ds = load_dataset('snli', split='test')
test_ds[0]

Reusing dataset snli (/home/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)


{'premise': 'This church choir sings to the masses as they sing joyous songs from the book at a church.',
 'hypothesis': 'The church has cracks in the ceiling.',
 'label': 1}

In [None]:
test_ds = test_ds.filter(lambda s: s['label'] in [0,1,2])
test_dl = dls.test_dl(test_ds, with_labels=True)
learn.validate(dl=test_dl)

Loading cached processed dataset at /home/.cache/huggingface/datasets/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-ea8100bd89f36a77.arrow


(#2) [0.33451008796691895,0.8907777070999146]

In [None]:
wandb.finish()

VBox(children=(Label(value=' 0.07MB of 0.07MB uploaded (0.00MB deduped)\r'), FloatProgress(value=1.0, max=1.0)…

0,1
epoch,5.0
train_loss,0.3385
raw_loss,0.33039
wd_0,0.0
sqr_mom_0,0.99
lr_0,0.0
mom_0,0.95
eps_0,1e-05
beta_0,0.0
wd_1,0.0


0,1
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
train_loss,█▅▄▃▃▂▂▂▂▂▂▁▁▁▁▁▂▂▂▂▂▂▂▂▁▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁
raw_loss,█▅▅▃▃▄▃▃▃▄▄▁▁▃▁▂▄▃▃▃▂▂▂▃▂▂▃▃▂▃▃▃▂▂▂▂▂▃▃▁
wd_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
sqr_mom_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
lr_0,▁▂▂▃▄▅▆▇▇██████▇▇▇▇▇▆▆▅▅▅▄▄▄▃▃▃▂▂▂▂▁▁▁▁▁
mom_0,██▇▆▅▄▃▂▂▁▁▁▁▁▁▂▂▂▂▂▃▃▄▄▄▅▅▅▆▆▆▇▇▇▇█████
eps_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
beta_0,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
wd_1,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁


## Validation on adversarial data

In [None]:
adv_ds = load_dataset('anli', split='test_r1')
adv_ds[0]

Reusing dataset anli (/home/.cache/huggingface/datasets/anli/plain_text/0.1.0/43fa2c99c10bf8478f1fa0860f7b122c6b277c4c41306255b7641257cf4e3299)


{'hypothesis': 'The first Ernest Jones store was opened on the continent of Europe.',
 'label': 0,
 'premise': 'Ernest Jones is a British jeweller and watchmaker. Established in 1949, its first store was opened in Oxford Street, London. Ernest Jones specialises in diamonds and watches, stocking brands such as Gucci and Emporio Armani. Ernest Jones is part of the Signet Jewelers group.',
 'reason': "The first store was opened in London, which is in Europe. It may have been difficult for the system because continents weren't mentioned.",
 'uid': '4aae63a8-fcf7-406c-a2f3-50c31c5934a9'}

In [None]:
test_dl = dls.test_dl(adv_ds, with_labels=True)

In [None]:
learn.validate(dl=test_dl)

(#2) [1.4400129318237305,0.30300000309944153]