In [None]:
!pip install datasets transformers scikit-learn numpy matplotlib accelerate

In [None]:
import pandas as pd
import torch
from transformers import (
    AutoTokenizer,
)
from datasets import load_dataset
import torch
import torch.nn.functional as F

import lightning.pytorch as pl
from pytorch_lightning import Trainer
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
)
from datasets import load_dataset
from transformers import DataCollatorWithPadding

In [None]:
import warnings

warnings.filterwarnings("ignore")

In [None]:
# Config
MODEL_TEACHER = "google-bert/bert-base-uncased"
MODEL_STUDENT = "prajjwal1/bert-small"
DATASET_NAME = "SetFit/20_newsgroups"

In [None]:
class LitDataModule(pl.LightningDataModule):
    def __init__(
        self,
        tokenizer: AutoTokenizer,
        dataset_name: str,
        max_length: int,
        batch_size: int,
        split_min_max=(3, 256),
    ):
        super().__init__()
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.batch_size = batch_size
        self.split_min_max = split_min_max
        self.dataset_name = dataset_name
        self._prepare_data()

    def _prepare_data(self):
        dataset = load_dataset(self.dataset_name)
        # filter dataset to include text > 3 tokens and < 256 tokens
        self.dataset = dataset.filter(
            lambda x: len(x["text"].split()) > self.split_min_max[0]
            and len(x["text"].split()) < self.split_min_max[1]
        )
        self.train_data = None
        self.test_data = None

    def tokenize(self, example):
        return self.tokenizer(
            example["text"],
            truncation=True,
            max_length=self.max_length,
        )

    def setup(self, stage=None):
        if self.train_data is not None and self.test_data is not None:
            return
        self.train_data = self.dataset["train"]
        self.test_data = self.dataset["test"]
        
        self.train_data = self.train_data.map(self.tokenize)
        self.test_data = self.test_data.map(self.tokenize)

        self.train_data.set_format(
            type="torch", columns=["input_ids", "attention_mask", "label"]
        )
        self.test_data.set_format(
            type="torch", columns=["input_ids", "attention_mask", "label"]
        )

    def train_dataloader(self):
        return torch.utils.data.DataLoader(
            self.train_data,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=DataCollatorWithPadding(self.tokenizer),
        )

    def test_dataloader(self):
        return torch.utils.data.DataLoader(
            self.test_data,
            batch_size=self.batch_size,
            collate_fn=DataCollatorWithPadding(self.tokenizer),
        )

In [None]:
class LitTrainingLogic(pl.LightningModule):
    def __init__(self, student, teacher=None, kd_mode=None):
        super().__init__()
        self.teacher = teacher
        self.student = student
        self.kd_mode = kd_mode
    
    def forward(self, input_ids, attention_mask):
        return self.student(input_ids, attention_mask=attention_mask)
    
    def kl_loss(self, student_logits, teacher_logits):
        if self.kd_mode == "forward":
            teacher_probs = F.softmax(teacher_logits, dim=-1)
            student_log_probs = F.log_softmax(student_logits, dim=-1)
            return F.kl_div(student_log_probs, teacher_probs, reduction="batchmean")
        else:  # Reverse KL
            student_probs = F.softmax(student_logits, dim=-1)
            teacher_log_probs = F.log_softmax(teacher_logits, dim=-1)
            return F.kl_div(teacher_log_probs, student_probs, reduction="batchmean")
    
    def normal_training_step(self, batch, batch_idx):
        student_out = self.student(
            input_ids=batch["input_ids"],
            attention_mask=batch["attention_mask"],
            labels=batch["labels"],
        )
        return student_out.logits, student_out.loss
    
    def training_step(self, batch, batch_idx):
        if self.kd_mode is None:
            return self.normal_training_step(batch, batch_idx)[1]
        else:
            raise NotImplementedError("KL-based training not implemented yet")

    def test_step(self, batch, batch_idx):
        preds = self(batch["input_ids"], batch["attention_mask"]).logits
        acc = (preds.argmax(-1) == batch["labels"]).float().mean()
        self.log("test_acc", acc, prog_bar=True)
    
    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(), lr=5e-5)
        return optimizer


In [None]:
large_model = AutoModelForSequenceClassification.from_pretrained(MODEL_TEACHER, num_labels=20)
small_model = AutoModelForSequenceClassification.from_pretrained(MODEL_STUDENT, num_labels=20)
large_tokenizer = AutoTokenizer.from_pretrained(MODEL_TEACHER)
small_tokenizer = AutoTokenizer.from_pretrained(MODEL_STUDENT)

In [None]:
lit_dm_large = LitDataModule(
    tokenizer=large_tokenizer,
    dataset_name=DATASET_NAME,
    max_length=256,
    batch_size=32,
)

In [None]:
MAX_EPOCHS = 4


lit_model_teacher = LitTrainingLogic(
    student=large_model,
    kl_mode=None
)

trainer = pl.Trainer(
    max_epochs=2,
    accelerator="auto",
    log_every_n_steps=50,
)


In [None]:
trainer.fit(lit_model_teacher, lit_dm_large)

In [None]:
lit_dm_large = LitDataModule(
    tokenizer=large_tokenizer,
    dataset_name=DATASET_NAME,
    max_length=256,
    batch_size=32,
)

In [None]:
trainer.test(lit_model_teacher, lit_dm_large)

In [None]:
large_model.push_to_hub("bert-base-uncased-20-newsgroup")
large_tokenizer.push_to_hub("bert-base-uncased-20-newsgroup")