<a href="https://colab.research.google.com/github/Loki-33/distillX/blob/main/Distillation_StepByStep.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from transformers import BertTokenizer, BertModel

In [None]:
from transformers import BertForSequenceClassification

In [None]:
teacher_model = BertForSequenceClassification.from_pretrained('distilbert-base-uncased', num_labels=2)
student_model  = BertForSequenceClassification.from_pretrained('prajjwal1/bert-tiny', num_labels=2)

In [None]:
import datasets

In [None]:
dataset = datasets.load_dataset('sst2', split='train')

In [None]:
dataset

In [None]:
dataset[4]

In [None]:
tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

In [None]:
if tokenizer.pad_token is None:
  tokenizer.pad_token = tokenizer.eos_token

In [None]:
tokenizer.pad_token

In [None]:
dataset.column_names

In [None]:
dataset = dataset.remove_columns('idx')

In [None]:
tokenized_dataset = dataset.map(lambda x: tokenizer(x['sentence']), batched=True)

In [None]:
tokenized_dataset[0]

In [None]:
def collate_fn(examples):
  inputs = [torch.tensor(x['input_ids']) for x in examples]
  labels = torch.tensor([x['label'] for x in examples])
  attention_mask = [torch.tensor(x['attention_mask']) for x in examples]

  input_padding = torch.nn.utils.rnn.pad_sequence(inputs, batch_first=True, padding_value=tokenizer.pad_token_id)
  attention_padding = torch.nn.utils.rnn.pad_sequence(attention_mask, batch_first=True, padding_value=0)
  return {
      'input_ids': input_padding,
      'labels': labels,
      'attention_mask': attention_padding
  }

In [None]:
from torch.utils.data import DataLoader

In [None]:
train_dataloader = DataLoader(tokenized_dataset, batch_size=8, shuffle=True, collate_fn=collate_fn)

In [None]:
next(iter(train_dataloader))

In [None]:
import torch.nn.functional as F

In [None]:
optimizer = optim.Adam(student_model.parameters(), lr=1e-5)

In [None]:
for epoch in range(100):
  tloss = 0
  for batch in train_dataloader:
    input_ids = batch['input_ids']
    attention_mask = batch['attention_mask']
    labels = batch['labels']

    with torch.no_grad():
      teacher_logits = teacher_model(input_ids, attention_mask).logits
    student_logits = student_model(input_ids, attention_mask).logits
    T = 2.0
    soft_teacher_probs= F.softmax(teacher_logits / T, dim=-1)
    log_soft_student_probs = F.log_softmax(student_logits / T, dim=-1)

    kl_loss = F.kl_div(log_soft_student_probs, soft_teacher_probs, reduction='batchmean') * (T ** 2)
    ce_loss = F.cross_entropy(student_logits, labels, ignore_index=tokenizer.pad_token_id)
    alpha = 0.5

    loss = alpha * ce_loss + (1 - alpha) * kl_loss
    tloss += loss.item()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
  print(f'Epoch: {epoch+1}, Loss: {tloss:.4f}')