In [1]:
from typing import Sequence, Optional

import pytorch_lightning as pl
import wandb
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger

from thegreatknowledgeheist.data import get_dataloaders
from thegreatknowledgeheist.io import load_yaml
from thegreatknowledgeheist.models import BertFactory
from thegreatknowledgeheist.models.bert import BaseBert
from transformers import BertConfig

In [2]:
def train_model(model, dataloaders, config):
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        monitor="val_loss",
        dirpath=f"{config['outputs_path']}/model_checkpoints",
        filename=config["task"] + "-model-{epoch:02d}-{val_accuracy:.2f}",
        save_top_k=1,
        mode="min",
    )

    trainer = Trainer(
        logger=WandbLogger(
            save_dir=f"{config['outputs_path']}/logs",
            project="experiments",
            entity="mma",
        ),
        gpus=config["gpus"],
        max_epochs=config["max_epochs"],
        callbacks=[checkpoint_callback],
    )

    trainer.fit(model, dataloaders["train"], dataloaders["val"])
    wandb.finish()


In [3]:
config_path = '/pio/scratch/1/i308362/TheGreatKnowledgeHeist/configs/train_config.yaml'


In [4]:
config = load_yaml(config_path)
dataloaders = get_dataloaders(
    dataset_name=config["task"],
    path_to_dataset=f"{config['dataset_path']}/{config['task']}",
    batch_size=config["batch_size"],
    num_workers=config["num_workers"],
)

factory = BertFactory()

bert_config = BertConfig(
    output_hidden_states=True
)
model = factory.create_model(config["task"], config=config, bert_config=bert_config)
train_model(model, dataloaders, config)


Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMultipleChoice: ['cls.predictions.decoder.weight', 'cls.seq_relationship.weight', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMultipleChoice were not initialized from the model checkpoint at bert-base-uncased and are newly

This layer will be frozen: bert.embeddings.word_embeddings.weight
This layer will be frozen: bert.embeddings.position_embeddings.weight
This layer will be frozen: bert.embeddings.token_type_embeddings.weight
This layer will be frozen: bert.embeddings.LayerNorm.weight
This layer will be frozen: bert.embeddings.LayerNorm.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.query.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.query.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.key.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.key.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.value.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.value.bias
This layer will be frozen: bert.encoder.layer.0.attention.output.dense.weight
This layer will be frozen: bert.encoder.layer.0.attention.output.dense.bias
This layer will be frozen: bert.encoder.layer.0.attention.output.LayerNorm

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33madrianurbanski[0m ([33mmma[0m). Use [1m`wandb login --relogin`[0m to force relogin


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name  | Type                  | Params
------------------------------------------------
0 | model | BertForMultipleChoice | 109 M 
1 | f1    | F1Score               | 0     
------------------------------------------------
14.8 M    Trainable params
94.7 M    Non-trainable params
109 M     Total params
437.932   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁
trainer/global_step,▁
val_accuracy,▁
val_f1,▁
val_loss,▁

0,1
epoch,0.0
trainer/global_step,0.0
val_accuracy,0.4
val_f1,0.26667
val_loss,1.39359


In [5]:
bert_config = BertConfig(
    num_hidden_layers=6,
    num_attention_heads=6,
    output_hidden_states=True,
)
small_model = factory.create_model(config["task"], config=config, pretrained=False, bert_config=bert_config)

This layer will be frozen: bert.embeddings.word_embeddings.weight
This layer will be frozen: bert.embeddings.position_embeddings.weight
This layer will be frozen: bert.embeddings.token_type_embeddings.weight
This layer will be frozen: bert.embeddings.LayerNorm.weight
This layer will be frozen: bert.embeddings.LayerNorm.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.query.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.query.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.key.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.key.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.value.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.value.bias
This layer will be frozen: bert.encoder.layer.0.attention.output.dense.weight
This layer will be frozen: bert.encoder.layer.0.attention.output.dense.bias
This layer will be frozen: bert.encoder.layer.0.attention.output.LayerNorm

In [6]:
from abc import ABC, abstractmethod
from torch.optim import Adam

class BaseKD(pl.LightningModule, ABC):
    def __init__(self, config, teacher_model: BaseBert, student_model: BaseBert):
        super().__init__()
        self.lr = config["lr"]
        self.eps = config["eps"]

        self.teacher = teacher_model
        self.teacher.freeze()
        self.student = student_model
        self.student.unfreeze()

    @abstractmethod
    def logits_loss(self, student_logits, teacher_logits):
        pass

    @abstractmethod
    def layers_loss(self, student_layers, teacher_layers):
        pass

    def configure_optimizers(self):
        optimizer = Adam(self.student.parameters(), lr=self.lr, eps=self.eps)
        return optimizer

    def forward(self, **inputs):
        student_outputs = self.student(**inputs)
        teacher_outputs = self.teacher(**inputs)
        return student_outputs, teacher_outputs

    def training_step(self, batch, batch_idx):
        student_outputs, teacher_outputs = self(**batch)
        accuracy = self.student.calculate_accuracy(student_outputs['logits'], batch["labels"])
        loss = (
                student_outputs['loss']
                + self.logits_loss(student_outputs['logits'], teacher_outputs['logits'])
                + self.layers_loss(student_outputs['hidden_states'], teacher_outputs['hidden_states'])
        )
        self.log("train_loss", loss)
        self.log("train_accuracy", accuracy)
        return loss

    def validation_step(self, batch, batch_idx):
        student_outputs, teacher_outputs = self(**batch)
        accuracy = self.student.calculate_accuracy(student_outputs['logits'], batch["labels"])
        loss = (
                student_outputs['loss']
                + self.logits_loss(student_outputs['logits'], teacher_outputs['logits'])
                + self.layers_loss(student_outputs['hidden_states'], teacher_outputs['hidden_states'])
        )
        self.log("val_loss", loss)
        self.log("val_accuracy", accuracy)
        return loss

In [7]:
from torch.nn.functional import softmax, kl_div

class LogitsKD(BaseKD):
    def __init__(self, config, teacher_model: BaseBert, student_model: BaseBert, temperature: float = 1):
        super().__init__(config, teacher_model, student_model)

        self._T = temperature

    def logits_loss(self, student_logits, teacher_logits):
        return kl_div(
            softmax(student_logits / self._T, dim=1),
            softmax(teacher_logits / self._T, dim=1),
            reduction='batchmean'
        )

    def layers_loss(self, student_layers, teacher_layers):
        return 0

In [8]:
distilled_model = LogitsKD(config, model, small_model)

In [9]:
train_model(distilled_model, dataloaders, config)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name    | Type     | Params
-------------------------------------
0 | teacher | SwagBert | 109 M 
1 | student | SwagBert | 67.0 M
-------------------------------------
67.0 M    Trainable params
109 M     Non-trainable params
176 M     Total params
705.755   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁
trainer/global_step,▁
val_accuracy,▁
val_loss,▁

0,1
epoch,0.0
trainer/global_step,0.0
val_accuracy,0.4
val_loss,-0.24831
