In [3]:
from transformers import AutoTokenizer, AutoModelForCausalLM
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader
from transformers import AdamW

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [4]:
# teacher model
teacher_tokenizer = AutoTokenizer.from_pretrained("bigscience/bloom-7b1")
teacher_model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-7b1")

tokenizer_config.json:   0%|          | 0.00/222 [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/14.5M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/85.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/739 [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/28.9k [00:00<?, ?B/s]

Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/9.98G [00:00<?, ?B/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/4.16G [00:00<?, ?B/s]

: 

In [None]:
# student model
student_tokenizer = AutoTokenizer.from_pretrained("openai-community/gpt2-large")
student_model = AutoModelForCausalLM.from_pretrained("openai-community/gpt2-large")

In [None]:
teacher_model.eval()

In [None]:
dataset = load_dataset('wikitext', 'wikitext-103-v1', split='train')

In [None]:
def preprocess_data(examples):
    return teacher_tokenizer(examples['text'], truncation=True, padding='max_length', max_length=512)

tokenized_dataset = dataset.map(preprocess_data, batched=True, remove_columns=["text"])
tokenized_dataset.set_format(type='torch', columns=['input_ids', 'attention_mask'])

In [None]:
# Create DataLoader
train_dataloader = DataLoader(tokenized_dataset, batch_size=8, shuffle=True)

# Define optimizer
optimizer = AdamW(student_model.parameters(), lr=5e-5)

# Knowledge Distillation Loss (KL Divergence)
def distillation_loss(y, teacher_scores, temperature):
    return torch.nn.KLDivLoss()(torch.nn.functional.log_softmax(y / temperature, dim=-1),
                                torch.nn.functional.softmax(teacher_scores / temperature, dim=-1)) * (temperature ** 2)

In [None]:
def train(teacher_model, student_model, dataloader, optimizer, device, num_epochs=3, temperature=2.0):
    teacher_model.to(device)
    student_model.to(device)
    
    for epoch in range(num_epochs):
        student_model.train()
        epoch_loss = 0.0
        for batch in dataloader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            
            with torch.no_grad():
                teacher_outputs = teacher_model(input_ids=input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_outputs.logits
            
            student_outputs = student_model(input_ids=input_ids, attention_mask=attention_mask)
            student_logits = student_outputs.logits
            
            loss = distillation_loss(student_logits, teacher_logits, temperature)
            loss.backward()
            
            optimizer.step()
            optimizer.zero_grad()
            
            epoch_loss += loss.item()
        
        print(f"Epoch {epoch+1}/{num_epochs} - Loss: {epoch_loss/len(dataloader)}")

In [None]:
train(teacher_model, student_model, train_dataloader, optimizer, device)