In [5]:
!nvidia-smi

Sat Apr 17 10:19:39 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  GeForce RTX 208...  On   | 00000000:65:00.0 Off |                  N/A |
| 31%   25C    P8    12W / 250W |     21MiB / 11019MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [6]:
import sys
if 'google.colab' in sys.modules:
    !pip install -Uqq fastcore onnx onnxruntime sentencepiece seqeval rouge-score
    !pip install -Uqq --no-deps fastai ohmeow-blurr
    !pip install -Uqq transformers datasets wandb 

In [1]:
import gc
import wandb
from fastai.text.all import *
from fastai.callback.wandb import *

In [2]:
def read_text(fn):
    return open(fn).read()

In [3]:
path = untar_data(URLs.IMDB)

## Setup

In [4]:
model_name = 'distilbert-base-uncased'

max_len = 512
bs = 8
val_bs = 16

## Training

In [5]:
def _to_device(e, device):
    if hasattr(e, 'to'): return e.to(device)
    elif isinstance(e, dict):
        for _, v in e.items():
            if hasattr(v, 'to'): v.to(device)
        return {k:(v.to(device) if hasattr(v, 'to') else v) for k, v in e.items()}

In [6]:
@patch
def one_batch(self:Learner, i, b):
        self.iter = i
        b_on_device = tuple(_to_device(e, self.dls.device) for e in b) if self.dls.device is not None else b
        self._split(b_on_device)
        self._with_events(self._do_one_batch, 'batch', CancelBatchException)

In [7]:
from transformers import *

from blurr.data.all import *
from blurr.modeling.all import *

[nltk_data] Downloading package wordnet to /home/morgan/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [8]:
hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(model_name, model_cls=AutoModelForSequenceClassification,
                                                                               tokenizer_cls=AutoTokenizer, tokenizer_kwargs={'max_len':512})

In [9]:
blocks = (HF_TextBlock(hf_arch, hf_config, hf_tokenizer, hf_model), CategoryBlock)
dblock = DataBlock(blocks=blocks, 
                   get_items=get_text_files,
                   get_x = read_text,
                   get_y=parent_label,
                   splitter=GrandparentSplitter(valid_name='test'))

dls = dblock.dataloaders(path, bs=bs, val_bs=val_bs)

### vat finetuning

In [10]:
import torch.nn.functional as F
from torch import linalg as LA

def KL(input, target, reduction="sum"):
    input = input.float()
    target = target.float()
    loss = F.kl_div(F.log_softmax(input, dim=-1, dtype=torch.float32), F.softmax(target, dim=-1, dtype=torch.float32), reduction=reduction)
    return loss

In [11]:
from fastai.callback.all import Hook

def hook_out(m, inp, out):
    return out

In [12]:
def adv_project(grad, norm_type='inf', eps=1e-6):
    if norm_type == 'l2':
        direction = grad / (torch.norm(grad, dim=-1, keepdim=True) + eps)
    elif norm_type == 'l1':
        direction = grad.sign()
    else:
        direction = grad / (grad.abs().max(-1, keepdim=True)[0] + eps)
    return direction

In [13]:
def compute_adversarial_loss(model:nn.Module, embed:Tensor, logits:Tensor, 
                             noise_var:float=1e-5, step_size:float=1e-3, k:int=1,
                             noise_gamma:float=1e-6):
    "This is nice docstring"
    noise = embed.data.new(embed.size()).normal_(0, noise_var) 
    noise.requires_grad_();

    for _ in range(k):
        newembed = embed + noise
        adv_logits = model(inputs_embeds=newembed).logits

        adv_loss = KL(adv_logits, logits.detach(), reduction="batchmean") 
        delta_grad, = torch.autograd.grad(adv_loss, noise, only_inputs=True)

        norm = LA.norm(delta_grad) 
        if (torch.isnan(norm) or torch.isinf(norm)):
            break

        noise = noise + delta_grad * step_size
        noise = adv_project(noise, norm_type="fro", eps=noise_gamma)

    newembed = embed + noise
    adv_logits = model(inputs_embeds=newembed).logits

    adv_loss_f = KL(adv_logits, logits.detach())
    adv_loss_b = KL(logits, adv_logits.detach())
    return adv_loss_f + adv_loss_b

In [14]:
class ALUMCallback(Callback):
    "ALUM callback (draft)"
    run_valid = False
    order = GradientAccumulation.order-1
    @delegates(compute_adversarial_loss)
    def __init__(self, m:nn.Module, alpha:float=1., start_epoch:int=1, **kwargs):
        self.hook = None
        self.adv_loss_func = partial(compute_adversarial_loss, **kwargs) if kwargs else compute_adversarial_loss
        store_attr()

    def before_batch(self):
        if (self.hook is None) and (self.epoch >= self.start_epoch):
            self.hook = Hook(self.m, hook_out)
            print(f'Starting virtual adversarial training at epoch {self.epoch}')

    def after_loss(self):
        if self.epoch >= self.start_epoch:
            embed, logits = self.hook.stored, self.pred
            adv_loss = self.adv_loss_func(self.model.hf_model, embed, logits)
            self.learn.loss_grad += adv_loss * self.alpha

    def after_fit(self):
        if self.hook is not None: self.hook.remove()

# Run a HyperParameter Sweep

In [15]:
wandb.login()

[34m[1mwandb[0m: Currently logged in as: [33mmorgan[0m (use `wandb login --relogin` to force relogin)


True

In [16]:
def train():
    run = wandb.init();

    hf_arch, hf_config, hf_tokenizer, hf_model = BLURR_MODEL_HELPER.get_hf_objects(
        model_name, model_cls=AutoModelForSequenceClassification, 
        tokenizer_cls=AutoTokenizer, tokenizer_kwargs={'max_len':512})
    
    model = HF_BaseModelWrapper(hf_model)
    learn = Learner(dls,
                    model,
                    opt_func=RAdam,
                    metrics=[accuracy],
                    cbs=[HF_BaseModelCallback, GradientAccumulation(8)],
                    splitter=hf_splitter).to_fp16()

    learn.add_cb(ALUMCallback(learn.model.hf_model.base_model.embeddings, 
                             start_epoch = run.config.start_epoch,
                             alpha=run.config.alpha,
                             noise_var=run.config.noise_var,
                             noise_gamma =run.config.noise_gamma,
                             step_size=run.config.step_size
                             ));

    learn.fit_one_cycle(4, 2e-5, cbs=[WandbCallback(log_preds=False, log_model=False)])
    del learn
    gc.collect()
    torch.cuda.empty_cache()
    torch.cuda.synchronize()

In [17]:
sweep_config = {
  "name": "ALUM test sweep",
  "method": "random",
  "parameters": {
        "start_epoch": {"values":[0,1]},
        "alpha": {"values": [0.0, 0.25,0.5,1,2,4,8,10,20]},
        "noise_var": {"values": [1e-6, 1e-5, 1e-4, 1e-3]},
        "noise_gamma": {"values": [1e-7, 1e-6, 1e-5, 1e-4, 1e-3]},
        "step_size": {"values": [1e-5, 1e-4, 1e-3, 1e-2]},   
    },
  "metric":{"goal": "maximise", "name": "accuracy"},
  "early_terminate": {"type": "hyperband", "s": 2, "eta": 3, "max_iter": 60}
}

In [18]:
sweep_id = wandb.sweep(sweep_config, project="vat", entity="fastai_community")

Create sweep with ID: 5x4o4q3e
Sweep URL: https://wandb.ai/fastai_community/vat/sweeps/5x4o4q3e


In [None]:
wandb.agent(sweep_id, function=train)

[34m[1mwandb[0m: Agent Starting Run: mrhnqn9y with config:
[34m[1mwandb[0m: 	alpha: 0.5
[34m[1mwandb[0m: 	noise_gamma: 1e-07
[34m[1mwandb[0m: 	noise_var: 1e-05
[34m[1mwandb[0m: 	start_epoch: 0
[34m[1mwandb[0m: 	step_size: 0.001


Could not gather input dimensions


epoch,train_loss,valid_loss,accuracy,time


Starting virtual adversarial training at epoch 0
