In [52]:
import pytorch_lightning as pl
import torch
import wandb
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from torch import nn
from torch.optim import Adam
from transformers import BertConfig

from thegreatknowledgeheist.data import get_dataloaders
from thegreatknowledgeheist.io import load_yaml
from thegreatknowledgeheist.models import BertFactory
from thegreatknowledgeheist.models.bert import BaseBert

In [86]:
def train_model(model, dataloaders, config):

    checkpoint_callback = pl.callbacks.ModelCheckpoint(
        monitor="train_student_loss",
        dirpath=f"{config['outputs_path']}/model_checkpoints",
        filename=config["task"] + "-model-{epoch:02d}-{train_student_loss:.2f}",
        save_top_k=1,
        mode="min",
    )

    trainer = Trainer(
        logger=WandbLogger(
            save_dir=f"{config['outputs_path']}/logs",
            project="experiments",
            entity="mma",
        ),
        gpus=config["gpus"],
        max_epochs=config["max_epochs"],
        callbacks=[checkpoint_callback],
    )

    trainer.fit(model, dataloaders["train"], dataloaders["val"])
    wandb.finish()

In [54]:
config_path = '/home/maria/Documents/TheGreatKnowledgeHeist/configs/train_config.yaml'

In [55]:
config = load_yaml(config_path)
dataloaders = get_dataloaders(
    dataset_name=config["task"],
    path_to_dataset=f"{config['dataset_path']}/{config['task']}",
    batch_size=config["batch_size"],
    num_workers=config["num_workers"],
)

factory = BertFactory()

model = factory.create_model(config["task"], config=config)
train_model(model, dataloaders, config)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForSequenceClassification: ['cls.seq_relationship.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at

This layer will be frozen: bert.embeddings.word_embeddings.weight
This layer will be frozen: bert.embeddings.position_embeddings.weight
This layer will be frozen: bert.embeddings.token_type_embeddings.weight
This layer will be frozen: bert.embeddings.LayerNorm.weight
This layer will be frozen: bert.embeddings.LayerNorm.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.query.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.query.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.key.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.key.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.value.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.value.bias
This layer will be frozen: bert.encoder.layer.0.attention.output.dense.weight
This layer will be frozen: bert.encoder.layer.0.attention.output.dense.bias
This layer will be frozen: bert.encoder.layer.0.attention.output.LayerNorm

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                          | Params
--------------------------------------------------------
0 | model | BertForSequenceClassification | 109 M 
1 | f1    | F1Score                       | 0     
--------------------------------------------------------
14.8 M    Trainable params
94.7 M    Non-trainable params
109 M     Total params
437.935   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
epoch,▁▂▃▃▃▄▅▆▆▆▇██
train_accuracy,▁██
train_f1,▁██
train_loss,█▁▁
trainer/global_step,▁▂▃▃▃▄▅▅▆▆▇██
val_accuracy,▁▇████████
val_f1,▁▇████████
val_loss,█▂▁▃▂▃▂▃▃▃

0,1
epoch,9.0
train_accuracy,1.0
train_f1,1.0
train_loss,0.04541
trainer/global_step,159.0
val_accuracy,0.872
val_f1,0.86811
val_loss,0.38977


In [56]:
bert_config = BertConfig(
    num_hidden_layers=6,
    num_attention_heads=6,
)
small_model = factory.create_model(config["task"], config=config, pretrained=False, bert_config=bert_config)

This layer will be frozen: bert.embeddings.word_embeddings.weight
This layer will be frozen: bert.embeddings.position_embeddings.weight
This layer will be frozen: bert.embeddings.token_type_embeddings.weight
This layer will be frozen: bert.embeddings.LayerNorm.weight
This layer will be frozen: bert.embeddings.LayerNorm.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.query.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.query.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.key.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.key.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.value.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.value.bias
This layer will be frozen: bert.encoder.layer.0.attention.output.dense.weight
This layer will be frozen: bert.encoder.layer.0.attention.output.dense.bias
This layer will be frozen: bert.encoder.layer.0.attention.output.LayerNorm

In [87]:
# https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/basic-gan.html

class Discriminator(pl.LightningModule):
    def __init__(self, config, num_classes):
        super().__init__()
        self.lr = config["lr"]
        self.eps = config["eps"]

        self.model = nn.Sequential(
            nn.Linear(num_classes, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, num_classes + 1),
        )

    def forward(self, inputs):
        return self.model(inputs)

    def training_step(self, batch, batch_idx):
        pass

    def validation_step(self, batch, batch_idx):
        pass

class AdversarialKD(pl.LightningModule):
    def __init__(self, config, teacher_model: BaseBert, student_model: BaseBert, discriminator_model: Discriminator):
        super().__init__()
        self.lr = config["lr"]
        self.eps = config["eps"]

        self.teacher = teacher_model
        # self.teacher.freeze()
        self.student = student_model
        # self.student.unfreeze()
        self.discriminator = discriminator_model
        # self.discriminator.unfreeze()
        self.cross_entropy_loss = nn.CrossEntropyLoss()
        self.binary_entropy_loss = nn.BCELoss()

    def configure_optimizers(self):
        student_optimizer = Adam(self.student.parameters(), lr=self.lr, eps=self.eps)
        # TODO: add config for discriminator
        discriminator_optimizer = Adam(self.discriminator.parameters(), lr=self.lr, eps=self.eps)
        return [student_optimizer, discriminator_optimizer]

    def forward(self, **inputs):
        student_outputs = self.student(**inputs)
        teacher_outputs = self.teacher(**inputs)
        discriminator_student_outputs = self.discriminator(student_outputs["logits"])
        discriminator_teacher_outputs = self.discriminator(teacher_outputs["logits"])
        return student_outputs, teacher_outputs, discriminator_student_outputs, discriminator_teacher_outputs

    def adverserial_loss(self, outputs, targets):
        # last logit - true/false; 0 - student, 1 - teacher
        # :, : for acronyms?
        return self.binary_entropy_loss(nn.Sigmoid()(outputs[:, -1]), targets)

    def adverserial_categories_loss(self, outputs, targets):
        return self.cross_entropy_loss(nn.Softmax()(outputs[:, :-1]), targets)

    def l1_norm(self, outputs, targets):
        return torch.norm((outputs * targets), 1, -1).mean()

    def discriminator_loss(self, discriminator_student_outputs, discriminator_teacher_outputs, targets):
        return 1/2 * (
            self.adverserial_loss(discriminator_student_outputs,
                                  torch.zeros(discriminator_student_outputs.size()[0]).to('cuda'))
            + self.adverserial_loss(discriminator_teacher_outputs,
                                  torch.ones(discriminator_teacher_outputs.size()[0]).to('cuda'))
            + self.adverserial_categories_loss(discriminator_student_outputs, targets)
            + self.adverserial_categories_loss(discriminator_teacher_outputs, targets)
        )

    def student_loss(self, supervised_loss, student_outputs, teacher_outputs, discriminator_student_outputs, discriminator_teacher_outputs, targets):
        return supervised_loss\
               + self.l1_norm(student_outputs, teacher_outputs)\
               + self.discriminator_loss(discriminator_student_outputs, discriminator_teacher_outputs, targets)

    def training_step(self, batch, batch_idx, optimizer_idx):
        student_outputs, teacher_outputs, discriminator_student_outputs, discriminator_teacher_outputs = self(**batch)
        # train student
        if optimizer_idx == 0:
            loss = self.student_loss(
                student_outputs['loss'],
                student_outputs['logits'],
                teacher_outputs['logits'],
                discriminator_student_outputs,
                discriminator_teacher_outputs,
                batch['labels'].to('cuda')
            )
            self.log("train_student_loss", loss)
        # train discriminator
        elif optimizer_idx == 1:
            loss = self.discriminator_loss(
                discriminator_student_outputs,
                discriminator_teacher_outputs,
                batch['labels'].to('cuda')
            )
            self.log("train_discriminator_loss", loss)
        else:
            raise NotImplementedError()
        return loss

In [88]:
# num_classes:
# swag - 4
# amazon - 2
# acronym - 5
discriminator = Discriminator(config=config, num_classes=2)
gan  = AdversarialKD(config=config, teacher_model=model, student_model=small_model, discriminator_model=discriminator)

In [89]:
train_model(gan, dataloaders, config)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                | Type               | Params
-----------------------------------------------------------
0 | teacher             | AmazonPolarityBert | 109 M 
1 | student             | AmazonPolarityBert | 67.0 M
2 | discriminator       | Discriminator      | 133 K 
3 | cross_entropy_loss  | CrossEntropyLoss   | 0     
4 | binary_entropy_loss | BCELoss            | 0     
-----------------------------------------------------------
14.9 M    Trainable params
161 M     Non-trainable params
176 M     Total params
706.296   Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

  return self.cross_entropy_loss(nn.Softmax()(outputs[:, :-1]), targets)


ReferenceError: weakly-referenced object no longer exists