In [1]:
import pytorch_lightning as pl
import torch
import wandb
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import WandbLogger
from torch import nn
from torch.optim import Adam
from transformers import BertConfig

from thegreatknowledgeheist.data import get_dataloaders
from thegreatknowledgeheist.io import load_yaml
from thegreatknowledgeheist.models import BertFactory

In [2]:
def train_model(model, dataloaders, config):

    # checkpoint_callback = pl.callbacks.ModelCheckpoint(
    #     monitor="val_loss",
    #     dirpath=f"{config['outputs_path']}/model_checkpoints",
    #     filename=config["task"] + "-model-{epoch:02d}-{val_loss:.2f}",
    #     save_top_k=1,
    #     mode="min",
    # )

    trainer = Trainer(
        logger=WandbLogger(
            save_dir=f"{config['outputs_path']}/logs",
            project="experiments",
            entity="mma",
        ),
        gpus=config["gpus"],
        max_epochs=config["max_epochs"],
        # callbacks=[checkpoint_callback],
    )

    trainer.fit(model, dataloaders["train"], dataloaders["val"])
    wandb.finish()

In [3]:
config_path = '/home/maria/Documents/TheGreatKnowledgeHeist/configs/train_config.yaml'

In [4]:
config = load_yaml(config_path)
dataloaders = get_dataloaders(
    dataset_name=config["task"],
    path_to_dataset=f"{config['dataset_path']}/{config['task']}",
    batch_size=config["batch_size"],
    num_workers=config["num_workers"],
)

In [5]:
# https://pytorch-lightning.readthedocs.io/en/stable/notebooks/lightning_examples/basic-gan.html

class Discriminator(pl.LightningModule):
    def __init__(self, config, num_classes):
        super().__init__()
        self.lr = config["lr"]
        self.eps = config["eps"]

        self.model = nn.Sequential(
            nn.Linear(num_classes, 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, num_classes + 1),
        )

    def forward(self, inputs):
        return self.model(inputs)

    def training_step(self, batch, batch_idx):
        pass

    def validation_step(self, batch, batch_idx):
        pass

class AdversarialKD(pl.LightningModule):
    def __init__(self, config, teacher_model_path: str, discriminator_model: Discriminator):
        super().__init__()
        factory = BertFactory()
        self.lr = config["lr"]
        self.eps = config["eps"]
        self.teacher = factory.create_model(config["task"], config=config, checkpoint_path=teacher_model_path)
        self.teacher.freeze()
        self.student = factory.create_model(config["task"], config=config, bert_config=BertConfig(num_hidden_layers=4, num_attention_heads=4), pretrained=False)
        self.student.unfreeze()
        self.discriminator = discriminator_model
        self.discriminator.unfreeze()
        self.cross_entropy_loss = nn.CrossEntropyLoss()
        self.binary_entropy_loss = nn.BCELoss()

    def configure_optimizers(self):
        student_optimizer = Adam(self.student.parameters(), lr=self.lr, eps=self.eps)
        # TODO: add config for discriminator
        discriminator_optimizer = Adam(self.discriminator.parameters(), lr=self.lr, eps=self.eps)
        return [student_optimizer, discriminator_optimizer]

    def forward(self, **inputs):
        student_outputs = self.student(**inputs)
        teacher_outputs = self.teacher(**inputs)
        discriminator_student_outputs = self.discriminator(student_outputs["logits"])
        discriminator_teacher_outputs = self.discriminator(teacher_outputs["logits"])
        return student_outputs, teacher_outputs, discriminator_student_outputs, discriminator_teacher_outputs

    def adverserial_loss(self, outputs, targets):
        # last logit - true/false; 0 - student, 1 - teacher
        # :, : for acronyms?
        return self.binary_entropy_loss(nn.Sigmoid()(outputs[:, -1]), targets)

    def adverserial_categories_loss(self, outputs, targets):
        return self.cross_entropy_loss(nn.Softmax()(outputs[:, :-1]), targets)

    def l1_norm(self, outputs, targets):
        return torch.norm((outputs * targets), 1, -1).mean()

    def discriminator_loss(self, discriminator_student_outputs, discriminator_teacher_outputs, targets):
        return 1/2 * (
            self.adverserial_loss(discriminator_student_outputs,
                                  torch.zeros(discriminator_student_outputs.size()[0]).to(self.device))
            + self.adverserial_loss(discriminator_teacher_outputs,
                                  torch.ones(discriminator_teacher_outputs.size()[0]).to(self.device))
            + self.adverserial_categories_loss(discriminator_student_outputs, targets)
            + self.adverserial_categories_loss(discriminator_teacher_outputs, targets)
        )

    def student_loss(self, supervised_loss, student_outputs, teacher_outputs, discriminator_student_outputs, discriminator_teacher_outputs, targets):
        return supervised_loss\
               + self.l1_norm(student_outputs, teacher_outputs)\
               + self.discriminator_loss(discriminator_student_outputs, discriminator_teacher_outputs, targets)

    def training_step(self, batch, batch_idx, optimizer_idx):
        student_outputs, teacher_outputs, discriminator_student_outputs, discriminator_teacher_outputs = self(**batch)
        # train student
        if optimizer_idx == 0:
            loss = self.student_loss(
                student_outputs['loss'],
                student_outputs['logits'],
                teacher_outputs['logits'],
                discriminator_student_outputs,
                discriminator_teacher_outputs,
                batch['labels']
            )
            self.log("train_student_loss", loss)
        # train discriminator
        elif optimizer_idx == 1:
            loss = self.discriminator_loss(
                discriminator_student_outputs,
                discriminator_teacher_outputs,
                batch['labels']
            )
            self.log("train_discriminator_loss", loss)
        else:
            raise NotImplementedError()
        return loss

    def validation_step(self, batch, batch_idx):
        outputs = self.student(**batch)
        loss, logits = outputs[:2]
        self.log("val_loss", loss, on_step=False, on_epoch=True)

In [6]:
# num_classes:
# swag - 4
# amazon - 2
# acronym - 5
wandb.finish()
discriminator = Discriminator(config=config, num_classes=2)
gan = AdversarialKD(config=config, teacher_model_path='/home/maria/Documents/TheGreatKnowledgeHeist/out/model_checkpoints/amazon_polarity-model-epoch=08-val_f1=0.93.ckpt', discriminator_model=discriminator)

This layer will be frozen: bert.embeddings.word_embeddings.weight
This layer will be frozen: bert.embeddings.position_embeddings.weight
This layer will be frozen: bert.embeddings.token_type_embeddings.weight
This layer will be frozen: bert.embeddings.LayerNorm.weight
This layer will be frozen: bert.embeddings.LayerNorm.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.query.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.query.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.key.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.key.bias
This layer will be frozen: bert.encoder.layer.0.attention.self.value.weight
This layer will be frozen: bert.encoder.layer.0.attention.self.value.bias
This layer will be frozen: bert.encoder.layer.0.attention.output.dense.weight
This layer will be frozen: bert.encoder.layer.0.attention.output.dense.bias
This layer will be frozen: bert.encoder.layer.0.attention.output.LayerNorm

In [7]:
train_model(gan, dataloaders, config)

[34m[1mwandb[0m: Currently logged in as: [33mmaria_wyrzykowska[0m ([33mmma[0m). Use [1m`wandb login --relogin`[0m to force relogin


GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                | Type               | Params
-----------------------------------------------------------
0 | teacher             | AmazonPolarityBert | 109 M 
1 | student             | AmazonPolarityBert | 52.8 M
2 | discriminator       | Discriminator      | 133 K 
3 | cross_entropy_loss  | CrossEntropyLoss   | 0     
4 | binary_entropy_loss | BCELoss            | 0     
-----------------------------------------------------------
52.9 M    Trainable params
109 M     Non-trainable params
162 M     Total params
649.593   Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


Training: 0it [00:00, ?it/s]

  return self.cross_entropy_loss(nn.Softmax()(outputs[:, :-1]), targets)


RuntimeError: CUDA out of memory. Tried to allocate 48.00 MiB (GPU 0; 3.82 GiB total capacity; 2.18 GiB already allocated; 77.75 MiB free; 2.28 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
torch.cuda.empty_cache()