In [78]:
from argparse import ArgumentParser

import pytorch_lightning as pl
import wandb
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

from thegreatknowledgeheist.data import get_dataloaders
from thegreatknowledgeheist.io import load_yaml
from thegreatknowledgeheist.models.bert import (
    AcronymIdentificationBert,
    AmazonPolarityBert,
    SwagBert,
)

In [79]:
import torch
from thegreatknowledgeheist.models.bert import BaseBert
from transformers import BertForMultipleChoice, BertConfig

class SwagBert(BaseBert):
    def __init__(self, config, pretrained: bool = True, bert_config: BertConfig = None):
        super().__init__(config)

        if pretrained:
            self.model = BertForMultipleChoice.from_pretrained(
                "bert-base-uncased", num_labels=4, output_hidden_states=True,
            )
        else:
            bert_config.output_hidden_states = True
            self.model = BertForMultipleChoice(bert_config)


    def calculate_accuracy(self, logits, labels):
        preds = torch.argmax(logits, dim=1)
        correct_preds = torch.sum(preds == labels)
        return correct_preds / len(preds)


In [80]:
GET_MODEL = {
    "amazon_polarity": AmazonPolarityBert,
    "acronym_identification": AcronymIdentificationBert,
    "swag": SwagBert,
}


def train_model(model, dataloaders, config):
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        monitor="val_loss",
        dirpath=f"{config['outputs_path']}/model_checkpoints",
        filename=config["task"] + "-model-{epoch:02d}-{val_accuracy:.2f}",
        save_top_k=1,
        mode="min",
    )

    trainer = Trainer(
        logger=WandbLogger(
            save_dir=f"{config['outputs_path']}/logs",
            project="Experiments",
            entity="mma",
        ),
#        gpus=config["gpus"],
        max_epochs=config["max_epochs"],
       callbacks=[checkpoint_callback],
        accelerator='cpu',
        devices=1
    )

    trainer.fit(model, dataloaders["train"], dataloaders["val"])
    wandb.finish()


In [81]:
config_path = '/pio/scratch/1/i308362/TheGreatKnowledgeHeist/scripts/train_config.yaml'


In [82]:
config = load_yaml(config_path)
dataloaders = get_dataloaders(
    dataset_name=config["task"],
    path_to_dataset=f"{config['dataset_path']}/{config['task']}",
    batch_size=config["batch_size"],
    num_workers=config["num_workers"],
)

model = GET_MODEL[config["task"]](config={"lr": config["lr"], "eps": config["eps"]})
train_model(model, dataloaders, config)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMultipleChoice: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight', 'cls.predictions.bias', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMultipleChoice were not initialized from the model checkpoint at bert-base-uncased and are newly

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name  | Type                  | Params
------------------------------------------------
0 | model | BertForMultipleChoice | 109 M 
------------------------------------------------
109 M     Trainable params
0         Non-trainable params
109 M     Total params
437.932   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁
trainer/global_step,▁
val_accuracy,▁
val_loss,▁

0,1
epoch,0.0
trainer/global_step,1.0
val_accuracy,0.0
val_loss,1.39726


In [83]:
bert_config = BertConfig(
    num_hidden_layers=6,
    num_attention_heads=6,
)
small_model = SwagBert(config, pretrained=False, bert_config=bert_config)

In [84]:
from abc import ABC, abstractmethod
from torch.optim import Adam

class BaseKD(pl.LightningModule, ABC):
    def __init__(self, config, teacher_model: BaseBert, student_model: BaseBert):
        super().__init__()
        self.lr = config["lr"]
        self.eps = config["eps"]

        self.teacher = teacher_model
        self.teacher.freeze()
        self.student = student_model
        self.student.unfreeze()

    @abstractmethod
    def logits_loss(self, student_logits, teacher_logits):
        pass

    @abstractmethod
    def layers_loss(self, student_layers, teacher_layers):
        pass

    def configure_optimizers(self):
        optimizer = Adam(self.student.parameters(), lr=self.lr, eps=self.eps)
        return optimizer

    def forward(self, **inputs):
        student_outputs = self.student(**inputs)
        teacher_outputs = self.teacher(**inputs)
        return student_outputs, teacher_outputs

    def training_step(self, batch, batch_idx):
        student_outputs, teacher_outputs = self(**batch)
        accuracy = self.student.calculate_accuracy(student_outputs['logits'], batch["labels"])
        loss = (
                student_outputs['loss']
                + self.logits_loss(student_outputs['logits'], teacher_outputs['logits'])
                + self.layers_loss(student_outputs['hidden_states'], teacher_outputs['hidden_states'])
        )
        self.log("train_loss", loss)
        self.log("train_accuracy", accuracy)
        return loss

    def validation_step(self, batch, batch_idx):
        student_outputs, teacher_outputs = self(**batch)
        accuracy = self.student.calculate_accuracy(student_outputs['logits'], batch["labels"])
        loss = (
                student_outputs['loss']
                + self.logits_loss(student_outputs['logits'], teacher_outputs['logits'])
                + self.layers_loss(student_outputs['hidden_states'], teacher_outputs['hidden_states'])
        )
        self.log("val_loss", loss)
        self.log("val_accuracy", accuracy)
        return loss

In [85]:
from torch.nn.functional import softmax, kl_div

class LogitsKD(BaseKD):
    def __init__(self, config, teacher_model: BaseBert, student_model: BaseBert, temperature: float = 1):
        super().__init__(config, teacher_model, student_model)

        self._T = temperature

    def logits_loss(self, student_logits, teacher_logits):
        return kl_div(
            softmax(student_logits / self._T, dim=1),
            softmax(teacher_logits / self._T, dim=1),
            reduction='batchmean'
        )

    def layers_loss(self, student_layers, teacher_layers):
        return 0

In [86]:
distilled_model = LogitsKD(config, model, small_model)

In [87]:
train_model(distilled_model, dataloaders, config)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type     | Params
-------------------------------------
0 | teacher | SwagBert | 109 M 
1 | student | SwagBert | 67.0 M
-------------------------------------
67.0 M    Trainable params
109 M     Non-trainable params
176 M     Total params
705.755   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁
trainer/global_step,▁
val_accuracy,▁
val_loss,▁

0,1
epoch,0.0
trainer/global_step,1.0
val_accuracy,0.2
val_loss,-0.23985
