In [None]:
!nvidia-smi

Thu Apr 22 00:26:41 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.51.05    Driver Version: 450.51.05    CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla V100S-PCI...  Off  | 00000000:00:0A.0 Off |                    0 |
| N/A   39C    P0    41W / 250W |   2967MiB / 32510MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
import sys
if 'google.colab' in sys.modules:
    !pip install -Uqq fastcore sentencepiece
    !pip install -Uqq --no-deps fastai
    !pip install -Uqq transformers datasets wandb 

In [None]:
from transformers import *
from datasets import load_dataset, concatenate_datasets, load_metric

  '"sox" backend is being deprecated. '


## Setup

In [None]:
model_name = 'distilroberta-base'
# data
max_length = 512
bs = 64
val_bs = bs*2
# training
lr = 3e-5

## Data preprocessing

In [None]:
ds_name = 'snli'

In [None]:
train_ds = load_dataset(ds_name, split='train')
valid_ds = load_dataset(ds_name, split='validation')

Reusing dataset snli (/workspace/.cache/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)
Reusing dataset snli (/workspace/.cache/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c)


In [None]:
# train_ds = train_ds.select(range(100))
# valid_ds = valid_ds.select(range(100))

In [None]:
len(train_ds), len(valid_ds)

(550152, 10000)

In [None]:
train_ds.column_names

['premise', 'hypothesis', 'label']

In [None]:
train_ds[2]

{'premise': 'A person on a horse jumps over a broken down airplane.',
 'hypothesis': 'A person is outdoors, on a horse.',
 'label': 0}

In [None]:
from collections import Counter

In [None]:
Counter(train_ds['label'])

Counter({1: 182764, 2: 183187, 0: 183416, -1: 785})

In [None]:
train_ds = train_ds.filter(lambda sample: sample['label'] in [0,1,2])
valid_ds = valid_ds.filter(lambda sample: sample['label'] in [0,1,2])

Loading cached processed dataset at /workspace/.cache/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-7507e6a826d31a45.arrow
Loading cached processed dataset at /workspace/.cache/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-e9398e3f5f289e57.arrow


In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
def tokenize(batch):
    batch = tokenizer(
        batch['premise'],
        batch['hypothesis'],
        add_special_tokens=True,
        padding=False,
        truncation=True,
        max_length=max_length,
        # return_tensors='pt'
        )
    return batch

In [None]:
train_ds = train_ds.map(tokenize, batched=True, batch_size=100, remove_columns=['premise','hypothesis'], num_proc=4)
valid_ds = valid_ds.map(tokenize, batched=True, batch_size=100, remove_columns=['premise','hypothesis'], num_proc=4)

 

Loading cached processed dataset at /workspace/.cache/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-d77a68b46e207138.arrow


 

Loading cached processed dataset at /workspace/.cache/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-41fdafefd47b52f4.arrow


 

Loading cached processed dataset at /workspace/.cache/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-9fa092c864ca09c2.arrow


 

Loading cached processed dataset at /workspace/.cache/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-ad220127775c7c99.arrow


    

Loading cached processed dataset at /workspace/.cache/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-e7966d9516bcacee.arrow
Loading cached processed dataset at /workspace/.cache/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-97cb843dcd0cd74a.arrow
Loading cached processed dataset at /workspace/.cache/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-3dce592b65c30dd5.arrow
Loading cached processed dataset at /workspace/.cache/snli/plain_text/1.0.0/bb1102591c6230bd78813e229d5dd4c7fbf4fc478cec28f298761eb69e5b537c/cache-914e0bf516502153.arrow


In [None]:
train_ds = train_ds.rename_column('label', 'labels')
valid_ds = valid_ds.rename_column('label', 'labels')

## Tracking

In [None]:
# import wandb

# WANDB_NAME = f'{ds_name}-{model_name}-hf'
# GROUP = f'{ds_name}-{model_name}-hf-{lr:.0e}'
# NOTES = f'HF finetuning {model_name} with AdamW lr={lr:.0e}'
# CONFIG = {}
# TAGS =[model_name,ds_name,'adamw']

In [None]:
import wandb

WANDB_NAME = f'{ds_name}-{model_name}-alum'
GROUP = f'{ds_name}-{model_name}-hf-{lr:.0e}'
NOTES = f'HF finetuning {model_name} with AdamW lr={lr:.0e}'
CONFIG = {}
TAGS =[model_name,ds_name,'adamw','alum']

In [None]:
%env WANDB_LOG_MODEL = false
%env WANDB_WATCH = false

env: WANDB_LOG_MODEL=false
env: WANDB_WATCH=false


In [None]:
wandb.init(reinit=True, project="vat", entity="fastai_community",
           name=WANDB_NAME, group=GROUP, notes=NOTES, tags=TAGS, config=CONFIG);

[34m[1mwandb[0m: Currently logged in as: [33mfastai_community[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.27 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


In [None]:
training_args = TrainingArguments(
    'test',
    evaluation_strategy = 'epoch',
    per_device_train_batch_size = bs,
    per_device_eval_batch_size=val_bs,
    learning_rate=lr,
    num_train_epochs=5,
    warmup_ratio=0.2,
    logging_steps=200,
    fp16=True,
    group_by_length=True,
    dataloader_num_workers=4,
    report_to='wandb',
    save_strategy='epoch',
    save_total_limit=2,
    seed=8,
)

In [None]:
metric = load_metric('accuracy')

In [None]:
import numpy as np

In [None]:
def compute_metric(eval_preds):
    predictions, labels = eval_preds
    predictions = np.argmax(predictions, axis=1)
    return metric.compute(predictions=predictions, references=labels)

## Regular training

In [None]:
# model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)

# trainer = Trainer(
#     model,
#     training_args,
#     train_dataset=train_ds,
#     eval_dataset=valid_ds,
#     tokenizer=tokenizer,
#     # data_collator=DataCollatorWithPadding(),
#     compute_metrics=compute_metric
# )

In [None]:
# out = trainer.train()

In [None]:
# wandb.finish()

## VATrainer

In [None]:
from core import compute_adversarial_loss

In [None]:
class VATrainer(Trainer):

    def __init__(self, *args, vat_kwargs={}, **kwargs):
        super().__init__(*args, **kwargs)
        self.adv_alpha = vat_kwargs.pop('alpha', 1.)
        self.mask_special_tokens = vat_kwargs.pop('mask_special_tokens', False)
        self.one_token_type = vat_kwargs.pop('one_token_type', False)
        self.vat_start_epoch = vat_kwargs.pop('start_epoch', 1)
        self.vat_kwargs = vat_kwargs
        self._do_vat=False

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        How the loss is computed by Trainer. By default, all models return the loss in the first element.

        Subclass and override for custom behavior.
        """
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        # explicitly adding kwargs here, verify no conflicts may happen
        outputs = model(**inputs, output_hidden_states=model.training, return_dict=True)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            loss = self.label_smoother(outputs, labels)
        else:
            loss = outputs.loss
        #TODO add option to use vat_start_step
        if model.training and self.state.epoch >= self.vat_start_epoch:
            if not self._do_vat:
                print(f'Starting virtual adversarial training at epoch {self.state.epoch}')
                self._do_vat = True
            # ALUM training procedure
            embed = outputs.hidden_states[0].detach()
            # TODO add option mask special tokens or toke types here
            special_tokens_mask, token_type_mask = None, None
            if self.mask_special_tokens:
                special_tokens_mask = None
            if self.one_token_type:
                token_type_mask = None

            adv_loss = compute_adversarial_loss(model, embed, outputs.logits, 
                special_tokens_mask=special_tokens_mask, token_type_mask=token_type_mask,
                **self.vat_kwargs)
            loss += self.adv_alpha*adv_loss
        return (loss, outputs) if return_outputs else loss

## Training

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=3)

Some weights of the model checkpoint at distilroberta-base were not used when initializing RobertaForSequenceClassification: ['lm_head.bias', 'lm_head.dense.weight', 'lm_head.dense.bias', 'lm_head.layer_norm.weight', 'lm_head.layer_norm.bias', 'lm_head.decoder.weight', 'roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at distilroberta-base and are newly initialized: ['classifier.dense.weight'

In [None]:
vat_kwargs = {'start_epoch':2, 'alpha':0.1}

trainer = VATrainer(
    model,
    training_args,
    train_dataset=train_ds,
    eval_dataset=valid_ds,
    tokenizer=tokenizer,
    # data_collator=DataCollatorWithPadding(),
    compute_metrics=compute_metric,
    vat_kwargs=vat_kwargs    
)

In [None]:
out = trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Runtime,Samples Per Second
1,0.4051,0.352074,0.869844,5.0545,1947.167
2,0.3417,0.291743,0.895448,5.0747,1939.406
3,0.6417,0.357602,0.883763,5.0424,1951.845
4,0.5982,0.332267,0.890673,5.113,1924.895


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

IOPub message rate exceed

## Validation on adversarial data

In [None]:
# adv_ds = load_dataset('anli', split='test_r1')
# adv_ds[0]