In [None]:
!pip install -q transformers accelerate bitsandbytes datasets peft

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AdamW
from torch.utils.data import DataLoader, Dataset
import pandas as pd
from tqdm import tqdm

In [None]:

TEACHER_ID = "Qwen/Qwen2.5-72B-Instruct"
STUDENT_ID = "Qwen/Qwen2.5-1.5B-Instruct"
TEMPERATURE = 2.0
ALPHA = 0.5
MAX_LENGTH = 512
BATCH_SIZE = 2
DEVICE = "cuda"


bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_quant_type="nf4"
)

print("Loading Teacher Model (4-bit)...")
teacher_tokenizer = AutoTokenizer.from_pretrained(TEACHER_ID)
teacher_model = AutoModelForCausalLM.from_pretrained(
    TEACHER_ID,
    quantization_config=bnb_config,
    device_map="auto"
)
teacher_model.eval()

print("Loading Student Model...")
student_tokenizer = AutoTokenizer.from_pretrained(STUDENT_ID)
student_model = AutoModelForCausalLM.from_pretrained(
    STUDENT_ID,
    torch_dtype=torch.bfloat16
).to(DEVICE)
student_model.train()

def distillation_loss_fn(student_logits, teacher_logits, labels):
    # Soften logits using Temperature
    soft_targets = F.softmax(teacher_logits / TEMPERATURE, dim=-1)
    soft_prob = F.log_softmax(student_logits / TEMPERATURE, dim=-1)

    # KL Divergence + standard Cross-Entropy
    distill_loss = F.kl_div(soft_prob, soft_targets, reduction='batchmean') * (TEMPERATURE**2)
    student_loss = F.cross_entropy(student_logits.view(-1, student_logits.size(-1)), labels.view(-1))

    return (ALPHA * distill_loss) + ((1 - ALPHA) * student_loss)

class GoldenDataset(Dataset):
    def __init__(self, csv_path, tokenizer):
        self.data = pd.read_csv(csv_path)
        self.tokenizer = tokenizer

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data.iloc[idx]['Text Chunk']
        encodings = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=MAX_LENGTH,
            return_tensors="pt"
        )
        return encodings['input_ids'].squeeze(), encodings['attention_mask'].squeeze()


def run_distillation(csv_file):
    dataset = GoldenDataset(csv_file, student_tokenizer)
    loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    optimizer = AdamW(student_model.parameters(), lr=5e-5)

    print(f"Starting distillation on {len(dataset)} samples...")
    for epoch in range(3):
        total_loss = 0
        for batch in tqdm(loader):
            input_ids, attention_mask = [b.to(DEVICE) for b in batch]


            with torch.no_grad():
                teacher_output = teacher_model(input_ids, attention_mask=attention_mask)
                teacher_logits = teacher_output.logits


            student_output = student_model(input_ids, attention_mask=attention_mask)
            student_logits = student_output.logits


            loss = distillation_loss_fn(student_logits, teacher_logits, input_ids)
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

            total_loss += loss.item()

        print(f"Epoch {epoch+1} Complete. Avg Loss: {total_loss/len(loader):.4f}")


    student_model.save_pretrained("./distilled_qwen_1.5b")
    print("Distilled model saved to ./distilled_qwen_1.5b")

run_distillation("/content/upsc_dataset_langchain_20260213_095213.csv.csv")