In [1]:
from typing import Sequence, Optional

import numpy as np
import pytorch_lightning as pl
import wandb
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from transformers import BertConfig

from thegreatknowledgeheist.data import get_dataloaders
from thegreatknowledgeheist.io import load_yaml
from thegreatknowledgeheist.models import BertFactory

In [2]:
def train_model(model, dataloaders, config):
    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        monitor="val_loss",
        dirpath=f"{config['outputs_path']}/model_checkpoints",
        filename=config["task"] + "-model-{epoch:02d}-{val_accuracy:.2f}",
        save_top_k=1,
        mode="min",
    )

    trainer = Trainer(
        logger=WandbLogger(
            save_dir=f"{config['outputs_path']}/logs",
            project="experiments",
            entity="mma",
        ),
        gpus=config["gpus"],
        max_epochs=config["max_epochs"],
        callbacks=[checkpoint_callback],
    )

    trainer.fit(model, dataloaders["train"], dataloaders["val"])
    wandb.finish()


In [3]:
config_path = '/pio/scratch/1/i308362/TheGreatKnowledgeHeist/configs/train_config.yaml'


In [11]:
config = load_yaml(config_path)
dataloaders = get_dataloaders(
    dataset_name=config["task"],
    path_to_dataset=f"{config['dataset_path']}/{config['task']}",
    batch_size=config["batch_size"],
    num_workers=config["num_workers"],
)


In [12]:
swag_checkpoint_path = '/pio/scratch/1/i308362/TheGreatKnowledgeHeist/out/swag-model-epoch=03-val_f1=0.74.ckpt'
teacher_config = BertConfig()
student_config = BertConfig(
    num_hidden_layers=6,
    num_attention_heads=6,
)

In [13]:
from abc import ABC, abstractmethod
from torch.optim import Adam

class BaseKD(pl.LightningModule, ABC):
    def __init__(self, config, teacher_checkpoint: str, teacher_config: BertConfig, student_config: BertConfig):
        super().__init__()
        self.lr = config["lr"]
        self.eps = config["eps"]

        factory = BertFactory()
        teacher_config.output_hidden_states = True
        student_config.output_hidden_states = True

        self.teacher = factory.create_model(config['task'], config=config, bert_config=teacher_config, checkpoint_path=teacher_checkpoint)
        self.teacher.freeze()
        self.student = factory.create_model(config['task'], config=config, bert_config=student_config, pretrained=False)
        self.student.unfreeze()

    @abstractmethod
    def logits_loss(self, student_logits, teacher_logits):
        pass

    @abstractmethod
    def layers_loss(self, student_layers, teacher_layers):
        pass

    def configure_optimizers(self):
        optimizer = Adam(self.student.parameters(), lr=self.lr, eps=self.eps)
        return optimizer

    def forward(self, **inputs):
        student_outputs = self.student(**inputs)
        teacher_outputs = self.teacher(**inputs)
        return student_outputs, teacher_outputs

    def training_step(self, batch, batch_idx):
        student_outputs, teacher_outputs = self(**batch)
        accuracy = self.student.calculate_accuracy(student_outputs['logits'], batch["labels"])
        loss = (
                student_outputs['loss']
                + self.logits_loss(student_outputs['logits'], teacher_outputs['logits'])
                + self.layers_loss(student_outputs['hidden_states'], teacher_outputs['hidden_states'])
        )
        self.log("train_loss", loss)
        self.log("train_accuracy", accuracy)
        return loss

    def validation_step(self, batch, batch_idx):
        student_outputs, teacher_outputs = self(**batch)
        accuracy = self.student.calculate_accuracy(student_outputs['logits'], batch["labels"])
        loss = (
                student_outputs['loss']
                + self.logits_loss(student_outputs['logits'], teacher_outputs['logits'])
                + self.layers_loss(student_outputs['hidden_states'], teacher_outputs['hidden_states'])
        )
        self.log("val_loss", loss)
        self.log("val_accuracy", accuracy)
        return loss

In [14]:
class BaselineKD(BaseKD):
    def __init__(self, config, teacher_checkpoint: str, teacher_config: BertConfig, student_config: BertConfig):
        super().__init__(config, teacher_checkpoint, teacher_config, student_config)

    def layers_loss(self, student_layers, teacher_layers):
        return 0

    def logits_loss(self, student_logits, teacher_logits):
        return 0

In [15]:
from torch.nn.functional import softmax, kl_div
from torch import nn

In [16]:
class KL_div(nn.Module):
    def __init__(self, temperature: float = 1):
        super(KL_div, self).__init__()

        self.T = temperature

    def forward(self, student_logits, teacher_logits):
        return kl_div(
            softmax(student_logits / self.T, dim=1),
            softmax(teacher_logits / self.T, dim=1),
            reduction='batchmean'
        )

In [17]:
class LogitsKD(BaseKD):
    def __init__(self, config, teacher_checkpoint: str, teacher_config: BertConfig, student_config: BertConfig, temperature: float = 1):
        super().__init__(config, teacher_checkpoint, teacher_config, student_config)

        self.logits_criterion = KL_div(temperature)


    def logits_loss(self, student_logits, teacher_logits):
        return self.logits_criterion(student_logits, teacher_logits)

    def layers_loss(self, student_layers, teacher_layers):
        return 0

In [None]:
baseline_model = BaselineKD(config, swag_checkpoint_path, teacher_config, student_config)
train_model(baseline_model, dataloaders, config)


This layer will be frozen: bert.embeddings.word_embeddings.weight
This layer will be frozen: bert.embeddings.position_embeddings.weight
This layer will be frozen: bert.embeddings.token_type_embeddings.weight
This layer will be frozen: bert.embeddings.LayerNorm.weight
This layer will be frozen: bert.embeddings.LayerNorm.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.query.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.query.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.key.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.key.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.value.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.value.bias
This layer will be frozen: bert.encoder.layer.0.attention.output.dense.weight
This layer will be frozen: bert.encoder.layer.0.attention.output.dense.bias
This layer will be frozen: bert.encoder.layer.0.attention.output.LayerNorm

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33madrianurbanski[0m ([33mmma[0m). Use [1m`wandb login --relogin`[0m to force relogin


GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name    | Type     | Params
-------------------------------------
0 | teacher | SwagBert | 109 M 
1 | student | SwagBert | 67.0 M
-------------------------------------
67.0 M    Trainable params
109 M     Non-trainable params
176 M     Total params
705.755   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [36]:
distilled_model = LogitsKD(config, swag_checkpoint_path, teacher_config, student_config)

This layer will be frozen: bert.embeddings.word_embeddings.weight
This layer will be frozen: bert.embeddings.position_embeddings.weight
This layer will be frozen: bert.embeddings.token_type_embeddings.weight
This layer will be frozen: bert.embeddings.LayerNorm.weight
This layer will be frozen: bert.embeddings.LayerNorm.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.query.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.query.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.key.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.key.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.value.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.value.bias
This layer will be frozen: bert.encoder.layer.0.attention.output.dense.weight
This layer will be frozen: bert.encoder.layer.0.attention.output.dense.bias
This layer will be frozen: bert.encoder.layer.0.attention.output.LayerNorm

In [37]:
train_model(distilled_model, dataloaders, config)

  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")

  | Name             | Type     | Params
----------------------------------------------
0 | teacher          | SwagBert | 109 M 
1 | student          | SwagBert | 67.0 M
2 | logits_criterion | KL_div   | 0     
----------------------------------------------
67.0 M    Trainable params
109 M     Non-trainable params
176 M     Total params
705.755   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

None
False


  rank_zero_warn(


Training: 0it [00:00, ?it/s]

RuntimeError: you can only change requires_grad flags of leaf variables.

In [19]:
class LayersKD(LogitsKD):
    def __init__(
            self, config, teacher_checkpoint: str, teacher_config: BertConfig, student_config: BertConfig,
            temperature: float = 1, layers_map: Optional[Sequence[int]] = None
    ):
        super().__init__(config, teacher_checkpoint, teacher_config, student_config, temperature)

        if layers_map is None:
            num_student_layers = self.student.model.config.num_hidden_layers
            num_teacher_layers = self.teacher.model.config.num_hidden_layers
            layers_map = np.linspace(0, num_teacher_layers, num=num_student_layers, endpoint=False, dtype=int)

        self.layers_map = layers_map
        self.layers_criterion = nn.MSELoss()

    def layers_loss(self, student_layers, teacher_layers):
        loss = []
        print(student_layers[0].requires_grad)
        print(teacher_layers[0].requires_grad)
        for student_layer, teacher_layer in enumerate(self.layers_map):
            loss.append(self.layers_criterion(student_layers[student_layer], teacher_layers[teacher_layer]))
        return sum(loss)

In [20]:
distilled_model = LayersKD(config, swag_checkpoint_path, teacher_config, student_config)
train_model(distilled_model, dataloaders, config)

This layer will be frozen: bert.embeddings.word_embeddings.weight
This layer will be frozen: bert.embeddings.position_embeddings.weight
This layer will be frozen: bert.embeddings.token_type_embeddings.weight
This layer will be frozen: bert.embeddings.LayerNorm.weight
This layer will be frozen: bert.embeddings.LayerNorm.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.query.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.query.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.key.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.key.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.value.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.value.bias
This layer will be frozen: bert.encoder.layer.0.attention.output.dense.weight
This layer will be frozen: bert.encoder.layer.0.attention.output.dense.bias
This layer will be frozen: bert.encoder.layer.0.attention.output.LayerNorm

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMultipleChoice: ['bert.encoder.layer.10.attention.self.value.weight', 'bert.encoder.layer.10.attention.self.value.bias', 'bert.encoder.layer.10.intermediate.dense.bias', 'bert.encoder.layer.10.output.dense.weight', 'bert.encoder.layer.8.intermediate.dense.weight', 'bert.encoder.layer.7.attention.output.dense.weight', 'bert.encoder.layer.11.intermediate.dense.weight', 'bert.encoder.layer.10.attention.self.key.bias', 'bert.encoder.layer.9.attention.self.value.weight', 'bert.encoder.layer.7.attention.self.key.bias', 'bert.encoder.layer.8.attention.self.value.weight', 'bert.encoder.layer.9.attention.self.key.weight', 'bert.encoder.layer.6.attention.output.LayerNorm.bias', 'bert.encoder.layer.10.attention.output.LayerNorm.bias', 'bert.encoder.layer.11.attention.output.dense.bias', 'bert.encoder.layer.9.intermediate.dense.weight', 'bert.encoder.layer.10.output.dense.bias', 'bert.encoder.layer.11.

This layer will be frozen: bert.embeddings.word_embeddings.weight
This layer will be frozen: bert.embeddings.position_embeddings.weight
This layer will be frozen: bert.embeddings.token_type_embeddings.weight
This layer will be frozen: bert.embeddings.LayerNorm.weight
This layer will be frozen: bert.embeddings.LayerNorm.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.query.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.query.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.key.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.key.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.value.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.value.bias
This layer will be frozen: bert.encoder.layer.0.attention.output.dense.weight
This layer will be frozen: bert.encoder.layer.0.attention.output.dense.bias
This layer will be frozen: bert.encoder.layer.0.attention.output.LayerNorm

Sanity Checking: 0it [00:00, ?it/s]

False
False


  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…