In [10]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [1]:
print("import torch")
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F

print("import lightning")
import lightning as pl
from lightning.pytorch.callbacks import LearningRateMonitor, Callback
from lightning.pytorch import Trainer, LightningModule

print("import transformers")
from transformers.models.gpt2 import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config

print("import datasets")
from datasets import load_dataset

from collections import OrderedDict

print("import model")
import model
import knn

import torch
import lightning
import transformers
import datasets
import model


In [2]:
def map_gpt2_to_infostill(key):
	match key.split('.'):
		case ['wte', x]: return f'embed.{x}'
		case ['h', i, 'ln_1', x]:
			return f"layers.{i}.prenorm.{x}"
		case ['h', i, 'attn', *x]:
			return f"layers.{i}.{'.'.join(x)}"
		case ['ln_f', x]: return f'postnorm.{x}'
		
		# h.?.ln_2, mlp
		case _: return None

class WikiText103DataModule(pl.LightningDataModule):
	def __init__(self, tokenized_dataset, batch_size=8):
		super().__init__()
		self.tokenized_dataset = tokenized_dataset
		self.batch_size = batch_size

	def train_dataloader(self):
		return DataLoader(self.tokenized_dataset['train'], batch_size=self.batch_size, shuffle=True)

	def val_dataloader(self):
		return DataLoader(self.tokenized_dataset['validation'], batch_size=self.batch_size)

class KnowledgeDistillationModule(LightningModule):
	def __init__(self, teacher, student, learning_rate=3e-5, alpha=0.5, temperature=2):
		super().__init__()
		self.teacher = teacher
		self.student = student
		self.learning_rate = learning_rate
		self.alpha = alpha
		self.temperature = temperature
		self.feedback = None

		self.criterion = nn.KLDivLoss(reduction='batchmean')
		self.ce_loss = nn.CrossEntropyLoss()

	def forward(self, input_ids, attention_mask):
		with torch.no_grad():
			teacher_logits = self.teacher(
				input_ids=input_ids,
				attention_mask=attention_mask
			).logits
		
		student_logits, hidden = self.student(
			input_ids=input_ids,
			attention_mask=attention_mask,
			feedback=self.feedback,
			output_hidden_states=True
		)
		print(hidden.shape)
		self.feedback = hidden
		return teacher_logits, student_logits

	def training_step(self, batch, batch_idx):
		input_ids, attention_mask, labels = batch
		teacher_logits, student_logits = self(input_ids, attention_mask)

		loss_kl = self.criterion(
			F.log_softmax(student_logits / self.temperature, dim=-1),
			F.softmax(teacher_logits / self.temperature, dim=-1),
		)
		loss_ce_teacher = self.ce_loss(teacher_logits, labels)
		loss_ce_student = self.ce_loss(student_logits, labels)
		
		# (1 - alpha) * KL + alpha * CE, setting alpha = 0.5
		loss = (loss_kl + loss_ce_student) / 2
		
		self.log('train_loss', {
			"combined": loss,
			"kl": loss_kl,
			"ce_teacher": loss_ce_teacher,
			"ce_student": loss_ce_student
		})

		return loss

	def configure_optimizers(self):
		optimizer = optim.AdamW(self.student.parameters(), lr=self.learning_rate)
		scheduler = optim.lr_scheduler.ReduceLROnPlateau(
			optimizer, max_lr=self.learning_rate,
			steps_per_epoch=100, epochs=self.trainer.max_epochs
		)
		return [optimizer], [scheduler]

class FrequentCheckpoint(Callback):
	def __init__(self, save_steps: int, output_dir: str):
		super().__init__()
		self.save_steps = save_steps
		self.output_dir = output_dir

	def on_batch_end(self, trainer: Trainer, pl_module: LightningModule):
		global_step = trainer.global_step
		if global_step % self.save_steps == 0:
			ckpt_path = os.path.join(self.output_dir, f"checkpoint_step_{global_step}.ckpt")
			trainer.save_checkpoint(ckpt_path)
			print(f"Checkpoint saved at step {global_step}: {ckpt_path}")

In [3]:
teacher_name = "gpt2-large"

print("Loading teacher")

teacher = GPT2LMHeadModel.from_pretrained(teacher_name)
tokenizer = GPT2Tokenizer.from_pretrained(teacher_name)
config = teacher.config

Loading teacher


In [11]:
print("Building student")

student = model.InfoDistillery(
    **config.__dict__,
    memory=knn.KNNMemory("orin", config.n_embd)
)
student_state = OrderedDict()

print("Transfer weights")

for teacher_key, teacher_weight in teacher.state_dict().items():
    student_key = map_gpt2_to_infostill(teacher_key)
    if student_key is not None:
        student_state[student_key] = teacher_weight

student.load_state_dict(student_state)

RPC error: [create_collection], <MilvusException: (code=1, message=multiple vector fields is not supported, fields name: key, value)>, <Time:{'RPC start': '2023-04-05 14:18:03.553008', 'RPC error': '2023-04-05 14:18:03.554149'}>


Building student


In [8]:
pl.seed_everything(42)
block_size = config.n_positions

print("Loading WikiText103")

dataset = load_dataset('wikitext', 'wikitext-103-raw-v1')

print("Tokenizing dataset")

dataset = [
    tokenizer(batch,
        truncation=True, padding='max_length',
        max_length=block_size, return_tensors='pt',
        
    ) for batch in dataset
]

Global seed set to 42


Loading WikiText103


Found cached dataset wikitext (/home/consciouscode/.cache/huggingface/datasets/wikitext/wikitext-103-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)


  0%|          | 0/3 [00:00<?, ?it/s]

Using pad_token, but it is not set yet.


Tokenizing dataset


In [7]:
print("Begin training")

trainer = pl.Trainer(gpus=1, max_epochs=3, progress_bar_refresh_rate=20)
trainer.fit(
    KnowledgeDistillationModule(teacher, student),
    datamodule=WikiText103DataModule(dataset),
    callbacks=[
        LearningRateMonitor(logging_interval='step'),
        FrequentCheckpoint(save_steps=1000, output_dir="checkpoints")
    ]
)

Begin training
