In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from datasets import load_dataset
from transformers import AutoModelForCausalLM, AutoModelForSeq2SeqLM, AutoTokenizer
from torch.utils.data import DataLoader
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
from torch.nn.functional import normalize
import torch.nn.functional as F




# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Teacher and student model names
teacher_model_name = "meta-llama/Llama-2-7b-hf"
student_model_name = "distilbert/distilgpt2"

# Load teacher and student models
teacher_model = AutoModelForCausalLM.from_pretrained(teacher_model_name).to(device)
student_model = AutoModelForCausalLM.from_pretrained(student_model_name).to(device)

Using device: cuda


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [2]:
class StudentWithLogitsProjection(nn.Module):
    def __init__(self, student_model, teacher_vocab_size):
        super().__init__()
        self.student_model = student_model
        self.projection = nn.Linear(student_model.config.vocab_size, teacher_vocab_size)

    def forward(self, input_ids, **kwargs):
        # Get student logits
        outputs = self.student_model(input_ids, **kwargs)
        logits = outputs.logits

        # Project logits to match teacher's vocabulary size
        projected_logits = self.projection(logits)
        return {"logits": projected_logits}


teacher_vocab_size = teacher_model.config.vocab_size
student_model = StudentWithLogitsProjection(student_model, teacher_vocab_size).to(device)


# Load tokenizers
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher_model_name)
student_tokenizer = AutoTokenizer.from_pretrained(student_model_name)
# Set padding tokens for tokenizers if not already defined
if teacher_tokenizer.pad_token is None:
    teacher_tokenizer.pad_token = teacher_tokenizer.eos_token
if student_tokenizer.pad_token is None:
    student_tokenizer.pad_token = student_tokenizer.eos_token

print("Models loaded successfully!")


Models loaded successfully!


In [3]:
log_file_path = "./llm_log.txt"
with open(log_file_path, "a") as log_file:
    log_file.write(f"============= NEW RUN =============\n")

In [4]:
dataset = load_dataset("tiny_shakespeare")

def split_by_newline(example):
    text = example["text"][0]
    lines = text.split("\n")
    return {"text": lines}
    
# Apply splitting
split_dataset = dataset.map(split_by_newline, batched=True, remove_columns=["text"])
split_dataset = split_dataset.flatten()  # Flatten nested lists into individual examples

# Tokenize the dataset
def tokenize_function(examples):
    teacher_tokens = teacher_tokenizer(
        examples["text"], truncation=True, padding="max_length", max_length=128
    )
    student_tokens = student_tokenizer(
        examples["text"], truncation=True, padding="max_length", max_length=128
    )
    return {
        "teacher_input_ids": teacher_tokens["input_ids"],
        "student_input_ids": student_tokens["input_ids"],
    }

tokenized_dataset = split_dataset.map(tokenize_function, batched=True, remove_columns=["text"])
tokenized_dataset.set_format(type="torch")

# Define a custom collate function for padding
def custom_collate_fn(batch):
    teacher_input_ids = [torch.tensor(example["teacher_input_ids"]) for example in batch]
    student_input_ids = [torch.tensor(example["student_input_ids"]) for example in batch]

    # Pad sequences
    teacher_input_ids = pad_sequence(teacher_input_ids, batch_first=True, padding_value=teacher_tokenizer.pad_token_id)
    student_input_ids = pad_sequence(student_input_ids, batch_first=True, padding_value=student_tokenizer.pad_token_id)

    return {
        "teacher_input_ids": teacher_input_ids,
        "student_input_ids": student_input_ids,
    }

# Prepare DataLoader
train_loader = DataLoader(
    tokenized_dataset["train"], 
    batch_size=8, 
    shuffle=True, 
    collate_fn=custom_collate_fn
)

# Verify the batch structure
for batch in train_loader:
    print("Teacher input shape:", batch["teacher_input_ids"].shape)  # Should be [batch_size, seq_len]
    print("Student input shape:", batch["student_input_ids"].shape)  # Should be [batch_size, seq_len]
    break


Teacher input shape: torch.Size([8, 128])
Student input shape: torch.Size([8, 128])


  teacher_input_ids = [torch.tensor(example["teacher_input_ids"]) for example in batch]
  student_input_ids = [torch.tensor(example["student_input_ids"]) for example in batch]


In [5]:
optimizer = optim.Adam(student_model.parameters(), lr=1e-4)

In [6]:
def train_one_epoch(teacher, student, dataloader, optimizer, device, temperature=1.0):
    """
    Train the student model using KL Divergence on logits with a progress bar.

    Args:
        teacher (nn.Module): The teacher model.
        student (nn.Module): The student model with logits projection.
        dataloader (DataLoader): Dataloader for training.
        optimizer (torch.optim.Optimizer): Optimizer for student model.
        device (torch.device): Device to perform computation on.
        temperature (float): Temperature for distillation.

    Returns:
        float: Average training loss for the epoch.
    """
    teacher.eval()  # Teacher is frozen
    student.train()
    total_loss = 0

    # Add tqdm for progress tracking
    for batch in tqdm(dataloader, desc="Training Epoch", leave=False):
        optimizer.zero_grad()

        # Move data to device
        student_input_ids = batch["student_input_ids"].to(device)
        teacher_input_ids = batch["teacher_input_ids"].to(device)

        # Get logits from teacher and student
        with torch.no_grad():
            teacher_logits = teacher(teacher_input_ids).logits / temperature  # Scaled teacher logits
        student_logits = student(student_input_ids)["logits"] / temperature  # Scaled student logits

        # Compute KL Divergence Loss
        loss = F.kl_div(
            F.log_softmax(student_logits, dim=-1),
            F.softmax(teacher_logits, dim=-1),
            reduction="batchmean"
        )

        # Backpropagation
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    return total_loss / len(dataloader)


In [14]:
def compute_cosine_similarity_logits(student, teacher, dataloader, device, data_fraction=1.0):
    student.eval()
    teacher.eval()
    total_similarity = 0
    count = 0

    # Limit the number of batches based on data_fraction
    total_batches = int(len(dataloader) * data_fraction)
    student = student.to(device)
    teacher = teacher.to(device)

    with torch.no_grad():
        for i, batch in enumerate(tqdm(dataloader, total=total_batches, desc="Computing Cosine Similarity (Logits)")):
            if i >= total_batches:
                break

            # Move data to device
            student_input_ids = batch["student_input_ids"].to(device)
            teacher_input_ids = batch["teacher_input_ids"].to(device)

            # Get logits from teacher and student
            teacher_logits = teacher(teacher_input_ids).logits
            student_logits = student(student_input_ids)["logits"]

            # Normalize and compute cosine similarity
            teacher_norm = normalize(teacher_logits, p=2, dim=-1)
            student_norm = normalize(student_logits, p=2, dim=-1)
            similarity = (teacher_norm * student_norm).sum(dim=-1).mean().item()
            total_similarity += similarity
            count += 1

    avg_similarity = total_similarity / count
    return avg_similarity


In [15]:
# Evaluate the student model
avg_cosine_similarity = compute_cosine_similarity_logits(
    student_model, teacher_model, train_loader, device, data_fraction=0.1
)

with open(log_file_path, "a") as log_file:
    log_file.write(f"Average Cosine Similarity Before Training: {avg_cosine_similarity:.4f}\n")


  teacher_input_ids = [torch.tensor(example["teacher_input_ids"]) for example in batch]
  student_input_ids = [torch.tensor(example["student_input_ids"]) for example in batch]
Computing Cosine Similarity (Logits): 100%|██████████| 444/444 [07:19<00:00,  1.01it/s]


In [16]:
import time
import torch
from tqdm import tqdm  # Import tqdm for progress bars

def benchmark_model(model, tokenizer, device, input_length=128, runs=100):
    model = model.to(device)
    model.eval()

    # Generate random tokenized input
    input_ids = torch.randint(
        low=0, high=tokenizer.vocab_size, size=(1, input_length), device=device
    )
    attention_mask = torch.ones_like(input_ids, device=device)

    # Warm-up runs
    for _ in tqdm(range(10), desc="Warm-up runs", leave=False):
        _ = model(input_ids=input_ids, attention_mask=attention_mask)

    # Benchmark
    if device.type == 'cuda':
        torch.cuda.synchronize()
    start_time = time.time()

    for _ in tqdm(range(runs), desc="Benchmarking", leave=False):
        _ = model(input_ids=input_ids, attention_mask=attention_mask)

    if device.type == 'cuda':
        torch.cuda.synchronize()
    elapsed_time = (time.time() - start_time) / runs
    return elapsed_time * 1000  # Convert to milliseconds

In [None]:
# Benchmark function
def benchmark_with_logging(teacher_model, student_model, teacher_tokenizer, student_tokenizer, file_path="benchmark_log.txt"):
    device_cpu = torch.device("cpu")
    device_cuda = torch.device("cuda") if torch.cuda.is_available() else None

    # Log results to file
    with open(file_path, "a") as log_file:
        log_file.write("Benchmarking before training...\n")
        
        # CPU Benchmarking
        llama_time_cpu = benchmark_model(teacher_model, teacher_tokenizer, device_cpu, 10)
        gpt_time_cpu = benchmark_model(student_model, student_tokenizer, device_cpu, 10)
        log_file.write(f"LLaMA Inference Time (CPU): {llama_time_cpu:.2f} ms\n")
        log_file.write(f"DistilGPT2 Inference Time (CPU): {gpt_time_cpu:.2f} ms\n")

        # CUDA Benchmarking
        if device_cuda:
            llama_time_cuda = benchmark_model(teacher_model, teacher_tokenizer, device_cuda)
            gpt_time_cuda = benchmark_model(student_model, student_tokenizer, device_cuda)
            log_file.write(f"LLaMA Inference Time (CUDA): {llama_time_cuda:.2f} ms\n")
            log_file.write(f"DistilGPT2 Inference Time (CUDA): {gpt_time_cuda:.2f} ms\n")
        else:
            log_file.write("CUDA is not available.\n")

benchmark_with_logging(teacher_model, student_model, teacher_tokenizer, student_tokenizer, log_file_path)

Benchmarking:  45%|████▌     | 45/100 [01:20<01:38,  1.79s/it]

In [None]:
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)

# Open a log file
with open(log_file_path, "a") as log_file:
    # Train the student model
    num_epochs = 10
    for epoch in range(num_epochs):
        log_file.write(f"Epoch {epoch+1}/{num_epochs}\n")
        train_loss = train_one_epoch(teacher_model, student_model, train_loader, optimizer, device)
        scheduler.step()
        log_file.write(f"Training Loss: {train_loss:.4f}\n")
        avg_cosine_similarity = evaluate(student_model, teacher_model, train_loader, device)
        log_file.write(f"Average Cosine Similarity for Epoch {epoch+1}: {avg_cosine_similarity:.4f}\n")

In [None]:
save_path = "./distillgpt_student.pth"
torch.save(student_model.state_dict(), save_path)
print(f"DistillGPT student model saved to {save_path}")