In [22]:
pip install transformers



In [23]:
import torch
import torch.nn as nn
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from torch.utils.data import DataLoader, TensorDataset
from transformers import AdamW

teacher_model = GPT2LMHeadModel.from_pretrained("gpt2")
teacher_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

student_model = GPT2LMHeadModel.from_pretrained("gpt2")
student_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

student_tokenizer.add_special_tokens({'pad_token': '[PAD]'})
teacher_tokenizer.add_special_tokens({'pad_token': '[PAD]'})

def distillation_loss(teacher_logits, student_logits):
    return nn.KLDivLoss()(nn.functional.log_softmax(student_logits, dim=1), nn.functional.softmax(teacher_logits, dim=1))

batch_size = 4
learning_rate = 1e-4
num_epochs = 5

input_data = ["The Industrial Revolution was a period of significant economic, technological, and social change that began in the late 18th century and continued into the 19th century. It marked a shift from agrarian and handicraft-based economies to industrial and machine-based economies. This period saw the rapid development of factories, mechanized agriculture, and the use of steam power. It had a profound impact on society, leading to urbanization, changes in labor practices, and increased production. The Industrial Revolution is often considered a turning point in history."]
labels = ["The Industrial Revolution, which started in the late 18th century, brought about significant economic, technological, and social changes. It led to the rise of industrial economies, the use of machinery, and urbanization."]

input_ids = student_tokenizer.encode(input_data[0], return_tensors="pt", max_length=50, padding="max_length", truncation=True)
label_ids = student_tokenizer.encode(labels[0], return_tensors="pt", max_length=50, padding="max_length", truncation=True)

dataset = TensorDataset(input_ids, label_ids)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

optimizer = AdamW(student_model.parameters(), lr=learning_rate)

for epoch in range(num_epochs):
    student_model.train()
    for batch in dataloader:
        optimizer.zero_grad()
        input_ids_batch, label_ids_batch = batch
        student_logits = student_model(input_ids_batch).logits
        teacher_logits = teacher_model(input_ids_batch).logits

        loss = distillation_loss(teacher_logits, student_logits)
        loss.backward()
        optimizer.step()

student_model.save_pretrained("student_model")



In [24]:
input_text = "Once upon a time, in a land far, far away, "
input_ids = student_tokenizer.encode(input_text, return_tensors="pt")

generated_text = student_model.generate(input_ids, max_length=100, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, top_p=0.95)
generated_text = student_tokenizer.decode(generated_text[0], skip_special_tokens=True)
print(generated_text)

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


Once upon a time, in a land far, far away,  I was in the middle of a great deal of trouble. I was a little bit of an idiot.
I had a lot of problems with my life. And I had problems. But I didn't have a problem with the world. It was all about the things that I wanted to do. So I just wanted it to be all right. That was the way I got it. The way. You know, I
