<a href="https://colab.research.google.com/github/3odat/LLM-Inf/blob/main/Online_Distillation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**✅ Required Libraries for Online Knowledge Distillation**

In [None]:
pip install -q -U transformers datasets accelerate bitsandbytes peft flash-attn --no-build-isolation

**Login to Hugging face to access the**

In [None]:
from huggingface_hub import notebook_login

notebook_login()


VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

**Step 2: Import Essentials with Error Handling**

In [None]:
import torch
import torch.nn.functional as F
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    Trainer,
    TrainingArguments,
    default_data_collator,
    BitsAndBytesConfig
)
from datasets import load_dataset
print("✔️ Libraries imported")

✔️ Libraries imported


**Step 3: Load Models Safely**

In [None]:
# Teacher Model (8B)
teacher = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.1-8B-Instruct",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    #attn_implementation="flash_attention_2"
)
teacher_tokenizer = AutoTokenizer.from_pretrained(teacher.name_or_path)
teacher_tokenizer.pad_token = teacher_tokenizer.eos_token

# Student Model (1B)
student = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-3.2-1B",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    #attn_implementation="flash_attention_2"
)
student_tokenizer = AutoTokenizer.from_pretrained(student.name_or_path)
student_tokenizer.pad_token = student_tokenizer.eos_token

print(f"✅ Models loaded | Teacher: {teacher.num_parameters()/1e9:.1f}B | Student: {student.num_parameters()/1e9:.1f}B")

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

✅ Models loaded | Teacher: 8.0B | Student: 1.2B


**Step 4: Load & Verify Dataset**

In [None]:
dataset = load_dataset("json", data_files="sampled_data.json")["train"]

# Validate critical columns
assert all(col in dataset.column_names for col in ["scene", "task"]), "❌ Invalid dataset format!"
print(f"📊 Dataset loaded | Samples: {len(dataset)}")
print("Sample:", {k: v[:50] + "..." for k, v in dataset[0].items()})

📊 Dataset loaded | Samples: 500
Sample: {'scene': '[phone_3,bicycle_4,door_11]...', 'task': '[A] Move up 20 cm. And Determine if the bottle is ...'}


**Step 5: Tokenization with Error Prevention**

In [None]:
# Open and read the contents of the file
with open("system_prompt.txt", "r") as file:
    system_prompt = file.read()


def tokenize_fn(batch):
    formatted = [teacher_tokenizer.apply_chat_template(
        [{"role": "system", "content": system_prompt},
         {"role": "user", "content": f"Scene: {s}\nTask: {t}"}],
        tokenize=False
    ) for s,t in zip(batch["scene"], batch["task"])]

    tokens = teacher_tokenizer(
        formatted,
        padding="max_length",
        truncation=True,
        max_length=512,
        return_tensors="pt",
        return_attention_mask=True
    )
    return tokens

train_dataset = dataset.map(tokenize_fn, batched=True, batch_size=16)
train_dataset = train_dataset.remove_columns(["scene", "task"])
train_dataset = train_dataset.with_format("torch")

print("🔡 Tokenized sample input_ids:", train_dataset[0]["input_ids"][:10])

Map:   0%|          | 0/500 [00:00<?, ? examples/s]

🔡 Tokenized sample input_ids: tensor([128000, 128000, 128006,   9125, 128007,    271,  38766,   1303,  33025,
          2696])


**Step 6: Online Distillation Trainer**

In [None]:
class SafeDistiller(Trainer):
    def __init__(self, teacher, **kwargs):
        super().__init__(**kwargs)
        self.teacher = teacher
        self.teacher.eval()

    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        # Student forward
        student_out = model(**inputs)

        # Teacher forward (no grad)
        with torch.no_grad():
            teacher_out = self.teacher(**inputs)

        # KL divergence loss
        loss = F.kl_div(
            F.log_softmax(student_out.logits, dim=-1),
            F.softmax(teacher_out.logits, dim=-1),
            reduction="batchmean"
        )

        # Progress logging
        if self.state.global_step % 50 == 0:
            self._log_progress(inputs["input_ids"][0])

        return (loss, student_out) if return_outputs else loss

    def _log_progress(self, input_ids):
        input_text = teacher_tokenizer.decode(input_ids, skip_special_tokens=True)

        with torch.no_grad():
            teacher_text = teacher_tokenizer.decode(
                self.teacher.generate(input_ids.unsqueeze(0).to("cuda"), max_new_tokens=50)[0],
                skip_special_tokens=True
            )
            student_text = student_tokenizer.decode(
                self.model.generate(input_ids.unsqueeze(0).to("cuda"), max_new_tokens=50)[0],
                skip_special_tokens=True
            )

        print(f"\n🔥 Step {self.state.global_step}")
        print(f"Input: {input_text[:100]}...")
        print(f"Teacher: {teacher_text[:100]}...")
        print(f"Student: {student_text[:100]}...\n")
        print(f"💻 Memory: {torch.cuda.memory_allocated()//1024**3}GB / {torch.cuda.memory_reserved()//1024**3}GB")

**Step 7: Configure Training**

In [None]:
training_args = TrainingArguments(
    output_dir="./llama-1b-minispec",
    gradient_checkpointing=True,     # Recompute activations during backward pass instead of storing them
    per_device_train_batch_size=16,  # A100 capacity
    gradient_accumulation_steps=2,   # Maintain effective batch size
    num_train_epochs=3,
    learning_rate=3e-5,
    weight_decay=0.01,
    fp16=True,
    logging_steps=10,
    save_strategy="epoch",
    remove_unused_columns=False
)

trainer = SafeDistiller(
    teacher=teacher,
    model=student,
    args=training_args,
    train_dataset=train_dataset,
    data_collator=default_data_collator,
)
print("⚙️ Training configured | Batch size:", training_args.per_device_train_batch_size)

⚙️ Training configured | Batch size: 32


**Step 8: Run Training**

In [None]:
print("🚀 Starting training...")
trainer.train()
print("🎉 Training complete!")

🚀 Starting training...


OutOfMemoryError: CUDA out of memory. Tried to allocate 7.83 GiB. GPU 0 has a total capacity of 79.33 GiB of which 5.53 GiB is free. Including non-PyTorch memory, this process has 73.16 GiB memory in use. Of the allocated memory 70.66 GiB is allocated by PyTorch, and 1.99 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

**Step 9: Save & Test Model**

In [None]:
# Save model
trainer.save_model("./llama-1b-minispec-final")
student_tokenizer.save_pretrained("./llama-1b-minispec-final")

# Inference test
def generate_code(prompt):
    inputs = student_tokenizer(prompt, return_tensors="pt").to("cuda")
    outputs = student.generate(**inputs, max_new_tokens=100)
    return student_tokenizer.decode(outputs[0], skip_special_tokens=True)

test_prompt = """[INST] <<SYS>>
Generate MiniSpec code using:
- tu(angle): Turn counterclockwise
- tc(angle): Turn clockwise
<</SYS>>
Scene: [computer_5,door_9]
Task: [A] Approach door then turn 45° clockwise [/INST]"""
print(generate_code(test_prompt))

✔️ Trainer ready
