<a href="https://colab.research.google.com/github/TheBlueHawk/CS4NLP_Project2022/blob/main/mctaco_finetuning_alice.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install datasets
!pip install transformers[sentencepiece]
!pip install sentencepiece # necessary for DeBERTa-v3
!pip install pytorch-lightning==1.5.10
!pip install wandb
!pip install rich
!pip install torchmetrics
!pip install smart-pytorch 

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.3.2-py3-none-any.whl (362 kB)
[K     |████████████████████████████████| 362 kB 5.2 MB/s 
Collecting xxhash
  Downloading xxhash-3.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[K     |████████████████████████████████| 212 kB 66.3 MB/s 
Collecting huggingface-hub<1.0.0,>=0.1.0
  Downloading huggingface_hub-0.8.1-py3-none-any.whl (101 kB)
[K     |████████████████████████████████| 101 kB 11.1 MB/s 
Collecting fsspec[http]>=2021.05.0
  Downloading fsspec-2022.5.0-py3-none-any.whl (140 kB)
[K     |████████████████████████████████| 140 kB 73.5 MB/s 
Collecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting aiohttp
  Downloading aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
[K     |████████████████████████████████|

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting wandb
  Downloading wandb-0.12.19-py2.py3-none-any.whl (1.8 MB)
[K     |████████████████████████████████| 1.8 MB 5.3 MB/s 
Collecting sentry-sdk>=1.0.0
  Downloading sentry_sdk-1.6.0-py2.py3-none-any.whl (145 kB)
[K     |████████████████████████████████| 145 kB 71.1 MB/s 
[?25hCollecting GitPython>=1.0.0
  Downloading GitPython-3.1.27-py3-none-any.whl (181 kB)
[K     |████████████████████████████████| 181 kB 69.1 MB/s 
[?25hCollecting docker-pycreds>=0.4.0
  Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)
Collecting pathtools
  Downloading pathtools-0.1.2.tar.gz (11 kB)
Collecting setproctitle
  Downloading setproctitle-1.2.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29 kB)
Collecting shortuuid>=0.5.0
  Downloading shortuuid-1.0.9-py3-none-any.whl (9.4 kB)
Collecting gitdb<5,>=4.0.1
  Downloading gitdb-

In [None]:
# Login to Wandb for logging
import wandb
wandb.login()

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize


In [None]:
import pytorch_lightning as pl
from transformers import AutoTokenizer, AutoModelForSequenceClassification
pl.seed_everything(42)

params = {
    'pretrained_model_name': 'roberta-base', # 'microsoft/deberta-v3-base', 'roberta-base', 'microsoft/mdeberta-v3-base', 'bert-base-uncased'
    'batch_size': 32,
    'sequence_length': 128,
    'max_epochs': 20,
    'alice_loss_weight': 1.0
}

tokenizer = AutoTokenizer.from_pretrained(params['pretrained_model_name'])
architecture = AutoModelForSequenceClassification.from_pretrained(params['pretrained_model_name'])

In [None]:
import torch 
from torch.utils.data import Dataset
from datasets import load_dataset
from transformers import BertTokenizer

class MCTACODataset(Dataset):

    def __init__(self, split: str, tokenizer, sequence_length: int):
        self.dataset = load_dataset("mc_taco")[split]
        self.tokenizer = tokenizer
        self.sequence_length = sequence_length

    def __len__(self):
        return len(self.dataset)

    def truncate_pair(self, tokens_a, tokens_b, max_length):
        while True:
            total_length = len(tokens_a) + len(tokens_b)
            if total_length <= max_length:
                break
            if len(tokens_a) > len(tokens_b):
                tokens_a.pop()
            else:
                tokens_b.pop()

    def __getitem__(self, idx): 
        item = self.dataset[idx] 
        tokenize = self.tokenizer.tokenize
        sequence = tokenize(item['sentence'] + " " + item['question'])
        answer = tokenize(item['answer']) 
        label = item['label']
        # Truncate excess tokens 
        if answer: 
            self.truncate_pair(sequence, answer, self.sequence_length - 3)
        else: 
            if len(sequence) > self.sequence_length - 2:
                sequence = sequence[0:(self.sequence_length - 2)]
        # Compute tokens, ids, mask 
        tokens = ['<s>'] + sequence + ['</s></s>'] + answer + ['</s>']
        input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
        input_mask = [1] * len(input_ids)
        # Pad with 0 
        while len(input_ids) < self.sequence_length:
            input_ids.append(0)
            input_mask.append(0)
        return torch.tensor(input_ids), torch.tensor(input_mask), torch.tensor(label)
        
dataset = MCTACODataset(split='validation', tokenizer=tokenizer, sequence_length=params['sequence_length'])
print(dataset[10])

In [None]:
from torch.utils.data import DataLoader

class MCTACODatamodule(pl.LightningDataModule):
    def __init__(
        self,
        tokenizer,
        batch_size: int,
        sequence_length: int 
    ):
        super().__init__()
        self.tokenizer = tokenizer
        self.batch_size = batch_size
        self.sequence_length = sequence_length
        self.dataset_train = None
        self.dataset_valid = None

    def setup(self, stage = None):
        self.dataset_train = MCTACODataset(
            split='validation', 
            tokenizer=self.tokenizer, 
            sequence_length=self.sequence_length
        )
        self.dataset_valid = MCTACODataset(
            split='test', 
            tokenizer=self.tokenizer, 
            sequence_length=self.sequence_length
        )

    def train_dataloader(self) -> DataLoader:
        return DataLoader(
            dataset=self.dataset_train,
            batch_size=self.batch_size,
            shuffle=True,
        )

    def val_dataloader(self) -> DataLoader:
        return DataLoader(
            dataset=self.dataset_valid,
            batch_size=self.batch_size,
            shuffle=False,
        )

datamodule = MCTACODatamodule(tokenizer, batch_size = params['batch_size'], sequence_length = params['sequence_length']) 
datamodule.setup()

In [None]:
from ast import Call
from typing import Union, Callable

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from itertools import count 

def exists(val):
    return val is not None

def default(val, d):
    if exists(val):
        return val
    return d

def inf_norm(x):
    return torch.norm(x, p=float('inf'), dim=-1, keepdim=True)

In [None]:
class ALICELoss(nn.Module):
    
    def __init__(
        self,
        eval_fn: Callable,
        virtual_loss_fn: Callable,
        labels: Tensor,
        gold_loss_fn: Callable = None, 
        virtual_loss_last_fn: Callable = None,
        gold_loss_last_fn: Callable = None,
        norm_fn: Callable = inf_norm, 
        alpha: float = 1,
        num_steps: int = 1,
        step_size: float = 1e-3, 
        epsilon: float = 1e-6,
        noise_var: float = 1e-5
    ) -> None:
        super().__init__()
        self.eval_fn = eval_fn 
        self.virtual_loss_fn = virtual_loss_fn
        self.labels = labels
        self.gold_loss_fn = default(gold_loss_fn, virtual_loss_fn)
        self.virtual_loss_last_fn = default(virtual_loss_last_fn, virtual_loss_fn)
        self.gold_loss_last_fn = default(gold_loss_last_fn, virtual_loss_fn)
        self.norm_fn = norm_fn
        self.alpha = alpha
        self.num_steps = num_steps 
        self.step_size = step_size
        self.epsilon = epsilon 
        self.noise_var = noise_var
        
    def forward(self, embed: Tensor, state: Tensor) -> Tensor:
        noise_1 = torch.randn_like(embed, requires_grad=True) * self.noise_var
        noise_2 = torch.randn_like(embed, requires_grad=True) * self.noise_var
        one_hot_labels = F.one_hot(self.labels).float()

        def compute_grad(loss, noise):
            # Compute noise gradient ∂loss/∂noise
            noise_gradient, = torch.autograd.grad(loss, noise)
            # Move noise towards gradient to change state as much as possible 
            step = noise + self.step_size * noise_gradient
            # Normalize new noise step into norm induced ball 
            step_norm = self.norm_fn(step)
            noise = step / (step_norm + self.epsilon)
            # Reset noise gradients for next step
            noise = noise.detach().requires_grad_()
        
        def eval_perturbed(embed: Tensor, noise: Tensor):
            embed_perturbed = embed + noise
            return self.eval_fn(embed_perturbed)
         
        # Indefinite loop with counter 
        for i in count():
            # Compute perturbed embed and states 
            state_perturbed_1 = eval_perturbed(embed, noise_1)
            state_perturbed_2 = eval_perturbed(embed, noise_2)

            # Return logits loss if last step (undetached state)
            if i == self.num_steps: 
                gold_loss = self.gold_loss_last_fn(state_perturbed_1, one_hot_labels)
                virtual_loss = self.virtual_loss_last_fn(state_perturbed_2, state)
                loss = gold_loss + self.alpha * virtual_loss
                return loss

            # Compute  loss (detached state)
            gold_loss = self.gold_loss_fn(state_perturbed_1, one_hot_labels.detach())
            virtual_loss = self.virtual_loss_fn(state_perturbed_2, state.detach()) 

            # Compute noise gradients
            compute_grad(gold_loss, noise_1)
            compute_grad(virtual_loss, noise_2)

In [None]:
from smart_pytorch import SMARTLoss
import torch.nn as nn
import torch.nn.functional as F

def kl_loss(s_p, s):
    # s_p: perturbed state, s: initial state 
    s_p = F.log_softmax(s_p, dim=1) # (b, n)
    s = F.log_softmax(s, dim=1) # (b, n)
    l0 = F.kl_div(s_p, s, reduction = 'sum', log_target=True)
    l1 = F.kl_div(s, s_p, reduction = 'sum', log_target=True)
    return l0 + l1

class ALICEClassificationModel(nn.Module):
    # b: batch_size, s: sequence_length, d: hidden_size , n: num_labels

    def __init__(self, model, weight):
        super().__init__()
        self.model = model 
        self.weight = weight

    def forward(self, input_ids, attention_mask, labels):
        # input_ids: (b, s), attention_mask: (b, s), labels: (b,)

        embed = self.model.roberta.embeddings(input_ids) # (b, s, d)

        def eval(embed):
            outputs = self.model.roberta(inputs_embeds=embed, attention_mask=attention_mask) # (b, s, d)
            pooled = outputs[0] # (b, d)
            logits = self.model.classifier(pooled) # (b, n)
            return logits 

        alice_loss_fn = ALICELoss(eval_fn = eval, virtual_loss_fn = kl_loss, labels = labels)
        state = eval(embed)
        loss = F.cross_entropy(state.view(-1, 2), labels.view(-1))
        alice_loss = torch.tensor(0)
        if embed.requires_grad:
            alice_loss = alice_loss_fn(embed, state)
            loss += self.weight * alice_loss
        #print(loss, alice_loss)
        return state, loss
           
input_ids, input_mask, labels = next(iter(datamodule.train_dataloader()))    
alice_architecture = ALICEClassificationModel(architecture, weight=params['alice_loss_weight'])
#output, loss = alice_architecture(input_ids, input_mask, labels)

In [None]:
import torch.nn as nn 
from transformers import Adafactor
from torchmetrics import MetricCollection, Accuracy, F1Score

class TextClassificationModel(pl.LightningModule):
    def __init__(
        self,
        model: nn.Module
    ):
        super().__init__()
        self.model = model
        metrics = MetricCollection([ Accuracy(), F1Score() ])
        self.train_metrics = metrics.clone(prefix='train_')
        self.valid_metrics = metrics.clone(prefix='val_')

    def configure_optimizers(self):
        optimizer = Adafactor(self.model.parameters(), warmup_init=True)
        return optimizer

    def training_step(self, batch, batch_idx):
        input_ids, attention_masks, labels = batch
        # Compute output 
        outputs, loss = self.model(input_ids = input_ids, attention_mask = attention_masks, labels = labels)
        labels_pred = torch.argmax(outputs, dim=1)
        # Compute metrics
        metrics = self.train_metrics(labels, labels_pred)
        # Log loss and metrics
        self.log("train_loss", loss, on_step=True)
        self.log_dict(metrics, on_step=True, on_epoch=True)
        return loss

    def validation_step(self, batch, batch_idx):
        input_ids, attention_masks, labels = batch
        # Compute output 
        outputs, loss = self.model(input_ids = input_ids, attention_mask = attention_masks, labels = labels)
        labels_pred = torch.argmax(outputs, dim=1)
        # Compute metrics
        metrics = self.valid_metrics(labels, labels_pred)
        # Log loss and metrics
        self.log("valid_loss", loss, on_step=True)
        self.log_dict(metrics, on_step=True, on_epoch=True)
        return loss

model = TextClassificationModel(alice_architecture)

In [None]:
# Wandb Logger
logger = pl.loggers.wandb.WandbLogger(project = 'cs4nlp', entity='nextmachina')
# Callbacks 
cb_progress_bar = pl.callbacks.RichProgressBar()
cb_model_summary = pl.callbacks.RichModelSummary()
# Train 
trainer = pl.Trainer(logger=logger, callbacks=[cb_progress_bar, cb_model_summary], max_epochs=params['max_epochs'], gpus=-1)
trainer.logger.log_hyperparams(params)
trainer.fit(model=model, datamodule=datamodule)
wandb.finish() 

In [None]:
target = torch.randint(high=2, size=(1,32)).view(-1)
target

In [None]:
F.one_hot(target)