### Step 1: Install required packages

In [1]:
!pip install transformers datasets accelerate bitsandbytes trl peft torch tqdm -q

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m59.1/59.1 MB[0m [31m44.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m518.9/518.9 kB[0m [31m47.2 MB/s[0m eta [36m0:00:00[0m
[?25h

In [2]:
# Restart the runtime after installation (without deleting file)
exit()

### Step 2: Imports libraries

In [1]:
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
from datasets import load_dataset
from trl import GKDTrainer, GKDConfig

In [2]:
# Configuration
TEACHER_ID = "Qwen/Qwen2.5-1.5B-Instruct"
STUDENT_ID = "Qwen/Qwen2.5-0.5B"
DATASET_ID = "yahma/alpaca-cleaned"
OUTPUT_DIR = "./distilled_qwen_0.5b_instruct"

### Step 3: Get an estimation of memory usage for loading the model(s)

In [14]:
!accelerate estimate-memory {TEACHER_ID} --library_name transformers

Loading pretrained config for `Qwen/Qwen2.5-1.5B-Instruct` from `transformers`...
2026-01-09 14:47:04.328273: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1767970024.350241   70874 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1767970024.356772   70874 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1767970024.373528   70874 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1767970024.373563   70874 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000

In [15]:
!accelerate estimate-memory {STUDENT_ID} --library_name transformers

Loading pretrained config for `Qwen/Qwen2.5-0.5B` from `transformers`...
2026-01-09 14:47:22.985129: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1767970043.006203   70993 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1767970043.012561   70993 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1767970043.028777   70993 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1767970043.028806   70993 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:17

### Step 4: Auto-detect dtype based on GPU capability

In [3]:
import torch
if torch.cuda.is_available():
    device = "cuda"
    gpu_name = torch.cuda.get_device_name(0)
    gpu_capability = torch.cuda.get_device_capability()[0]

    print(f"GPU: {gpu_name}")
    print(f"Using device: {device}")
    print(f"Compute Capability: {gpu_capability}.x")

    # Ampere (RTX 30xx, A100) and newer (capability >= 8) support bf16 efficiently
    # Older GPUs (T4, V100, RTX 20xx) should use fp16
    if gpu_capability >= 8:
        torch_dtype = torch.bfloat16
        use_bf16 = True
        use_fp16 = False
        attn_implementation = "flash_attention_2"
        print("Using bfloat16 (Ampere+ GPU detected)")
    else:
        torch_dtype = torch.float16
        use_bf16 = False
        use_fp16 = True
        attn_implementation = "eager"
        print("Using float16 (Pre-Ampere GPU detected)")
else:
    raise RuntimeError("No GPU available!")

GPU: NVIDIA A100-SXM4-80GB
Using device: cuda
Compute Capability: 8.x
Using bfloat16 (Ampere+ GPU detected)


### Step5: Loading model(s) & tokenizer

In [4]:
# If attention is flash attention 2, install it (using the command below or skip if not the case)
print(f"Attention: {attn_implementation}")

Attention: flash_attention_2


In [6]:
# install Flash Attention 2
!pip install ninja packaging wheel
!pip install flash-attn --no-build-isolation

Collecting ninja
  Downloading ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (5.1 kB)
Downloading ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (180 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/180.7 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m180.7/180.7 kB[0m [31m16.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ninja
Successfully installed ninja-1.13.0
Collecting flash-attn
  Downloading flash_attn-2.8.3.tar.gz (8.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m8.4/8.4 MB[0m [31m145.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: flash-attn
  Building wheel for flash-attn (setup.py) ... [?25l[?25hdone
  Created wheel for flash-attn: filename=flash_attn-2.8.3-cp312-cp312-linux_x86_64.whl size=253780426 sha256=4e2f9e3

In [7]:
# Teacher model
print(f"Loading Teacher: {TEACHER_ID}")
teacher_model = AutoModelForCausalLM.from_pretrained(
    TEACHER_ID,
    device_map="auto",
    dtype=torch_dtype,
    attn_implementation=attn_implementation,
    trust_remote_code=True
)
teacher_model.eval()
print(f"Teacher loaded: {sum(p.numel() for p in teacher_model.parameters()) / 1e6:.1f}M parameters")

Loading Teacher: Qwen/Qwen2.5-1.5B-Instruct


generation_config.json:   0%|          | 0.00/242 [00:00<?, ?B/s]

Teacher loaded: 1543.7M parameters


In [8]:
# Student model
print(f"Loading Student: {STUDENT_ID}")
student_model = AutoModelForCausalLM.from_pretrained(
    STUDENT_ID,
    dtype=torch_dtype,
    device_map="auto",
    trust_remote_code=True
)
print(f"Student loaded: {sum(p.numel() for p in student_model.parameters()) / 1e6:.1f}M parameters")
# Note: No prepare_model_for_kbit_training needed here!

Loading Student: Qwen/Qwen2.5-0.5B


config.json:   0%|          | 0.00/681 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/988M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/138 [00:00<?, ?B/s]

Student loaded: 494.0M parameters


In [9]:
# Tokenizer
tokenizer = AutoTokenizer.from_pretrained(STUDENT_ID)
tokenizer.pad_token = tokenizer.eos_token

tokenizer_config.json: 0.00B [00:00, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

tokenizer.json: 0.00B [00:00, ?B/s]

### Step 6: Dataset preparation

In [10]:
# Loading dataset
raw_dataset = load_dataset(DATASET_ID, split="train").train_test_split(test_size=0.1, seed=42)

def format_alpaca_for_gkd(row):
    if row.get('input') and row['input'].strip():
        user_content = f"{row['instruction']}\n\nInput: {row['input']}"
    else:
        user_content = row['instruction']

    # Return full conversation (system + user + assistant)
    # Note: The Trainer will automatically slice off the last message to create the prompt.
    return {
        "messages": [
            {"role": "system", "content": "You are a helpful assistant."},
            {"role": "user", "content": user_content},
            {"role": "assistant", "content": row['output']}
        ]
    }

dataset = raw_dataset.map(format_alpaca_for_gkd)

README.md: 0.00B [00:00, ?B/s]

alpaca_data_cleaned.json:   0%|          | 0.00/44.3M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/51760 [00:00<?, ? examples/s]

Map:   0%|          | 0/46584 [00:00<?, ? examples/s]

Map:   0%|          | 0/5176 [00:00<?, ? examples/s]

### Step 7: Training configuration

In [11]:
# GKD parameters settings
gkd_config = GKDConfig(
    output_dir=OUTPUT_DIR,
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    num_train_epochs=1,

    # Precision settings (Auto detection)
    bf16=use_bf16,
    fp16=use_fp16,

    # GKD Params
    lmbda=1.0,
    max_new_tokens=64,
    beta=0.5,
    temperature=0.9,

    logging_steps=10,
    save_strategy="steps",
    save_steps=500,
    save_total_limit=2,
    gradient_checkpointing=True,
    report_to="none"
)



### Step 8: Training execution

In [12]:
print("Starting on-policy distillation training...")
trainer = GKDTrainer(
    model=student_model,
    teacher_model=teacher_model,
    args=gkd_config,
    train_dataset=dataset["train"],
    eval_dataset=dataset["test"],
    processing_class=tokenizer,
    # peft_config=peft_config
)

trainer.train()
# Save the final model
trainer.save_model(OUTPUT_DIR)
print(f"Training completed. Model saved to {OUTPUT_DIR}")

  trainer = GKDTrainer(
The model is already on multiple devices. Skipping the move to device specified in `args`.
The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'bos_token_id': None, 'pad_token_id': 151643}.


Starting on-policy distillation training...


`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
10,0.1802
20,0.1649
30,0.1634
40,0.1609
50,0.1603
60,0.166
70,0.1676
80,0.1551
90,0.1691
100,0.1586


Training completed. Model saved to ./distilled_qwen_0.5b_instruct


### Step 9: Inference

In [13]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

BASE_MODEL_ID = "Qwen/Qwen2.5-0.5B"
OUTPUT_DIR = "./distilled_qwen_0.5b_instruct"

# Test prompts
prompts = [
    #"Instruction: What is the capital of France?\n\nInput: \nAnswer:",
    #"Instruction: Write a short poem about a robot learning to love.\n\nInput: \nAnswer:",
    #"Instruction: Solve this math problem: If I have 3 apples and eat 1, how many do I have?\n\nInput: \nAnswer:",
    #"Instruction: Explain why the sky is blue in one sentence.\n\nInput: \nAnswer:"
    "Instruction: How do I make a cup of tea?\n\nInput: \nAnswer:"

]

# We use the same tokenizer for both
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID)

# Helper function to run inference
def generate_response(model, prompt):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=100,
            do_sample=True,
            temperature=0.7,
            pad_token_id=tokenizer.eos_token_id
        )

    # Decode and remove the prompt itself from the output
    full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
    return full_text.replace(prompt, "").strip()


print("\nLoading the distilled model: ")
distilled_model = AutoModelForCausalLM.from_pretrained(
    OUTPUT_DIR,
    dtype=torch_dtype,
    device_map="auto"
)

print("\nGenerating distilled model responses ")
distilled_results = []
for p in prompts:
    print(f"Generating for: {p[:30]}...")
    distilled_results.append(generate_response(distilled_model, p))

for i, prompt in enumerate(prompts):
    print(f"\nPROMPT: {prompt.split('Input')[0].strip()}")

    print(f"{distilled_results[i]}")


Loading the distilled model: 

Generating distilled model responses 
Generating for: Instruction: How do I make a c...

PROMPT: Instruction: How do I make a cup of tea?
To make a cup of tea, you will need:

1. Tea bags (or loose tea leaves)
2. A teapot
3. Water (if you have a kettle or pot)
4. A tea infuser or a tea bag

Instructions:

1. Fill a teapot or kettle with water and add a few drops of tea bags or loose tea leaves.
2. Place the teapot or kettle on the stove or burners and heat it until the water boils.
3. Once
