In [1]:
import openai
import torch
import os
from dotenv import load_dotenv
from openai import OpenAI
from datasets import load_dataset
from transformers import GPT2Tokenizer
from transformers import GPT2LMHeadModel, Trainer, TrainingArguments
from torch.nn import functional as F
from torch.utils.data import DataLoader

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

load_dotenv()

client = OpenAI(api_key=os.environ.get('OPENAI_KEY'), organization=os.environ.get('ORGANIZATION'))

bs = 256

In [2]:
# dataset, tokenizer
dataset = load_dataset("stanfordnlp/snli", split='train')
tokenizer = GPT2Tokenizer.from_pretrained("gpt2-large")

In [3]:
# tokenizer, pad token

tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
    # Join the premise and hypothesis for each example
    return tokenizer(
        [f"{premise} {hypothesis}" for premise, hypothesis in zip(examples["premise"], examples["hypothesis"])],
        padding="max_length",
        truncation=True,
        max_length=128
    )

tokenized_datasets = dataset.map(tokenize_function, batched=True)

In [4]:
# teacher output

def get_gpt4_logits(premise, hypothesis):
    prompt = f"Premise: {premise}\nHypothesis: {hypothesis}\nLabel:"
    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "user", "content": prompt}
        ],
        logprobs=True,
        top_logprobs=3,
        max_tokens=30
    )
    # Access the logprobs directly from the response
    logprobs = response.choices[0].logprobs.content

    # Extract log probabilities for each token
    token_logprobs = {token_logprob.token: token_logprob.logprob for token_logprob in logprobs}

    # Returning the log probabilities for the first label or empty if not available
    return token_logprobs if token_logprobs else {}

In [5]:
# student model init
student_model = GPT2LMHeadModel.from_pretrained("gpt2-large")
student_model.to(device)

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1280)
    (wpe): Embedding(1024, 1280)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-35): 36 x GPT2Block(
        (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2SdpaAttention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=1280, out_features=50257, bias=False)
)

In [6]:
# distillation loss (cross-entropy + kullback-leibler)
def distillation_loss(student_logits, teacher_logits, ground_truth, alpha=0.5, temperature=2.0):
    # Ensure ground_truth is in the correct shape
    if ground_truth.dim() != 1:
        ground_truth = ground_truth.view(-1)  # Flatten if not already 1D

    # Calculate the cross-entropy loss
    loss_ce = F.cross_entropy(student_logits.view(-1, student_logits.size(-1)), ground_truth)

    # Calculate the KL divergence loss
    loss_kl = F.kl_div(
        F.log_softmax(student_logits / temperature, dim=1),
        F.softmax(teacher_logits / temperature, dim=1),
        reduction='batchmean'
    ) * (temperature ** 2)

    # Combine the losses
    return alpha * loss_ce + (1 - alpha) * loss_kl

In [7]:
def custom_collate_fn(batch):
    # Extract premises and hypotheses
    premises = [item['premise'] for item in batch]
    hypotheses = [item['hypothesis'] for item in batch]
    labels = [item['label'] for item in batch]

    # Tokenize with padding and truncation
    encoding = tokenizer(
        [f"{premise} {hypothesis}" for premise, hypothesis in zip(premises, hypotheses)],
        padding="max_length",  # Pad to the longest sequence in the batch
        max_length=256,
        truncation=True,
        return_tensors='pt'
    )

    return {
        'input_ids': encoding['input_ids'],
        'attention_mask': encoding['attention_mask'],
        'premise': premises,
        'hypothesis': hypotheses,
        'label': labels
    }

# something is wrong with the training loop, nothing is ever getting sent to GPU

In [8]:
# custom trainer

class KnowledgeDistillationTrainer:
    def __init__(self, student_model, tokenizer):
        self.student_model = student_model
        self.tokenizer = tokenizer

    def train(self, train_loader, num_epochs):
        optimizer = torch.optim.AdamW(self.student_model.parameters())
        self.student_model.train()

        for epoch in range(num_epochs):
            print(f"Epoch {epoch + 1}/{num_epochs}")
            total_loss = 0
            for batch in train_loader:
                # Accessing keys
                premise = batch['premise']
                hypothesis = batch['hypothesis']
                ground_truth = batch['label']  

                # Get teacher logits
                teacher_logits = []
                for p, h in zip(premise, hypothesis):
                    teacher_logits.append(get_gpt4_logits(p, h))
                teacher_logits = torch.tensor(teacher_logits).to(device)

                # Tokenize inputs for the student model
                inputs = {
                    'input_ids': batch['input_ids'].to(device),
                    'attention_mask': batch['attention_mask'].to(device),
                    'labels': batch['input_ids'].to(device)
                }

                # Forward pass through the student model
                outputs = self.student_model(**inputs)
                student_logits = outputs.logits

                # Compute the loss
                loss = distillation_loss(student_logits.view(-1, student_logits.size(-1)), teacher_logits, ground_truth).to(device)

                # Backpropagation
                loss.backward()
                optimizer.step()
                optimizer.zero_grad()

                total_loss += loss.item()
                print(f"Batch loss: {loss.item()}")
            
            avg_loss = total_loss / len(train_loader)
            print(f"Average loss for epoch {epoch + 1}: {avg_loss:.4f}")

In [11]:
# create trainer
trainer = KnowledgeDistillationTrainer(student_model, tokenizer)
train_loader = DataLoader(tokenized_datasets, batch_size=bs, shuffle=True)

In [12]:
# run trainer
trainer.train(train_loader, num_epochs=5)

Epoch 1/5


KeyboardInterrupt: 