In [2]:
from argparse import ArgumentParser

import pytorch_lightning as pl
import wandb
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

from thegreatknowledgeheist.data import get_dataloaders
from thegreatknowledgeheist.io import load_yaml
from thegreatknowledgeheist.models.bert import (
    AcronymIdentificationBert,
    AmazonPolarityBert,
    SwagBert,
)

In [3]:
GET_MODEL = {
    "amazon_polarity": AmazonPolarityBert,
    "acronym_identification": AcronymIdentificationBert,
    "swag": SwagBert,
}


def train_model(model, dataloaders, config):
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        monitor="val_loss",
        dirpath=f"{config['outputs_path']}/model_checkpoints",
        filename=config["task"] + "-model-{epoch:02d}-{val_accuracy:.2f}",
        save_top_k=1,
        mode="min",
    )

    trainer = Trainer(
        # logger=WandbLogger(
        #     save_dir=f"{config['outputs_path']}/logs",
        #     project="Experiments",
        #     entity="mma",
        # ),
#        gpus=config["gpus"],
        max_epochs=config["max_epochs"],
#        callbacks=[checkpoint_callback],
        accelerator='cpu',
        devices=1
    )

    trainer.fit(model, dataloaders["train"], dataloaders["val"])
#    wandb.finish()


In [4]:
config_path = '/pio/scratch/1/i308362/TheGreatKnowledgeHeist/scripts/train_config.yaml'


In [5]:
config = load_yaml(config_path)
dataloaders = get_dataloaders(
    dataset_name=config["task"],
    path_to_dataset=f"{config['dataset_path']}/{config['task']}",
    batch_size=config["batch_size"],
    num_workers=config["num_workers"],
)

model = GET_MODEL[config["task"]](config={"lr": config["lr"], "eps": config["eps"]})
train_model(model, dataloaders, config)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMultipleChoice: ['cls.seq_relationship.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing BertForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMultipleChoice were not initialized from the model checkpoint at bert-base-uncased and are newly

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [6]:
import torch
from thegreatknowledgeheist.models.bert import BaseBert
from transformers import BertForMultipleChoice, BertConfig

class SwagBert(BaseBert):
    def __init__(self, config, pretrained: bool = True, bert_config: BertConfig = None):
        super().__init__(config)

        if pretrained:
            self.model = BertForMultipleChoice.from_pretrained(
                "bert-base-uncased", num_labels=4
            )
        else:
            self.model = BertForMultipleChoice(bert_config)


    def calculate_accuracy(self, logits, labels):
        preds = torch.argmax(logits, dim=1)
        correct_preds = torch.sum(preds == labels)
        return correct_preds / len(preds)


In [7]:
bert_config = BertConfig(
    num_hidden_layers=6,
    num_attention_heads=6,
)
small_model = SwagBert(config, pretrained=False, bert_config=bert_config)

In [8]:
from torch.optim import Adam
from torch.nn.functional import softmax, kl_div

class KDLogits(pl.LightningModule):
    def __init__(self, config, teacher_model: BaseBert, student_model: BaseBert, temperature=1):
        super().__init__()
        self.lr = config["lr"]
        self.eps = config["eps"]
        self._T = temperature
        self._loss = kl_div

        self.teacher = teacher_model
        self.teacher.freeze()
        self.student = student_model
        self.student.unfreeze()

    def configure_optimizers(self):
        optimizer = Adam(self.student.parameters(), lr=self.lr, eps=self.eps)
        return optimizer

    def training_step(self, batch, batch_idx):
        teacher_logits = self.teacher(**batch)[1]
        student_logits = self.student(**batch)[1]
        accuracy = self.student.calculate_accuracy(student_logits, batch["labels"])
        loss = self._loss(softmax(student_logits/self._T, dim=1), softmax(teacher_logits/self._T, dim=1), reduction='batchmean')
        self.log("train_loss", loss)
        self.log("train_accuracy", accuracy)
        return loss


    def validation_step(self, batch, batch_idx):
        teacher_logits = self.teacher(**batch)[1]
        student_logits = self.student(**batch)[1]
        accuracy = self.student.calculate_accuracy(student_logits, batch["labels"])
        loss = self._loss(softmax(student_logits/self._T, dim=1), softmax(teacher_logits/self._T, dim=1), reduction='batchmean')
        self.log("train_loss", loss)
        self.log("train_accuracy", accuracy)
        return loss

In [9]:
distilled_model = KDLogits(config, model, small_model)

In [10]:
train_model(distilled_model, dataloaders, config)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type     | Params
-------------------------------------
0 | teacher | SwagBert | 109 M 
1 | student | SwagBert | 67.0 M
-------------------------------------
67.0 M    Trainable params
109 M     Non-trainable params
176 M     Total params
705.755   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]