In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader
from transformers import AdamW

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# teacher model
teacher_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-7b1")
teacher_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-7b1")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
# student model
student_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2-large")
student_model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2-large")

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/666 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/3.25G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [4]:
teacher_model.eval()

BloomForCausalLM(
  (transformer): BloomModel(
    (word_embeddings): Embedding(250880, 4096)
    (word_embeddings_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
    (h): ModuleList(
      (0-29): 30 x BloomBlock(
        (input_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (self_attention): BloomAttention(
          (query_key_value): Linear(in_features=4096, out_features=12288, bias=True)
          (dense): Linear(in_features=4096, out_features=4096, bias=True)
          (attention_dropout): Dropout(p=0.0, inplace=False)
        )
        (post_attention_layernorm): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
        (mlp): BloomMLP(
          (dense_h_to_4h): Linear(in_features=4096, out_features=16384, bias=True)
          (gelu_impl): BloomGelu()
          (dense_4h_to_h): Linear(in_features=16384, out_features=4096, bias=True)
        )
      )
    )
    (ln_f): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)
  )


In [5]:
dataset = load_dataset("openai/gsm8k", "socratic", split='train')

Downloading readme:   0%|          | 0.00/7.94k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/2.68M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/487k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/7473 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1319 [00:00<?, ? examples/s]

In [7]:
def preprocess_data(examples):
    return teacher_tokenizer(examples['question'], truncation=True, padding='max_length', max_length=512)

tokenized_dataset = dataset.map(preprocess_data, batched=True, remove_columns=["question"])
tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])

Map:   0%|          | 0/7473 [00:00<?, ? examples/s]

In [8]:
# Create DataLoader
train_dataloader = DataLoader(tokenized_dataset, batch_size=8, shuffle=True)

# Define optimizer
optimizer = AdamW(student_model.parameters(), lr=5e-5)

# Knowledge Distillation Loss (KL Divergence)
def distillation_loss(y, teacher_scores, temperature):
    return torch.nn.KLDivLoss()(torch.nn.functional.log_softmax(y / temperature, dim=-1),
                                torch.nn.functional.softmax(teacher_scores / temperature, dim=-1)) * (temperature ** 2)



In [9]:
def train(teacher_model, student_model, dataloader, optimizer, device, num_epochs=3, temperature=2.0):
    teacher_model.to(device)
    student_model.to(device)
    
    for epoch in range(num_epochs):
        student_model.train()
        epoch_loss = 0.0
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            with torch.no_grad():
                teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits
            
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_logits = student_outputs.logits
            
            loss = distillation_loss(student_logits, teacher_logits, temperature)
            loss.backward()
            
            optimizer.step()
            optimizer.zero_grad()
            
            epoch_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss/len(dataloader)}")

In [10]:
train(teacher_model, student_model, train_dataloader, optimizer, device)

KeyboardInterrupt: 