In [None]:
# default_exp core

In [None]:
#hide
%load_ext autoreload
%autoreload 2

In [None]:
#hide
from nbdev.showdoc import *

# VAT

> wip...

In [None]:
#export
from fastai.basics import *
from fastai.test_utils import synth_learner
from fastai.callback.all import *

In [None]:
#export
class LossCallback(Callback):
    "Base class for loss-computing callbacks"
    def log_loss(self, loss:torch.Tensor, log_name:str):
        "Write `loss` item to `self.learn.log_extras`"
        log_extras = getattr(self.learn, 'log_extras', {})
        log_extras[log_name] = loss.detach().item()
        self.learn.log_extras = log_extras

## ALUM

Adversarial training for large neural language models as presented in https://arxiv.org/abs/2004.08994.

In [None]:
#export
def hook_out(m, inp, out):
    return out

In [None]:
#export
def KL(inp, targ, reduction="sum"):
    inp = inp.float()
    targ = targ.float()
    return F.kl_div(F.log_softmax(inp, dim=-1, dtype=torch.float32), F.softmax(targ, dim=-1, dtype=torch.float32), reduction=reduction)

In [None]:
#export
def SymmetrizedKL(inp, targ, reduction="sum"):
    return KL(inp, targ.detach(), reduction=reduction) + KL(targ, inp.detach(), reduction=reduction)

In [None]:
#export 
def adv_project(grad, norm_type='inf', eps=1e-6):
    if norm_type == 'l2':
        direction = grad / (torch.norm(grad, dim=-1, keepdim=True) + eps)
    elif norm_type == 'l1':
        direction = grad.sign()
    else:
        direction = grad / (grad.abs().max(-1, keepdim=True)[0] + eps)
    return direction

In [None]:
#export
def compute_adversarial_loss(model:nn.Module, embed:Tensor, logits:Tensor,
                             special_tokens_mask=None, token_type_mask=None,
                             noise_var:float=1e-5, step_size:float=1e-3, k:int=1,
                             noise_gamma:float=1e-6, criterion=SymmetrizedKL):
    "Computes adversarial loss on iteratively refined perturbation"
    noise = embed.data.new(embed.size()).normal_(0, noise_var)
    noise.requires_grad_();
    if special_tokens_mask is not None:
        noise = noise*special_tokens_mask
    if token_type_mask is not None:
        nosie = noise*token_type_mask

    for _ in range(k):
        newembed = embed + noise
        adv_logits = model(inputs_embeds=newembed).logits

        adv_loss = KL(adv_logits, logits.detach(), reduction="batchmean")
        delta_grad, = torch.autograd.grad(adv_loss, noise, only_inputs=True)

        norm = torch.linalg.norm(delta_grad)
        if (torch.isnan(norm) or torch.isinf(norm)):
            break
        noise = noise + delta_grad * step_size
        noise = adv_project(noise, norm_type="fro", eps=noise_gamma)

    newembed = embed + noise
    adv_logits = model(inputs_embeds=newembed).logits

    return criterion(adv_logits, logits)

In [None]:
#export 
class ALUMCallback(LossCallback):
    "ALUM callback for HuggingFace pretrained models"
    run_valid = False
    order = GradientAccumulation.order-1
    @delegates(compute_adversarial_loss)
    def __init__(self, m:nn.Module, alpha:float=1., start_epoch:int=0,
                 criterion=None, mask_special_tokens:bool=False, 
                 one_token_type=False, **kwargs):
        self.hook = None
        self.kwargs = kwargs if kwargs else {}
        self._do_vat=True
        self.special_tokens_mask, self.token_type_mask = None, None
        store_attr()
    
    def before_fit(self):
        if self.criterion is None:
            self.criterion = MSELoss() if isinstance(self.loss_func, nn.MSELoss) else SymmetrizedKL
        self.adv_loss_func = partial(compute_adversarial_loss, criterion=self.criterion, **self.kwargs)
    
    def before_batch(self):
        if (self.hook is None) and (self.epoch >= self.start_epoch):
            self.hook = Hook(self.m, hook_out)
            print(f'Starting virtual adversarial training at epoch {self.epoch}')

        if self.mask_special_tokens:
            self.special_tokens_mask = self.xb[0].pop('special_tokens_mask', None)
            if self.special_tokens_mask is not None:
                self.special_tokens_mask = (1-self.special_tokens_mask).unsqueeze(-1)
        if self.one_token_type:
            self.token_type_mask = self.xb[0].pop('token_type_ids', None)
            if self.token_type_mask is not None:
                # this would deterministically mask tokens of type 0
                self.token_type_mask = self.token_type_mask.unsqueeze(-1)

    def after_loss(self):
        if self.epoch >= self.start_epoch and self._do_vat:
            embed, logits = self.hook.stored, self.pred
            model = self.model.hf_model if hasattr(self.model, 'hf_model') else self.model
            try:
                adv_loss = self.adv_loss_func(model, embed, logits, self.special_tokens_mask, self.token_type_mask)
                self.log_loss(adv_loss, 'adversarial_loss')
            except TypeError as e:
                print("Your model is probably not supported, make sure model interface is compatible with HF pretrained models")
                adv_loss, self._do_vat = 0, False
            self.learn.loss_grad += adv_loss * self.alpha

    def after_fit(self):
        if self.hook is not None: self.hook.remove()

In [None]:
model = nn.Sequential(
    nn.Linear(1,10, bias=False),
    nn.Linear(10,1, bias=False)
)
learn = synth_learner(model=model, cbs=ALUMCallback(model[0]))

In [None]:
#hide
learn.show_training_loop()

Start Fit
   - before_fit     : [TrainEvalCallback, ALUMCallback, Recorder, ProgressCallback]
  Start Epoch Loop
     - before_epoch   : [Recorder, ProgressCallback]
    Start Train
       - before_train   : [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - before_batch   : [ALUMCallback]
         - after_pred     : []
         - after_loss     : [ALUMCallback]
         - before_backward: []
         - before_step    : []
         - after_step     : []
         - after_cancel_batch: []
         - after_batch    : [TrainEvalCallback, Recorder, ProgressCallback]
      End Batch Loop
    End Train
     - after_cancel_train: [Recorder]
     - after_train    : [Recorder, ProgressCallback]
    Start Valid
       - before_validate: [TrainEvalCallback, Recorder, ProgressCallback]
      Start Batch Loop
         - **CBs same as train batch**: []
      End Batch Loop
    End Valid
     - after_cancel_validate: [Recorder]
     - after_validate : [Recorder, Progress

In [None]:
learn.fit(2, 1e-3)

epoch,train_loss,valid_loss,time
0,11.263165,10.850156,00:00
1,10.396017,9.291542,00:00


Starting virtual adversarial training at epoch 0
Your model is probably not supported, make sure model interface is compatible with HF pretrained models


## SMART

[SMART: Robust and Efficient Fine-Tuning for Pre-trained Natural Language Models through Principled Regularized Optimization](https://arxiv.org/abs/1911.03437)

In [None]:
#export
def update_ema_model(ema_model:nn.Module, model:nn.Module, mom:float=0.99):
    "Updates `ema_model` parameters with online `model` parameters using momentum `mom`"
    coef = 1-mom
    for p_ema, p in zip(ema_model.parameters(), model.parameters()):
        p_ema.data.mul_(mom)
        p_ema.data.add_(p.data, alpha=coef)

### Algorithm

**Notation:**  
$ g_i(\tilde{x_i}, \bar{\theta_i}) = \frac{1}{|\mathcal{B}|}\sum_{x_i \in \mathcal{B}} {\{\nabla_x \ell_s (\mathcal{f}(x_i; \bar{\theta}_s), \mathcal{f}(\tilde{x_i}; \bar{\theta}_s))} $;  
$AdamUpdate_{\mathcal B}$ - ADAM update  for optimizing;  
$\Pi_{\mathcal A}$ - prjection to $\mathcal A$

**Input:** $T$: total number of iterations, $\mathcal X$: the dataset, $\theta_0$: pre-trained model parameters, $S$: total number of iterations for Bregman proximal point method, $\sigma^2$: variance for random initialization of perturbation, $T_{\bar{x}}$number of iterations for updating $\tilde{x_i}$, $\eta$: lr for updating $\tilde{x_i}$, $\beta$: momentum parameter.

01: $\tilde{\theta_1} \leftarrow \theta_0$  
02: **for** $t = 1,...,T$ **do**  
03: $\quad$ $\bar{\theta_1} \leftarrow \theta_t-1$  
04: $\quad$ **for** $s = 1,...,S$ **do**  
05: $\quad \quad$ Sample $\mathcal{B}$ from $\mathcal X$  
06: $\quad \quad$ $\tilde{x_i} \leftarrow x_i + \nu_i$ where $\nu_i ~ \mathcal{N} (0, \sigma^2$  
07: $\quad \quad$ **for** $m = 1,...,T_\bar{x}$ **do**  
08: $\quad \quad \quad$ $\tilde{g_i} \leftarrow \frac{g_i(\tilde{x_i},\bar{\theta_s})}{\|g_i(\tilde{x_i},\bar{\theta_s})\|_\infty} $  
09: $\quad \quad \quad$ $\tilde{x_i} \leftarrow \Pi_{\|\tilde{x_i}-x\|_\infty \le \epsilon}(\tilde{x_i} + \eta \tilde{g_i})$  
10: $\quad \quad$ **end for**  
11: $\quad \quad$ $\bar{\theta}_{s+1} \leftarrow AdamUpdate_\mathcal{B} (\bar{\theta}_s)$  
12: $\quad$ **end for**  
13: $\quad$ $\theta_t \leftarrow \bar{\theta}_{S}$  
14: $\quad$ $\tilde{\theta}_{t+1} \leftarrow $  
15: **end for**

**Output:** $\theta_T$


In [None]:
#export
class SMARTCallback(LossCallback):
    """
    SMART callback for HuggingFace pretrained models.
    
    Combines smoothness-inducing adversarial training and
    momentum accelerated Bregman proximal point optimization.
    """
    run_valid = False
    order = GradientAccumulation.order-1
    @delegates(compute_adversarial_loss)
    def __init__(self, m:nn.Module, alpha:float=1., mu:float=1., start_epoch:int=0, criterion=None,
                 mask_special_tokens:bool=False, one_token_type=False, **kwargs):
        self.hook = None
        self.kwargs = kwargs if kwargs else {}
        self._do_vat=True
        self.mom = 0.99
        self.special_tokens_mask, self.token_type_mask = None, None
        store_attr()
    
    def before_fit(self):
        "Create and freeze EMA model"
        self.ema_model = deepcopy(self.model)
        self.ema_model.eval()
        self.ema_model.requires_grad_(False)
        
        if self.criterion is None:
            self.criterion = MSELoss() if isinstance(self.loss_func, nn.MSELoss) else SymmetrizedKL
        self.adv_loss_func = partial(compute_adversarial_loss, criterion=self.criterion, **self.kwargs)
    
    def before_batch(self):
        if (self.hook is None) and (self.epoch >= self.start_epoch):
            self.hook = Hook(self.m, hook_out)
            print(f'Starting virtual adversarial training at epoch {self.epoch}')
        if (self.mom == 0.99) & (self.pct_train >= 0.1):
            self.mom = 0.999

    def after_loss(self):
        if self.epoch >= self.start_epoch and self._do_vat:
            embed, logits = self.hook.stored, self.pred
            model = self.model.hf_model if hasattr(self.model, 'hf_model') else self.model
            # "Bregman" loss
            # TODO make sure labels are not in `xb`
            with torch.no_grad():
                ema_out = self.ema_model(*self.xb)
                ema_logits = ema_out.logits if hasattr(ema_out, 'logits') else ema_out
            breg_loss = self.criterion(logits, ema_logits)
            self.log_loss(breg_loss, 'breg_loss')
            self.learn.loss_grad += breg_loss * self.mu
            # adversarial loss
            try:
                adv_loss = self.adv_loss_func(model, embed, logits, self.special_tokens_mask, self.token_type_mask)
                self.log_loss(adv_loss, 'adversarial_loss')
            except TypeError as e:
                print("Your model is probably not supported, make sure model interface is compatible with HF pretrained models")
                adv_loss, self._do_vat = 0, False
            self.learn.loss_grad += adv_loss * self.alpha
    
    def after_step(self):
        update_ema_model(self.ema_model, self.model, self.mom)
    
    def after_fit(self):
        if self.hook is not None: self.hook.remove()

In [None]:
model = nn.Sequential(
    nn.Linear(1,10, bias=False),
    nn.Linear(10,1, bias=False)
)
learn = synth_learner(model=model, cbs=SMARTCallback(model[0]))

In [None]:
learn.fit(2)

epoch,train_loss,valid_loss,time
0,12.51172,9.756591,00:00
1,11.275883,8.746477,00:00


Starting virtual adversarial training at epoch 0
Your model is probably not supported, make sure model interface is compatible with HF pretrained models


In [None]:
#hide
assert learn.smart.mom == 0.999

In [None]:
#export
class VATCallback(LossCallback):
    "VAT callback (draft)"
    run_valid=False
    # mb worth adding capability to inject adversarial noize into intermediate activations
    # for ALUM case we could perturb outputs of the embedding layer instead of embedding weights (which would be equivalent)
    def __init__(self, start_iter=None): #?? potentially start in the middle of training
        
        self.start_iter = start_iter
        
    def after_loss(self):
        #TODO: detach as appropriate
        noize = 0
        x_adv = self.x + noize #?? take care of possible multiple inputs 
        logits = self.pred
        print(f'{self.train_iter:2} - Do stuff here with input of shape {self.x.shape} and logits {logits.shape} and modify loss {self.loss:.4f}')
        # do VAT stuff here


In [None]:
#hide
model = nn.Sequential(
    nn.Linear(1,10, bias=False),
    nn.Linear(10,1, bias=False)
)
learn = synth_learner(model=model, cbs=VATCallback())
learn.fit(1, 1e-3)

epoch,train_loss,valid_loss,time
0,11.183575,10.959995,00:00


 0 - Do stuff here with input of shape torch.Size([16, 1]) and logits torch.Size([16, 1]) and modify loss 16.1694
 1 - Do stuff here with input of shape torch.Size([16, 1]) and logits torch.Size([16, 1]) and modify loss 8.1717
 2 - Do stuff here with input of shape torch.Size([16, 1]) and logits torch.Size([16, 1]) and modify loss 4.7342
 3 - Do stuff here with input of shape torch.Size([16, 1]) and logits torch.Size([16, 1]) and modify loss 14.8299
 4 - Do stuff here with input of shape torch.Size([16, 1]) and logits torch.Size([16, 1]) and modify loss 15.2079
 5 - Do stuff here with input of shape torch.Size([16, 1]) and logits torch.Size([16, 1]) and modify loss 9.9517
 6 - Do stuff here with input of shape torch.Size([16, 1]) and logits torch.Size([16, 1]) and modify loss 13.5607
 7 - Do stuff here with input of shape torch.Size([16, 1]) and logits torch.Size([16, 1]) and modify loss 6.2763
 8 - Do stuff here with input of shape torch.Size([16, 1]) and logits torch.Size([16, 1]) an

## Fin

In [None]:
#hide
from nbdev.export import notebook2script; notebook2script()

Converted 00_core.ipynb.
Converted 01_utils.ipynb.
Converted index.ipynb.
