In [None]:
# default_exp core

In [None]:
#hide
%load_ext autoreload
%autoreload 2

In [None]:
#hide
from nbdev.showdoc import *

# VAT

> wip...

In [None]:
#export
from fastai.basics import *
from fastai.test_utils import synth_learner
from fastai.callback.all import *

## ALUM

Adversarial training for large neural language models as presented in https://arxiv.org/abs/2004.08994.

In [None]:
#export
def hook_out(m, inp, out):
    return out

In [None]:
#export
def KL(inp, targ, reduction="sum"):
    input = input.float()
    target = target.float()
    return F.kl_div(F.log_softmax(inp, dim=-1, dtype=torch.float32), F.softmax(targ, dim=-1, dtype=torch.float32), reduction=reduction)

In [None]:
#export
def SymmetrizedKL(inp, targ, reduction="sum"):
    return KL(inp, targ.detach(), reduction=reduction) + KL(targ, inp.detach(), reduction=reduction)

In [None]:
#export 
def adv_project(grad, norm_type='inf', eps=1e-6):
    if norm_type == 'l2':
        direction = grad / (torch.norm(grad, dim=-1, keepdim=True) + eps)
    elif norm_type == 'l1':
        direction = grad.sign()
    else:
        direction = grad / (grad.abs().max(-1, keepdim=True)[0] + eps)
    return direction

In [None]:
#export
def compute_adversarial_loss(model:nn.Module, embed:Tensor, logits:Tensor,
                             special_tokens_mask=None, token_type_mask=None,
                             noise_var:float=1e-5, step_size:float=1e-3, k:int=1,
                             noise_gamma:float=1e-6, criterion=SymmetrizedKL):
    "Computes adversarial loss on iteratively refined perturbation"
    noise = embed.data.new(embed.size()).normal_(0, noise_var)
    noise.requires_grad_();
    if special_tokens_mask is not None:
        noise = noise*special_tokens_mask
    if token_type_mask is not None:
        nosie = noise*token_type_mask

    for _ in range(k):
        newembed = embed + noise
        adv_logits = model(inputs_embeds=newembed).logits

        adv_loss = KL(adv_logits, logits.detach(), reduction="batchmean")
        delta_grad, = torch.autograd.grad(adv_loss, noise, only_inputs=True)

        norm = torch.linalg.norm(delta_grad)
        if (torch.isnan(norm) or torch.isinf(norm)):
            break
        noise = noise + delta_grad * step_size
        noise = adv_project(noise, norm_type="fro", eps=noise_gamma)

    newembed = embed + noise
    adv_logits = model(inputs_embeds=newembed).logits

    return criterion(adv_logits, logits)

In [None]:
#export 
class ALUMCallback(Callback):
    "ALUM callback for HuggingFace pretrained models"
    run_valid = False
    order = GradientAccumulation.order-1
    @delegates(compute_adversarial_loss)
    def __init__(self, m:nn.Module, alpha:float=1., start_epoch:int=0,
                 criterion=None, mask_special_tokens:bool=False, 
                 one_token_type=False, **kwargs):
        self.hook = None
        self.kwargs = kwargs if kwargs else {}
        self._do_vat=True
        self.special_tokens_mask, self.token_type_mask = None, None
        store_attr()
    
    def before_fit(self):
        if self.criterion is None:
            self.criterion = MSELoss() if isinstance(self.loss_func, nn.MSELoss) else SymmetrizedKL
        self.adv_loss_func = partial(compute_adversarial_loss, criterion=self.criterion, **self.kwargs)
    
    def before_batch(self):
        if (self.hook is None) and (self.epoch >= self.start_epoch):
            self.hook = Hook(self.m, hook_out)
            print(f'Starting virtual adversarial training at epoch {self.epoch}')

        if self.mask_special_tokens:
            self.special_tokens_mask = self.xb[0].pop('special_tokens_mask', None)
            if self.special_tokens_mask is not None:
                self.special_tokens_mask = (1-self.special_tokens_mask).unsqueeze(-1)
        if self.one_token_type:
            self.token_type_mask = self.xb[0].pop('token_type_ids', None)
            if self.token_type_mask is not None:
                # this would deterministically mask tokens of type 0
                self.token_type_mask = self.token_type_mask.unsqueeze(-1)

    def after_loss(self):
        if self.epoch >= self.start_epoch and self._do_vat:
            embed, logits = self.hook.stored, self.pred
            model = self.model.hf_model if hasattr(self.model, 'hf_model') else self.model
            try:    adv_loss = self.adv_loss_func(model, embed, logits, self.special_tokens_mask, self.token_type_mask)
            except TypeError as e:
                print("Your model is probably not supported, make sure model interface is compatible with HF pretrained models")
                adv_loss, self._do_vat = 0, False
            self.learn.loss_grad += adv_loss * self.alpha

    def after_fit(self):
        if self.hook is not None: self.hook.remove()

In [None]:
model = nn.Sequential(
    nn.Linear(1,10, bias=False),
    nn.Linear(10,1, bias=False)
)
learn = synth_learner(model=model, cbs=ALUMCallback(model[0]))

In [None]:
#hide
learn.show_training_loop()

Start Fit
   - before_fit     : [TrainEvalCallback, ALUMCallback, Recorder, ProgressCallback]
  Start Epoch Loop
     - before_epoch   : [Recorder, ProgressCallback]
    Start Train
       - before_train   : [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - before_batch   : [ALUMCallback]
         - after_pred     : []
         - after_loss     : [ALUMCallback]
         - before_backward: []
         - before_step    : []
         - after_step     : []
         - after_cancel_batch: []
         - after_batch    : [TrainEvalCallback, Recorder, ProgressCallback]
      End Batch Loop
    End Train
     - after_cancel_train: [Recorder]
     - after_train    : [Recorder, ProgressCallback]
    Start Valid
       - before_validate: [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - **CBs same as train batch**: []
      End Batch Loop
    End Valid
     - after_cancel_validate: [Recorder]
     - after_validate : [Recorder, Progress

In [None]:
learn.fit(2, 1e-3)

epoch,train_loss,valid_loss,time
0,12.705046,7.071889,00:00
1,11.292135,8.504992,00:00


Starting virtual adversarial training at epoch 0
Your model is probably not supported, make sure model interface is compatible with HF pretrained models


In [None]:
#export
class VATCallback(Callback):
    "VAT callback (draft)"
    run_valid=False
    # mb worth adding capability to inject adversarial noize into intermediate activations
    # for ALUM case we could perturb outputs of the embedding layer instead of embedding weights (which would be equivalent)
    def __init__(self, start_iter=None): #?? potentially start in the middle of training
        
        self.start_iter = start_iter
        
    def after_loss(self):
        #TODO: detach as appropriate
        noize = 0
        x_adv = self.x + noize #?? take care of possible multiple inputs 
        logits = self.pred
        print(f'{self.train_iter:2} - Do stuff here with input of shape {self.x.shape} and logits {logits.shape} and modify loss {self.loss:.4f}')
        # do VAT stuff here


In [None]:
model = nn.Sequential(
    nn.Linear(1,10, bias=False),
    nn.Linear(10,1, bias=False)
)
learn = synth_learner(model=model, cbs=VATCallback())
learn.fit(1, 1e-3)

epoch,train_loss,valid_loss,time
0,10.577516,10.631283,00:00


 0 - Do stuff here with input of shape torch.Size([16, 1]) and logits torch.Size([16, 1]) and modify loss 14.7863
 1 - Do stuff here with input of shape torch.Size([16, 1]) and logits torch.Size([16, 1]) and modify loss 14.5227
 2 - Do stuff here with input of shape torch.Size([16, 1]) and logits torch.Size([16, 1]) and modify loss 8.2985
 3 - Do stuff here with input of shape torch.Size([16, 1]) and logits torch.Size([16, 1]) and modify loss 11.5245
 4 - Do stuff here with input of shape torch.Size([16, 1]) and logits torch.Size([16, 1]) and modify loss 10.8708
 5 - Do stuff here with input of shape torch.Size([16, 1]) and logits torch.Size([16, 1]) and modify loss 9.9784
 6 - Do stuff here with input of shape torch.Size([16, 1]) and logits torch.Size([16, 1]) and modify loss 8.0394
 7 - Do stuff here with input of shape torch.Size([16, 1]) and logits torch.Size([16, 1]) and modify loss 8.8918
 8 - Do stuff here with input of shape torch.Size([16, 1]) and logits torch.Size([16, 1]) an

## Fin

In [None]:
#hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_core.ipynb.
Converted index.ipynb.
