### **Import necessary libraries**
- *torch*: For model training and inference.
- *transformers*: For loading pre-trained models, tokenizers, and training utilities.
- *datasets*: For loading and processing datasets.
- *evaluate*: For metrics like code_eval (Pass@k) and BLEU.
- *wandb*: For logging experiments (metrics, hyperparameters).
- *Other utils*: Time for latency, Path/os for file handling, pandas for result tables.

In [None]:
import torch
from torch.utils.data import DataLoader
import torch.nn.functional as F
from torch.optim import AdamW
from transformers import AutoTokenizer, AutoModelForCausalLM, GPT2Config, GPT2LMHeadModel
from datasets import load_dataset, Dataset
import evaluate
import wandb
import time
from pathlib import Path
import os
import pandas as pd
import json
from kaggle_secrets import UserSecretsClient
import warnings

# Suppress unnecessary warnings
warnings.filterwarnings("ignore")

# Ignore XLA warnings (common in Kaggle environments with mixed backends)
os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=/usr/lib/cuda"
os.environ["HF_ALLOW_CODE_EVAL"] = "1"
# Set environment variable to disable tokenizer parallelism
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
os.environ['TORCH_USE_CUDA_DSA'] = "1"

In [None]:
try:
    user_secrets = UserSecretsClient()
    wandb_api_key = user_secrets.get_secret("WANDB_API_KEY")
    os.environ["WANDB_API_KEY"] = wandb_api_key
    print("W&B API key set successfully from Kaggle Secrets")
except Exception as e:
    print(f"Error setting W&B API key: {e}")
    print("Please ensure WANDB_API_KEY is added in Kaggle Secrets (Add-ons > Secrets)")

### **Load and inspect datasets**
- Load a small coding problems dataset from Hugging Face (train: 374 examples, test/validation small).
- Print dataset info and sample for verification.
- HumanEval is used for evaluation (164 coding problems with tests/solutions).

In [None]:
train_dataset = load_dataset("google-research-datasets/mbpp", split="train")
eval_dataset = load_dataset("openai/openai_humaneval", split="test")
print("Train Dataset Info:", train_dataset)
print(json.dumps(train_dataset[0], indent=4))  # Sample entry: task_id, text (problem), code (solution), test_list (tests)
print("Eval Dataset Info:", eval_dataset)
print(json.dumps(eval_dataset[0], indent=4))

### **Preprocess functions for tokenization**

In [None]:
def preprocess_training_data_function(examples, tokenizer):
    """
    Preprocess training data by formatting prompts and tokenizing

    Args:
        examples (dict): Batch of dataset examples with 'text' (problem), 'test_list' (tests), 'code' (target).
        tokenizer (AutoTokenizer): Tokenizer for encoding text to tokens.

    Returns:
        dict: Tokenized inputs (input_ids, attention_mask) and labels (tokenized targets).
    
    Details:
        - Formats input as "Instruction:[problem] Test list: [tests] Response: ".
        - Tokenizes with max_length=512 (model context limit), truncation/padding.
        - Targets ('code') tokenized as labels for supervised training.
        - Uses PyTorch tensors for direct model input.
    """
    inputs = [f"Instruction:\n{problem}\n\nTest list:\n{test}\n\nResponse:\n" for problem, test in zip(examples["text"], examples["test_list"])]
    targets = examples["code"]

    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length", return_tensors="pt")
    labels = tokenizer(targets, max_length=512, truncation=True, padding="max_length", return_tensors="pt").input_ids
    model_inputs["labels"] = labels
    
    return model_inputs


def preprocess_evaluation_data_function(examples, tokenizer):
    """
    Preprocess evaluation data similarly to training, for consistency.
    
    Args:
        examples (dict): Batch with 'prompt' (problem), 'test' (tests), 'canonical_solution' (targets).
        tokenizer (AutoTokenizer): Tokenizer for encoding.
    
    Returns:
        dict: Tokenized inputs and labels.
    
    Details:
        - Similar formatting as training.
        - Used for validation loss computation during training.
    """
    inputs = [f"Instruction:\n{problem}\n\nTest list:\n{test}\n\nResponse:\n" for problem, test in zip(examples["prompt"], examples["test"])]
    targets = examples["canonical_solution"]
    
    model_inputs = tokenizer(inputs, max_length=512, truncation=True, padding="max_length", return_tensors="pt")
    labels = tokenizer(targets, max_length=512, truncation=True, padding="max_length", return_tensors="pt").input_ids
    model_inputs["labels"] = labels
    
    return model_inputs

### **Load tokenizer and tokenize datasets**
- Use StarCoder2-3B tokenizer (code-specialized).
- Set pad_token to eos_token for proper padding in generation.
- Map preprocessing functions batched for efficiency, remove original columns.

In [None]:
checkpoint = 'bigcode/starcoder2-3b'
tokenizer = AutoTokenizer.from_pretrained(checkpoint, use_fast=True, padding_side='left')

if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

In [None]:
tokenizer

In [None]:
tokenized_train = train_dataset.map(lambda x: preprocess_training_data_function(x, tokenizer), batched=True, remove_columns=train_dataset.column_names)
tokenized_eval = eval_dataset.map(lambda x: preprocess_evaluation_data_function(x, tokenizer), batched=True, remove_columns=eval_dataset.column_names)

In [None]:
tokenized_train, tokenized_eval

### **Load teacher model**
- StarCoder2-3B as teacher (pre-trained on code).
- Use bfloat16 for efficiency, auto device mapping (GPU if available).
- Set to eval mode for inference (soft labels in distillation).

In [None]:
teacher_model = AutoModelForCausalLM.from_pretrained(checkpoint, device_map="auto", torch_dtype=torch.bfloat16)
teacher_model.eval()

### **Define student configurations**
- GPT2-like architectures with medium size (202M params).
- Configs vary depth (n_layer), width (n_embd), heads (n_head).
- Print param counts as reference.

In [None]:
student_configs = [
    {
        "name": "medium-model",
        "config": GPT2Config(
            vocab_size=tokenizer.vocab_size,
            n_positions=512,
            n_embd=1024,
            n_layer=12,
            n_head=16,
        )
    },
]

for cfg in student_configs:
    model = GPT2LMHeadModel(cfg["config"])
    params = sum(p.numel() for p in model.parameters())
    print(f"{cfg['name']} params: {params / 1e6:.1f}M")

### **Custom collate function for dataloader**

In [None]:
def collate_fn(batch):
    """
    Collate batch of tokenized examples into tensors.

    Args:
        batch (list[dict]): List of examples with input_ids, labels, attention_mask.

    Returns:
        dict: Stacked tensors for input_ids, labels, attention_mask (torch.long dtype).

    Details:
        - Ensure proper tensor stacking for batched training/inference.
        - Uses long dtype for token IDs.
    """
    input_ids = torch.stack([torch.tensor(item["input_ids"], dtype=torch.long) for item in batch])
    labels = torch.stack([torch.tensor(item["labels"], dtype=torch.long) for item in batch])
    attention_mask = torch.stack([torch.tensor(item["attention_mask"], dtype=torch.long) for item in batch])
    return {"input_ids": input_ids, "labels": labels, "attention_mask": attention_mask}

### **Training loop with distillation**
- *DataLoaders*: Train shuffled, eval not.
- *Hyperparams*: Temperature (softens logits), alpha (balances KL div and CE loss).
- *Loop over hyperparams and students*: Train with KL div + CE, log to W&B, save model.

In [None]:
batch_size = 4
train_loader = DataLoader(tokenized_train, batch_size=batch_size, shuffle=True, collate_fn=collate_fn)
eval_loader = DataLoader(tokenized_eval, batch_size=batch_size, shuffle=False, collate_fn=collate_fn)

hyperparams = [
    {"temperature": 2.0, "alpha": 0.7},
]

epochs = 7
accumulation_steps = 4

for hyperparam in hyperparams:
    temperature = hyperparam["temperature"]
    alpha = hyperparam["alpha"]

    for student_cfg in student_configs:
        run_name = f"distilled_{student_cfg['name']}_t{float(temperature)}_a{float(alpha)}"
        wandb.init(project="code-distillation", name=run_name, allow_val_change=True)

        # Log hyperparameters and model params
        wandb.log({"temperature": temperature, "alpha": alpha, "batch_size": batch_size})
        temp_model = GPT2LMHeadModel(student_cfg["config"])
        params = sum(p.numel() for p in temp_model.parameters()) / 1e6
        wandb.log({"model_params_M": params})
        del temp_model

        student_model = GPT2LMHeadModel(student_cfg["config"]).to("cuda")
        if student_cfg["name"] == "large-model":
            student_model.gradient_checkpointing_enable()  # Memory saving for large models

        optimizer = AdamW(student_model.parameters(), lr=5e-5)
        scaler = torch.amp.GradScaler()

        for epoch in range(epochs):
            epoch_start_time = time.time()
            student_model.train()
            total_loss = total_distillation_loss = total_ce_loss = 0
            num_batches = len(train_loader)

            for step, batch in enumerate(train_loader):
                inputs = batch["input_ids"].to("cuda")
                labels = batch["labels"].to("cuda")
                attention_mask = batch["attention_mask"].to("cuda")

                # Shape and value assertions for debugging
                assert inputs.shape[0] <= batch_size and inputs.shape[1] == 512, f"Input shape mismatch: {inputs.shape}"
                assert labels.shape[0] <= batch_size and labels.shape[1] == 512, f"Labels shape mismatch: {labels.shape}"
                assert attention_mask.shape[0] <= batch_size and attention_mask.shape[1] == 512, f"Attention mask mismatch: {attention_mask.shape}"
                assert torch.all(0 <= inputs) and torch.all(inputs < tokenizer.vocab_size), "Invalid input_ids"
                assert torch.all(0 <= labels) and torch.all(labels < tokenizer.vocab_size), "Invalid labels"

                with torch.amp.autocast(device_type="cuda"):
                    student_outputs = student_model(inputs, attention_mask=attention_mask, labels=labels)
                    student_logits = student_outputs.logits

                    with torch.no_grad():
                        teacher_outputs = teacher_model(inputs, attention_mask=attention_mask, labels=labels)
                        teacher_logits = teacher_outputs.logits

                    soft_teacher = F.softmax(teacher_logits / temperature, dim=-1)
                    soft_student = F.log_softmax(student_logits / temperature, dim=-1)
                    distillation_loss = F.kl_div(soft_student, soft_teacher, reduction="batchmean") * (temperature ** 2)
                    ce_loss = student_outputs.loss
                    loss = alpha * distillation_loss + (1 - alpha) * ce_loss

                loss = loss / accumulation_steps
                scaler.scale(loss).backward()

                if (step + 1) % accumulation_steps == 0:
                    scaler.unscale_(optimizer)
                    torch.nn.utils.clip_grad_norm_(student_model.parameters(), max_norm=1.0)  # Prevent exploding gradients
                    scaler.step(optimizer)
                    scaler.update()
                    optimizer.zero_grad()

                total_loss += loss.item() * accumulation_steps
                total_distillation_loss += distillation_loss.item()
                total_ce_loss += ce_loss.item()

            avg_loss = total_loss / num_batches
            avg_distillation_loss = total_distillation_loss / num_batches
            avg_ce_loss = total_ce_loss / num_batches
            epoch_time = time.time() - epoch_start_time

            wandb.log({
                "epoch": epoch,
                "train_loss": avg_loss,
                "train_distillation_loss": avg_distillation_loss,
                "train_ce_loss": avg_ce_loss,
                "epoch_time_s": epoch_time
            })
            print(f"{run_name} Epoch {epoch}: Train Loss {avg_loss:.4f} | Distillation Loss {avg_distillation_loss:.4f} | CE Loss {avg_ce_loss:.4f} | Time {epoch_time:.2f}s")

            # Validation: Compute losses on eval set
            student_model.eval()
            val_loss = val_distillation_loss = val_ce_loss = 0
            num_val_batches = len(eval_loader)

            with torch.no_grad():
                for batch in eval_loader:
                    inputs = batch["input_ids"].to("cuda")
                    labels = batch["labels"].to("cuda")
                    attention_mask = batch["attention_mask"].to("cuda")

                    with torch.amp.autocast(device_type="cuda"):
                        student_outputs = student_model(inputs, attention_mask=attention_mask, labels=labels)
                        student_logits = student_outputs.logits
                        teacher_outputs = teacher_model(inputs, attention_mask=attention_mask, labels=labels)
                        teacher_logits = teacher_outputs.logits

                        soft_teacher = F.softmax(teacher_logits / temperature, dim=-1)
                        soft_student = F.log_softmax(student_logits / temperature, dim=-1)
                        v_distillation_loss = F.kl_div(soft_student, soft_teacher, reduction="batchmean") * (temperature ** 2)
                        v_ce_loss = student_outputs.loss
                        v_loss = alpha * v_distillation_loss + (1 - alpha) * v_ce_loss

                    val_loss += v_loss.item()
                    val_distillation_loss += v_distillation_loss.item()
                    val_ce_loss += v_ce_loss.item()

            avg_val_loss = val_loss / num_val_batches
            avg_val_distillation_loss = val_distillation_loss / num_val_batches
            avg_val_ce_loss = val_ce_loss / num_val_batches

            wandb.log({
                "val_loss": avg_val_loss,
                "val_distillation_loss": avg_val_distillation_loss,
                "val_ce_loss": avg_val_ce_loss
            })
            print(f"{run_name} Epoch {epoch}: Val Loss {avg_val_loss:.4f} | Val Distillation Loss {avg_val_distillation_loss:.4f} | Val CE Loss {avg_val_ce_loss:.4f}")

        # Save model/tokenizer and log file size
        save_dir = f"/kaggle/working/{run_name}"
        os.makedirs(save_dir, exist_ok=True)
        student_model.save_pretrained(save_dir)
        tokenizer.save_pretrained(save_dir)

        model_file = Path(save_dir) / "pytorch_model.bin"
        if model_file.exists():
            file_size_mb = model_file.stat().st_size / (1024 * 1024)
            wandb.log({f"{student_cfg['name']}_file_size_mb": file_size_mb})
            print(f"{run_name} File Size: {file_size_mb:.2f} MB")

        wandb.finish()

### **Evaluation setup**
- Init W&B for eval logging.
- *Load metrics*: code_eval for Pass@k (code correctness), BLEU for similarity.
- Extract problems/references from eval_dataset.

In [None]:
wandb.init(project="code-distillation", name="evaluation", allow_val_change=True)
humaneval_metric = evaluate.load("code_eval")
bleu_metric = evaluate.load("bleu")

debug_mode = True # Set to True for quick testing (attempting to first 10 problems)
num_problems = 20 if debug_mode else len(eval_dataset)

eval_subset = eval_dataset.select(range(num_problems))
problems = [ex["prompt"] for ex in eval_subset]
references = [ex["test"] for ex in eval_subset]
solution_references = [ex["canonical_solution"] for ex in eval_subset]

eval_prompts_dataset = Dataset.from_dict({"prompt": problems})

### **Tokenize evaluation prompts and collate**

In [None]:
def tokenize_prompts(examples, tokenizer):
    """
    Tokenize evaluation prompts for generation.
    
    Args:
        examples (dict): Batch with 'prompt' keys.
        tokenizer (AutoTokenizer): For encoding.
    
    Returns:
        dict: Tokenized input_ids and attention_mask.
    """
    return tokenizer(
        examples["prompt"],
        max_length=256,
        truncation=True,
        padding="max_length",
        return_tensors="pt"
    )

tokenized_eval_prompts = eval_prompts_dataset.map(
    lambda x: tokenize_prompts(x, tokenizer),
    batched=True,
    remove_columns=["prompt"]
)

In [None]:
tokenized_eval_prompts

In [None]:
def eval_collate_fn(batch):
    """
    Collate for eval (no labels).
    
    Args:
        batch (list[dict]): Examples with input_ids, attention_mask.
    
    Returns:
        dict: Stacked tensors.
    """
    input_ids = torch.stack([torch.tensor(item["input_ids"], dtype=torch.long) for item in batch])
    attention_mask = torch.stack([torch.tensor(item["attention_mask"], dtype=torch.long) for item in batch])
    return {"input_ids": input_ids, "attention_mask": attention_mask}

eval_batch_size = 16
eval_loader = DataLoader(
    tokenized_eval_prompts,
    batch_size=eval_batch_size,
    shuffle=False,
    collate_fn=eval_collate_fn
)

### **Generation function**

In [None]:
# Initialize results dictionary
results = {
    "Model": [], "Params (M)": [], "Pass@1": [], "Pass@10": [], "BLEU": [],
    "Avg Latency (s)": [], "File Size (MB)": []
}

# Explicitly set pad_token_id to eos_token_id to suppress warnings
tokenizer.pad_token_id = tokenizer.eos_token_id  # Ensure pad_token_id is set

# Updated generate_completions function
def generate_completions(model, tokenizer, data_loader, num_samples=1, max_new_tokens=512, do_sample=False, top_p=None):
    """
    Generate code completions in batches.
    
    Args:
        model (nn.Module): Model for generation.
        tokenizer (AutoTokenizer): For decoding outputs.
        data_loader (DataLoader): Batched inputs.
        num_samples (int): Samples per prompt (1 for Pass@1, 10 for Pass@10).
        max_new_tokens (int): Max tokens to generate.
        do_sample (bool): Use sampling (for diversity in Pass@10).
        top_p (float): Nucleus sampling param.
    
    Returns:
        list[str] or list[list[str]]: Decoded generations (flat for Pass@1, nested for Pass@10).
    
    Details:
        - Loops num_samples times over loader.
        - Uses greedy decoding for do_sample=False, sampling otherwise.
        - Decodes skipping special tokens.
        - Avoids global variable for generations to prevent memory issues.
    """
    model.eval()
    all_generations = []
    with torch.no_grad():
        for _ in range(num_samples):
            batch_generations = []
            for batch in data_loader:
                inputs = batch["input_ids"].to("cuda")
                attention_mask = batch["attention_mask"].to("cuda")
                outputs = model.generate(
                    inputs,
                    attention_mask=attention_mask,
                    max_new_tokens=max_new_tokens,
                    do_sample=do_sample,
                    top_p=top_p if do_sample else None,
                    pad_token_id=tokenizer.pad_token_id,  # Use explicit pad_token_id
                    eos_token_id=tokenizer.eos_token_id,
                    num_beams=1,  # Greedy decoding for speed
                )
                decoded = [tokenizer.decode(output, skip_special_tokens=True) for output in outputs]
                batch_generations.extend(decoded)
            all_generations.append(batch_generations)
    
    # Organize outputs
    if num_samples == 1:
        return all_generations[0]
    else:
        # Transpose to group by problem
        return [[all_generations[i][j] for i in range(num_samples)] for j in range(len(all_generations[0]))]

### **Evaluation teacher**

In [None]:
# Evaluate teacher model
teacher_model.eval()
teacher_generations = generate_completions(teacher_model, tokenizer, eval_loader, num_samples=1, max_new_tokens=512)
teacher_generations_10 = generate_completions(teacher_model, tokenizer, eval_loader, num_samples=10, max_new_tokens=512, do_sample=True, top_p=0.9)

teacher_pass1, _ = humaneval_metric.compute(predictions=[[g] for g in teacher_generations], references=references, k=[1])
teacher_pass10, _ = humaneval_metric.compute(predictions=teacher_generations_10, references=references, k=[10])
teacher_bleu = bleu_metric.compute(predictions=teacher_generations, references=solution_references)["bleu"]
teacher_params = sum(p.numel() for p in teacher_model.parameters()) / 1e6

# Latency: Reduced to 20 iterations for speed
start = time.time()
for _ in range(20):
    test_input = tokenizer("def add(a, b):", return_tensors="pt", max_length=512, truncation=True, padding="max_length").to("cuda")
    with torch.no_grad():
        teacher_model.generate(
            test_input["input_ids"],
            attention_mask=test_input["attention_mask"],
            max_new_tokens=512,
            pad_token_id=tokenizer.pad_token_id,  # Explicitly set to avoid warning
            eos_token_id=tokenizer.eos_token_id,
            num_beams=1
        )
teacher_latency = (time.time() - start) / 20

# File size
teacher_save_dir = "/kaggle/working/teacher_model"
teacher_model.save_pretrained(teacher_save_dir)
teacher_file = Path(teacher_save_dir) / "pytorch_model.bin"
teacher_file_size_mb = teacher_file.stat().st_size / (1024 * 1024) if teacher_file.exists() else 0

results["Model"].append("Teacher")
results["Params (M)"].append(teacher_params)
results["Pass@1"].append(teacher_pass1["pass@1"])
results["Pass@10"].append(teacher_pass10["pass@10"])
results["BLEU"].append(teacher_bleu)
results["Avg Latency (s)"].append(teacher_latency)
results["File Size (MB)"].append(teacher_file_size_mb)

wandb.log({
    "teacher_pass@1": teacher_pass1["pass@1"],
    "teacher_pass@10": teacher_pass10["pass@10"],
    "teacher_bleu": teacher_bleu,
    "teacher_params_M": teacher_params,
    "teacher_latency_s": teacher_latency,
    "teacher_file_size_mb": teacher_file_size_mb
})

### **Evaluate student**

In [None]:
# Evaluate student model
for hyperparam in hyperparams:
    temperature = hyperparam["temperature"]
    alpha = hyperparam["alpha"]
    for student_cfg in student_configs:
        save_dir = f"distilled_{student_cfg['name']}_t{float(temperature)}_a{float(alpha)}"
        try:
            student_model = GPT2LMHeadModel.from_pretrained(save_dir).to("cuda")
            student_model.eval()

            student_generations = generate_completions(student_model, tokenizer, eval_loader, num_samples=1, max_new_tokens=256)  # Reduced
            student_generations_10 = generate_completions(student_model, tokenizer, eval_loader, num_samples=10, max_new_tokens=256, do_sample=True, top_p=0.9)  # Reduced

            student_pass1, _ = humaneval_metric.compute(predictions=[[g] for g in student_generations], references=references, k=[1])
            student_pass10, _ = humaneval_metric.compute(predictions=student_generations_10, references=references, k=[10])
            student_bleu = bleu_metric.compute(predictions=student_generations, references=solution_references)["bleu"]
            student_params = sum(p.numel() for p in student_model.parameters()) / 1e6

            # Latency (20 iterations)
            start = time.time()
            for _ in range(20):
                test_input = tokenizer("def add(a, b):", return_tensors="pt", max_length=256, truncation=True, padding="max_length").to("cuda")  # Reduced
                with torch.no_grad():
                    student_model.generate(
                        test_input["input_ids"], 
                        attention_mask=test_input["attention_mask"], 
                        max_new_tokens=256,
                        pad_token_id=tokenizer.pad_token_id,
                        eos_token_id=tokenizer.eos_token_id,
                        num_beams=1
                    )
            student_latency = (time.time() - start) / 20

            student_file = Path(save_dir) / "pytorch_model.bin"
            student_file_size_mb = student_file.stat().st_size / (1024 * 1024) if student_file.exists() else 0

            results["Model"].append(f"{student_cfg['name'].capitalize()}_t{float(temperature)}_a{float(alpha)}")
            results["Params (M)"].append(student_params)
            results["Pass@1"].append(student_pass1["pass@1"])
            results["Pass@10"].append(student_pass10["pass@10"])
            results["BLEU"].append(student_bleu)
            results["Avg Latency (s)"].append(student_latency)
            results["File Size (MB)"].append(student_file_size_mb)

            wandb.log({
                f"{student_cfg['name']}_t{float(temperature)}_a{float(alpha)}_pass@1": student_pass1["pass@1"],
                f"{student_cfg['name']}_t{float(temperature)}_a{float(alpha)}_pass@10": student_pass10["pass@10"],
                f"{student_cfg['name']}_t{float(temperature)}_a{float(alpha)}_bleu": student_bleu,
                f"{student_cfg['name']}_params_M": student_params,
                f"{student_cfg['name']}_latency_s": student_latency,
                f"{student_cfg['name']}_file_size_mb": student_file_size_mb
            })

            del student_model
            torch.cuda.empty_cache()
        except Exception as e:
            print(f"Error evaluating {student_cfg['name']}_t{float(temperature)}_a{float(alpha)}: {e}")

In [None]:
df = pd.DataFrame(results)
print(df.to_markdown(index=False))
wandb.log({"evaluation_results": wandb.Table(dataframe=df)})
df.to_csv("/kaggle/working/evaluation_results.csv")
wandb.finish()