In [None]:
!nvidia-smi

Tue Apr 27 14:20:26 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.36.06    Driver Version: 450.36.06    CUDA Version: 11.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Quadro P5000        On   | 00000000:00:05.0 Off |                  Off |
| 26%   27C    P5     7W / 180W |    865MiB / 16278MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
import sys
if 'google.colab' in sys.modules:
    !pip install -Uqq fastcore sentencepiece
    !pip install -Uqq --no-deps fastai
    !pip install -Uqq transformers datasets wandb 

In [None]:
from transformers import *
from datasets import load_dataset, concatenate_datasets, load_metric

## Setup

In [None]:
model_name = 'distilroberta-base'
# data
max_length = 128
bs = 16
val_bs = bs*4
# training
lr = 3e-5

## Data preprocessing

In [None]:
ds_name = 'imdb'

In [None]:
dataset = load_dataset(ds_name)

Reusing dataset imdb (/root/.cache/huggingface/datasets/imdb/plain_text/1.0.0/4ea52f2e58a08dbc12c2bd52d0d92b30b88c00230b4522801b3636782f625c5b)


In [None]:
# train_ds = train_ds.select(range(100))
# valid_ds = valid_ds.select(range(100))

In [None]:
dataset.column_names

{'train': ['label', 'text'],
 'test': ['label', 'text'],
 'unsupervised': ['label', 'text']}

In [None]:
dataset['unsupervised'][2]

{'label': -1,
 'text': "Everybody has seen 'Back To The Future,' right? Whether you LIKE that movie or not, you've seen an example of how to make a time-travel movie work. A torn-up poster for 'Back To The Future' shows up in this movie, representing, perhaps unintentionally, what the makers of 'Tangents' (aka 'Time Chasers') did to the time-travel formula. Then again, the movie claims to have been made in 1994, but it looks -- and sounds -- like it was produced at least ten years earlier, so maybe they achieved time-travel after all.<br /><br />Start with an intensely unappealing leading man. I mean, what woman doesn't love gangly, whiny, lantern-jawed, butt-chinned, mullet-men with Coke-bottle glasses? Oh, none of you? Prepare to tough it out, ladies, cuz that's what this movie gives you.<br /><br />Second, add a leading lady who -- while not entirely unattractive -- represents many '80s clichés: big hair, too much makeup, two different plaids, shoulder pads, acid-washed mom-jeans, e

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
def tokenize(batch):
    return tokenizer(batch['text'],return_attention_mask=True, return_special_tokens_mask=True)

In [None]:
dataset = dataset.map(tokenize, batched=True, batch_size=100, remove_columns=dataset['train'].column_names, num_proc=4)

Token indices sequence length is longer than the specified maximum sequence length for this model (535 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (897 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (950 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (803 > 512). Running this sequence through the model will result in indexing errors








Token indices sequence length is longer than the specified maximum sequence length for this model (1289 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (540 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (699 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (525 > 512). Running this sequence through the model will result in indexing errors








Token indices sequence length is longer than the specified maximum sequence length for this model (523 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (723 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (1068 > 512). Running this sequence through the model will result in indexing errors
Token indices sequence length is longer than the specified maximum sequence length for this model (565 > 512). Running this sequence through the model will result in indexing errors








In [None]:
block_size = 128

In [None]:
def group_texts(examples):
    # Concatenate all texts.
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the small remainder, we could add padding if the model supported it instead of this drop, you can
        # customize this part to your needs.
    total_length = (total_length // block_size) * block_size
    # Split by chunks of max_len.
    result = {
        k: [t[i : i + block_size] for i in range(0, total_length, block_size)]
        for k, t in concatenated_examples.items()
    }
    result["labels"] = result["input_ids"].copy()
    return result

In [None]:
lm_dataset = dataset.map(
    group_texts,
    batched=True,
    batch_size=1000,
    num_proc=4,
)















In [None]:
lm_dataset = concatenate_datasets([lm_dataset['train'], lm_dataset['unsupervised'], lm_dataset['test']])

## Tracking

In [None]:
# import wandb

# WANDB_NAME = f'{ds_name}-{model_name}-hf'
# GROUP = f'{ds_name}-{model_name}-hf-{lr:.0e}'
# NOTES = f'HF finetuning {model_name} with AdamW lr={lr:.0e}'
# CONFIG = {}
# TAGS =[model_name,ds_name,'adamw']

In [None]:
import wandb

WANDB_NAME = f'{ds_name}-{model_name}-alum'
GROUP = f'{ds_name}-{model_name}-hf-{lr:.0e}'
NOTES = f'HF finetuning {model_name} with AdamW lr={lr:.0e}'
CONFIG = {}
TAGS =[model_name,ds_name,'adamw','alum']

In [None]:
%env WANDB_LOG_MODEL = false
%env WANDB_WATCH = false

env: WANDB_LOG_MODEL=false
env: WANDB_WATCH=false


In [None]:
wandb.init(reinit=True, project="vat", entity="fastai_community",
           name=WANDB_NAME, group=GROUP, notes=NOTES, tags=TAGS, config=CONFIG);

[34m[1mwandb[0m: Currently logged in as: [33mfastai_community[0m (use `wandb login --relogin` to force relogin)
[34m[1mwandb[0m: wandb version 0.10.27 is available!  To upgrade, please run:
[34m[1mwandb[0m:  $ pip install wandb --upgrade


## Training

In [None]:
training_args = TrainingArguments(
    'test', #f'{ds_name}-{model_name}-2',
    evaluation_strategy = 'epoch',
    per_device_train_batch_size = bs,
    per_device_eval_batch_size=val_bs,
    learning_rate=lr,
    num_train_epochs=2,
    lr_scheduler_type='cosine',
    warmup_ratio=0.2,
    logging_steps=200,
    fp16=True,
    group_by_length=True,
    dataloader_num_workers=4,
    remove_unused_columns=False,
    report_to='none',#'wandb',
    save_strategy='epoch',
    save_total_limit=2,
    seed=8,
)

## Regular training

In [None]:
from transformers import AutoModelForMaskedLM

In [None]:
N = len(lm_dataset)

In [None]:
import random

In [None]:
idx = list(range(N))
random.shuffle(idx)

In [None]:
split = int(N*0.9)
train_idx = idx[:split]
valid_idx = idx[split:]

In [None]:
model = AutoModelForMaskedLM.from_pretrained(model_name)

trainer = Trainer(
    model,
    training_args,
    train_dataset=lm_dataset.select(train_idx),
    eval_dataset=lm_dataset.select(valid_idx),
#     tokenizer=tokenizer,
    data_collator=DataCollatorForLanguageModeling(tokenizer)
)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=331070498.0, style=ProgressStyle(descri…




In [None]:
out = trainer.train()



Epoch,Training Loss,Validation Loss,Runtime,Samples Per Second
1,2.2617,2.102613,117.4889,199.704


KeyboardInterrupt: 

In [None]:
# wandb.finish()

## VATrainer

In [None]:
from core import compute_adversarial_loss

In [None]:
class VATrainer(Trainer):

    def __init__(self, *args, vat_kwargs={}, **kwargs):
        super().__init__(*args, **kwargs)
        self.adv_alpha = vat_kwargs.pop('alpha', 1.)
        self.mask_special_tokens = vat_kwargs.pop('mask_special_tokens', False)
        self.one_token_type = vat_kwargs.pop('one_token_type', False)
        self.vat_start_epoch = vat_kwargs.pop('start_epoch', 1)
        self.vat_kwargs = vat_kwargs
        self._do_vat=False

    def compute_loss(self, model, inputs, return_outputs=False):
        """
        Loss computation with virtual adversarial loss component 
        """
        if self.label_smoother is not None and "labels" in inputs:
            labels = inputs.pop("labels")
        else:
            labels = None
        # if masking any of those are expected to be in the inputs and should be removed before forward
        special_tokens_mask = inputs.pop('special_tokens_mask', None)
        token_type_mask = inputs.pop('token_type_ids', None)
        # explicitly adding kwargs here, verify no conflicts may happen
        outputs = model(**inputs, output_hidden_states=model.training, return_dict=True)
        # Save past state if it exists
        # TODO: this needs to be fixed and made cleaner later.
        if self.args.past_index >= 0:
            self._past = outputs[self.args.past_index]

        if labels is not None:
            loss = self.label_smoother(outputs, labels)
        else:
            loss = outputs.loss
        #TODO add option to use vat_start_step
        if model.training and self.state.epoch >= self.vat_start_epoch:
            if not self._do_vat:
                print(f'Starting virtual adversarial training at epoch {self.state.epoch}')
                self._do_vat = True
            # ALUM training procedure
            embed = outputs.hidden_states[0].detach()
            # TODO add option mask special tokens or toke types here
            special_tokens_mask, token_type_mask = None, None
            if self.mask_special_tokens:
                if special_tokens_mask is not None:
                    special_tokens_mask = (1-special_tokens_mask).unsqueeze(-1)
                else:
                    print('`special_tokens_maks` not found in the inputs')
                    self.mask_special_tokens = False
            if self.one_token_type:
                token_type_mask = None

            adv_loss = compute_adversarial_loss(model, embed, outputs.logits, 
                special_tokens_mask=special_tokens_mask, token_type_mask=token_type_mask,
                **self.vat_kwargs)
            loss += self.adv_alpha*adv_loss
        return (loss, outputs) if return_outputs else loss

## Training

In [None]:
model = AutoModelForMaskedLM.from_pretrained(model_name)

In [None]:
N = len(lm_dataset)
N

234624

In [None]:
train_idx = list(range(int(N*0.95)))
valid_idx = list(range(int(N*0.95), N))

In [None]:
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.15)

vat_kwargs = {
    'start_epoch':0,
    'alpha':5,
    'mask_special_tokens':False
}

trainer = VATrainer(
    model,
    training_args,
    train_dataset=lm_dataset.select(train_idx),
    eval_dataset=lm_dataset.select(valid_idx),
    data_collator=data_collator,
    vat_kwargs=vat_kwargs    
)

In [None]:
out = trainer.train()

In [None]:
wandb.finish()

In [None]:
import torch
# del model, trainer
torch.cuda.empty_cache()

In [None]:
!nvidia-smi