In [18]:
%load_ext autoreload
%autoreload 2
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [2]:
print("import torch")
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F

print("import lightning")
import lightning as pl
from lightning.pytorch.callbacks import LearningRateMonitor, Callback
from lightning.pytorch import Trainer, LightningModule

print("import transformers")
from transformers.models.gpt2 import GPT2LMHeadModel, GPT2TokenizerFast, GPT2Config

print("import datasets")
from datasets import load_dataset

from collections import OrderedDict

print("import model")
import model
import knn
print("Done")

import torch
import lightning
import transformers
import datasets
import model
Done


In [5]:
def map_gpt2_to_infostill(key):
	pre, *key = teacher_key.split(".")
	if pre != "transformer":
		return None
	
	match key:
		case ['wte', 'weight']:
			return "token_embed.weight"
		case ["ln_f", "weight"]:
			return "postnorm.weight"
		case ['h', l, "attn", *_]:
			layer = l
		case _:
			return None
	
	match key[3:]:
		# ln_2 is FF layer norm so skip
		case ["ln_1", wb]:
			key = f"prenorm.{wb}"
		
		# Skip mlp layer and biases
		case ["c_attn", "weight"]:
			key = f"qkv_proj.weight"
		case ["c_proj", "weight"]:
			key = f"out_proj.weight"
		
		# Causal mask bias
		case ["bias"]:
			key = "attention.bias"
		
		case _:
			return None
	
	return f"layers.{layer}.{key}"


print("Building student")

student = model.InfoDistilleryGPT2Model(config)
student_state = OrderedDict()

print("Transfer weights")

teacher_name = "gpt2"
for teacher_key, teacher_weight in teacher.state_dict().items():
    student_key = map_gpt2_to_infostill(teacher_key)
    if student_key is None:
        continue
    if 'proj' in student_key:
        teacher_weight = teacher_weight.T
    student_state[student_key] = teacher_weight

student.load_state_dict(student_state, strict=False)
print("Done")

print("Loading teacher")

teacher = GPT2LMHeadModel.from_pretrained(teacher_name)
tokenizer = GPT2TokenizerFast.from_pretrained(teacher_name)
config = teacher.config
print("Done")

Loading teacher


Downloading (…)olve/main/vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading (…)olve/main/merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Done


In [35]:
class KnowledgeDistillationModel(pl.LightningModule):
    def __init__(self, teacher_model, student_model, memory, tokenizer, temperature=2.0):
        super().__init__()
        self.teacher_model = teacher_model
        self.student_model = student_model
        self.memory = memory
        self.feedback = None
        self.tokenizer = tokenizer
        self.temperature = temperature
        self.distill_loss = nn.KLDivLoss(reduction="batchmean")
        self.ce_loss = nn.CrossEntropyLoss()

    def forward(self, input_ids, attention_mask):
        return self.student_model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            memory=self.memory,
            feedback=self.feedback
        )

    def training_step(self, batch, batch_idx):
        input_ids, attention_mask = batch["input_ids"], batch["attention_mask"]
        print("input_ids", input_ids)
        labels = input_ids
        
        # Teacher model output
        with torch.no_grad():
            teacher_logits = self.teacher_model(input_ids=input_ids, attention_mask=attention_mask).logits

        # Student model output
        student_logits, hidden = self(input_ids, attention_mask)
        print("hidden shape", hidden.shape)
        self.feedback = hidden[1:] + [None]

        # Calculate distillation loss
        distill_loss = self.distill_loss(
            F.log_softmax(student_logits / self.temperature, dim=-1),
            F.softmax(teacher_logits / self.temperature, dim=-1)
        )
        student_loss = self.ce_loss(student_logits, labels)
        teacher_loss = self.ce_loss(teacher_logits, labels)
        
        loss = distill_loss + student_loss
        
        self.log("train_loss", {
            "combined": loss,
            "distill": distill_loss,
            "ce_teacher": loss_ce_teacher,
            "ce_student": loss_ce_student
        })
        return loss

    def configure_optimizers(self):
        optimizer = optim.Adam(self.student_model.parameters(), lr=1e-4)
        return optimizer

pl.seed_everything(42)
block_size = config.n_positions
dataset = load_dataset("ag_news", 'all')['train']

# Tokenize and collate dataset
def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=512)

tokenized_dataset = dataset.map(tokenize_function, batched=True)
dataloader = DataLoader(tokenized_dataset, batch_size=8)    

memory = knn.KNNMemory("orin", config.n_embd, config.n_embd)
print("Done")

Global seed set to 42
Found cached dataset ag_news (/home/consciouscode/.cache/huggingface/datasets/ag_news/all/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548)


  0%|          | 0/2 [00:00<?, ?it/s]

Loading cached processed dataset at /home/consciouscode/.cache/huggingface/datasets/ag_news/all/0.0.0/bc2bcb40336ace1a0374767fc29bb0296cdaf8a6da7298436239c54d79180548/cache-7ad8216966f98598.arrow


Done


In [36]:
print("Begin training")

# Train the model
model = KnowledgeDistillationModel(teacher, student, memory, tokenizer)
trainer = pl.Trainer(accelerator="gpu", max_epochs=3, log_every_n_steps=8, callbacks=[
    LearningRateMonitor(logging_interval='step'),
    FrequentCheckpoint(save_steps=1000, output_dir="checkpoints")
])
trainer.fit(model, dataloader)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type                    | Params
----------------------------------------------------------
0 | teacher_model | GPT2LMHeadModel         | 124 M 
1 | student_model | InfoDistilleryGPT2Model | 66.9 M
2 | distill_loss  | KLDivLoss               | 0     
3 | ce_loss       | CrossEntropyLoss        | 0     
----------------------------------------------------------
191 M     Trainable params
384       Non-trainable params
191 M     Total params
765.476   Total estimated model params size (MB)


Begin training


Training: 0it [00:00, ?it/s]

Exception ignored in: <function _ConnectionBase.__del__ at 0x7fa42952f880>
Traceback (most recent call last):
  File "/home/consciouscode/anaconda3/lib/python3.10/multiprocessing/connection.py", line 132, in __del__
    self._close()
  File "/home/consciouscode/anaconda3/lib/python3.10/multiprocessing/connection.py", line 361, in _close
    _close(self._handle)
OSError: [Errno 9] Bad file descriptor
