In [None]:
!nvidia-smi

Thu Apr 15 14:53:23 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.67       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |
| N/A   66C    P8    11W /  70W |      0MiB / 15109MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
import sys
if 'google.colab' in sys.modules:
    !pip install -Uqq fastcore onnx onnxruntime sentencepiece seqeval rouge-score
    !pip install -Uqq --no-deps fastai ohmeow-blurr
    !pip install -Uqq transformers datasets wandb 

In [None]:
from fastai.text.all import *
from fastai.callback.wandb import *

In [None]:
def read_text(fn):
    return open(fn).read()

In [None]:
path = untar_data(URLs.IMDB)
# path = untar_data(URLs.IMDB_SAMPLE)

## Setup

In [None]:
model_name = 'distilbert-base-uncased'

max_len = 512
bs = 4
val_bs = 16

## Tracking

In [None]:
# !wandb login

In [None]:
import wandb

WANDB_NAME = f'imdb-{model_name}-alum'
GROUP = f'IMDB-{model_name}-alum'
NOTES = f'Simple finetuning {model_name} with RAdam lr=2e-5'
CONFIG = {}
TAGS =[model_name,'imdb','radam','alum']

In [None]:
wandb.init(reinit=True, project="vat", entity="fastai_community",
           name=WANDB_NAME, group=GROUP, notes=NOTES, tags=TAGS, config=CONFIG);

[34m[1mwandb[0m: Currently logged in as: [33mfastai_community[0m (use `wandb login --relogin` to force relogin)


## Training

In [None]:
def _to_device(e, device):
    if hasattr(e, 'to'): return e.to(device)
    elif isinstance(e, dict):
        for _, v in e.items():
            if hasattr(v, 'to'): v.to(device)
        return {k:(v.to(device) if hasattr(v, 'to') else v) for k, v in e.items()}

In [None]:
@patch
def one_batch(self:Learner, i, b):
        self.iter = i
        b_on_device = tuple(_to_device(e, self.dls.device) for e in b) if self.dls.device is not None else b
        self._split(b_on_device)
        self._with_events(self._do_one_batch, 'batch', CancelBatchException)

In [None]:
from transformers import *

from blurr.data.all import *
from blurr.modeling.all import *

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [None]:
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(model_name, model_cls=AutoModelForSequenceClassification,
                                                                               tokenizer_cls=AutoTokenizer, tokenizer_kwargs={'max_len':512})

In [None]:
blocks = (HF_TextBlock(hf_arch, hf_config, hf_tokenizer, hf_model), CategoryBlock)
dblock = DataBlock(blocks=blocks, 
                   get_items=get_text_files,
                   get_x = read_text,
                   get_y=parent_label,
                   splitter=GrandparentSplitter(valid_name='test'))

dls = dblock.dataloaders(path, bs=bs, val_bs=val_bs)

In [None]:
# texts = pd.read_csv(path/'texts.csv')

In [None]:
# blocks = (HF_TextBlock(hf_arch, hf_config, hf_tokenizer, hf_model), CategoryBlock)
# dblock = DataBlock(blocks=blocks,
#                    get_x = ColReader('text'),
#                    get_y = ColReader('label'),
#                    splitter = ColSplitter()
#                    )

# dls = dblock.dataloaders(texts, bs=bs, val_bs=val_bs)

In [None]:
# model = HF_BaseModelWrapper(hf_model)
# learn = Learner(dls,
#                 model,
#                 opt_func=RAdam,
#                 metrics=[accuracy],
#                 cbs=[HF_BaseModelCallback],
#                 splitter=hf_splitter).to_fp16()

# learn.blurr_summary()

In [None]:
# learn.show_training_loop()

### vat finetuning

In [None]:
import torch.nn.functional as F
from torch import linalg as LA

def KL(input, target, reduction="sum"):
    input = input.float()
    target = target.float()
    loss = F.kl_div(F.log_softmax(input, dim=-1, dtype=torch.float32), F.softmax(target, dim=-1, dtype=torch.float32), reduction=reduction)
    return loss

In [None]:
from fastai.callback.all import Hook

def hook_out(m, inp, out):
    return out

In [None]:
def adv_project(grad, norm_type='inf', eps=1e-6):
    if norm_type == 'l2':
        direction = grad / (torch.norm(grad, dim=-1, keepdim=True) + eps)
    elif norm_type == 'l1':
        direction = grad.sign()
    else:
        direction = grad / (grad.abs().max(-1, keepdim=True)[0] + eps)
    return direction

In [None]:
def compute_adversarial_loss(model:nn.Module, embed:Tensor, logits:Tensor, 
                             noise_var:float=1e-5, step_size:float=1e-3, k:int=1,
                             noise_gamma:float=1e-6):
    "This is nice docstring"
    noise = embed.data.new(embed.size()).normal_(0, noise_var) 
    noise.requires_grad_();

    for _ in range(k):
        newembed = embed + noise
        adv_logits = model(inputs_embeds=newembed).logits

        adv_loss = KL(adv_logits, logits.detach(), reduction="batchmean") 
        delta_grad, = torch.autograd.grad(adv_loss, noise, only_inputs=True)

        norm = LA.norm(delta_grad) 
        if (torch.isnan(norm) or torch.isinf(norm)):
            break

        noise = noise + delta_grad * step_size
        noise = adv_project(noise, norm_type="fro", eps=noise_gamma)

    newembed = embed + noise
    adv_logits = model(inputs_embeds=newembed).logits

    adv_loss_f = KL(adv_logits, logits.detach())
    adv_loss_b = KL(logits, adv_logits.detach())
    return adv_loss_f + adv_loss_b

In [None]:
class ALUMCallback(Callback):
    "ALUM callback (draft)"
    run_valid = False
    order = GradientAccumulation.order-1
    @delegates(compute_adversarial_loss)
    def __init__(self, m:nn.Module, alpha:float=1., start_epoch:int=1, **kwargs):
        self.hook = None
        self.adv_loss_func = partial(compute_adversarial_loss, **kwargs) if kwargs else compute_adversarial_loss
        store_attr()

    def before_batch(self):
        if (self.hook is None) and (self.epoch >= self.start_epoch):
            self.hook = Hook(self.m, hook_out)
            print(f'Starting virtual adversarial training at epoch {self.epoch}')

    def after_loss(self):
        if self.epoch >= self.start_epoch:
            embed, logits = self.hook.stored, self.pred
            adv_loss = self.adv_loss_func(self.model.hf_model, embed, logits)
            self.learn.loss_grad += adv_loss * self.alpha

    def after_fit(self):
        if self.hook is not None: self.hook.remove()

In [None]:
model = HF_BaseModelWrapper(hf_model)
learn = Learner(dls,
                model,
                opt_func=RAdam,
                metrics=[accuracy],
                cbs=[HF_BaseModelCallback, GradientAccumulation(8)],
                splitter=hf_splitter).to_fp16()

In [None]:
learn.add_cb(ALUMCallback(learn.model.hf_model.base_model.embeddings, start_epoch=1));

In [None]:
learn.fit_one_cycle(4, 2e-5, cbs=WandbCallback(log_preds=False, log_model=False))

Could not gather input dimensions


epoch,train_loss,valid_loss,accuracy,time
0,0.228204,0.24411,0.9008,10:00
1,0.170476,0.175666,0.93008,20:48
2,0.104206,0.156094,0.94176,20:51
3,0.063551,0.150507,0.94376,20:47


Exception ignored in: <finalize object at 0x7feb41aa8800; dead>
Traceback (most recent call last):
  File "/usr/lib/python3.7/weakref.py", line 572, in __call__
    return info.func(*info.args, **(info.kwargs or {}))
  File "/usr/lib/python3.7/tempfile.py", line 936, in _cleanup
    _rmtree(name)
  File "/usr/lib/python3.7/shutil.py", line 485, in rmtree
    onerror(os.lstat, path, sys.exc_info())
  File "/usr/lib/python3.7/shutil.py", line 483, in rmtree
    orig_st = os.lstat(path)
FileNotFoundError: [Errno 2] No such file or directory: '/tmp/tmpdpmakwvl'
Exception ignored in: <finalize object at 0x7feb41aa8720; dead>
Traceback (most recent call last):
  File "/usr/lib/python3.7/weakref.py", line 572, in __call__
    return info.func(*info.args, **(info.kwargs or {}))
  File "/usr/lib/python3.7/tempfile.py", line 936, in _cleanup
    _rmtree(name)
  File "/usr/lib/python3.7/shutil.py", line 485, in rmtree
    onerror(os.lstat, path, sys.exc_info())
  File "/usr/lib/python3.7/shutil.p

Starting virtual adversarial training at epoch 1
