# üöÄ Knowledge Distillation for Reasoning Tasks

## üìç Environment Support:
- ‚úÖ **Google Colab** (Recommended for free GPU)
- ‚úÖ **Kaggle** (Pre-installed packages)
- ‚úÖ **Local** (Windows/Linux/Mac)

---

## üî• Google Colab Setup:

### 1Ô∏è‚É£ Enable GPU:
   - Click **Runtime** ‚Üí **Change runtime type**
   - Hardware accelerator: **GPU** (T4 or better)
   - Click **Save**

### 2Ô∏è‚É£ Mount Google Drive (Optional but Recommended):
   - Saves cache permanently (survives session restarts)
   - Auto-prompted when you run cell 2
   - Click **Connect to Google Drive** ‚Üí **Allow**

### 3Ô∏è‚É£ Check GPU:
   ```python
   !nvidia-smi
   ```

---

## ‚ö° Quick Start:
1. Run cells in order (Shift+Enter)
2. First run: Extracts teacher cache (~30-60 min)
3. Cache saved to Google Drive for reuse
4. Subsequent runs: Much faster!

---

In [31]:
# Install packages for Google Colab
!pip install -q transformers==4.36.0 peft==0.7.1 datasets==2.16.0 accelerate==0.25.0 bitsandbytes==0.41.3 wandb scikit-learn
!pip install -q huggingface_hub[hf_xet]

import os
os.environ['WANDB_DISABLED'] = 'true'

print("‚úÖ Packages installed successfully!")


[notice] A new release of pip is available: 24.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


‚úÖ Packages installed successfully!



[notice] A new release of pip is available: 24.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [32]:
# üîß Fix tokenizer compatibility issue
%pip install -q --upgrade tokenizers==0.15.0

print("‚úÖ Tokenizer fixed!")

Note: you may need to restart the kernel to use updated packages.
‚úÖ Tokenizer fixed!


  error: subprocess-exited-with-error
  
  √ó Building wheel for tokenizers (pyproject.toml) did not run successfully.
  ‚îÇ exit code: 1
  ‚ï∞‚îÄ> [57 lines of output]
      Running `maturin pep517 build-wheel -i e:\python.exe --compatibility off`
      √∞≈∏\x8d¬π Building a mixed python/rust project
      √∞≈∏‚Äù‚Äî Found pyo3 bindings
      √∞≈∏\x90\x8d Found CPython 3.12 at e:\python.exe
      √∞≈∏‚Äú¬° Using build options features, bindings from pyproject.toml
      [1m[32m   Compiling[0m autocfg v1.1.0
      [1m[32m   Compiling[0m proc-macro2 v1.0.69
      [1m[32m   Compiling[0m unicode-ident v1.0.12
      [1m[32m   Compiling[0m windows_x86_64_msvc v0.48.5
      [1m[32m   Compiling[0m cfg-if v1.0.0
      [1m[32m   Compiling[0m syn v1.0.109
      [1m[32m   Compiling[0m target-lexicon v0.12.12
      [1m[32m   Compiling[0m scopeguard v1.2.0
      [1m[32m   Compiling[0m libc v0.2.150
      [1m[32m   Compiling[0m crossbeam-utils v0.8.16
      [1m[32m   

In [33]:
# üîç Check GPU availability (especially important for Colab)
!nvidia-smi

import torch
print(f"\n{'='*70}")
print(f"üî• PyTorch version: {torch.__version__}")
print(f"üî• CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"üî• CUDA version: {torch.version.cuda}")
    print(f"üî• GPU Device: {torch.cuda.get_device_name(0)}")
    print(f"üî• GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    print("‚ö†Ô∏è  WARNING: GPU not detected!")
    print("   In Colab: Runtime ‚Üí Change runtime type ‚Üí GPU")
print(f"{'='*70}")

Fri Dec  5 16:22:58 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.92                 Driver Version: 580.92         CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                  Driver-Model | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A1000 Laptop GPU  WDDM  |   00000000:01:00.0 Off |                  N/A |
| N/A   50C    P8              5W /   60W |       0MiB /   4096MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [34]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    TaskType
)
from datasets import load_dataset
import numpy as np
from dataclasses import dataclass
from typing import Dict, List, Optional
import gc
from tqdm.auto import tqdm
import json

from huggingface_hub import login
login(token="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")

# Config - Auto-detect environment (Colab, Kaggle, or Local)
class Config:
    # Detect environment
    IS_COLAB = 'google.colab' in str(get_ipython()) if 'get_ipython' in dir() else False
    IS_KAGGLE = os.path.exists('/kaggle')
    IS_LOCAL = not (IS_COLAB or IS_KAGGLE)
    
    # Models
    TEACHER_MODEL = "meta-llama/Llama-2-13b-hf"
    STUDENT_MODEL = "mistralai/Mistral-7B-v0.1"
    
    # Dataset
    DATASET_NAME = "gsm8k"
    DATASET_CONFIG = "main"
    MAX_SAMPLES = 2000
    MAX_LENGTH = 512
    
    # Training
    BATCH_SIZE = 2
    GRADIENT_ACCUM = 8
    LEARNING_RATE = 2e-4
    NUM_EPOCHS = 3
    WARMUP_STEPS = 100
    
    # Distillation
    ALPHA_OUTPUT = 0.5
    BETA_LATENT = 0.5
    TEMPERATURE = 2.0
    LATENT_LAYERS = [8, 16, 24]
    
    # LoRA
    LORA_R = 16
    LORA_ALPHA = 32
    LORA_DROPOUT = 0.05
    LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj"]
    
    # Environment-specific paths
    if IS_COLAB:
        # Google Colab paths
        OUTPUT_DIR = "/content/drive/MyDrive/distill_output"
        LATENT_CACHE_DIR = "/content/drive/MyDrive/latent_cache"
        USE_GDRIVE = True
    elif IS_KAGGLE:
        # Kaggle paths
        OUTPUT_DIR = "/kaggle/working/distill_output"
        LATENT_CACHE_DIR = "/kaggle/working/latent_cache"
        USE_GDRIVE = False
    else:
        # Local paths
        OUTPUT_DIR = "./distill_output"
        LATENT_CACHE_DIR = "./latent_cache"
        USE_GDRIVE = False
    
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

config = Config()

# Mount Google Drive if on Colab (to save cache permanently)
if config.IS_COLAB and config.USE_GDRIVE:
    try:
        from google.colab import drive
        drive.mount('/content/drive')
        print("‚úÖ Google Drive mounted successfully!")
    except Exception as e:
        print(f"‚ö†Ô∏è Could not mount Google Drive: {e}")
        print("   Using /content/ instead (will be lost after session)")
        config.OUTPUT_DIR = "/content/distill_output"
        config.LATENT_CACHE_DIR = "/content/latent_cache"
        config.USE_GDRIVE = False

# Create directories
os.makedirs(config.OUTPUT_DIR, exist_ok=True)
os.makedirs(config.LATENT_CACHE_DIR, exist_ok=True)

print(f"{'='*70}")
print(f"üåç Environment: {'Google Colab' if config.IS_COLAB else 'Kaggle' if config.IS_KAGGLE else 'Local'}")
print(f"üî• Device: {config.DEVICE}")
print(f"üî• GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}")
if config.IS_COLAB:
    print(f"üíæ Google Drive: {'Mounted ‚úÖ' if config.USE_GDRIVE else 'Not mounted ‚ö†Ô∏è'}")
print(f"\nüìÇ Paths:")
print(f"   Output: {config.OUTPUT_DIR}")
print(f"   Cache: {config.LATENT_CACHE_DIR}")
print(f"{'='*70}")

üåç Environment: Kaggle
üî• Device: cuda
üî• GPU: NVIDIA RTX A1000 Laptop GPU

üìÇ Paths:
   Output: /kaggle/working/distill_output
   Cache: /kaggle/working/latent_cache


In [35]:
def prepare_prompt(question: str, answer: str = None) -> str:
    """Format prompt for reasoning task"""
    prompt = f"Question: {question}\n\nLet's solve this step by step:\n"
    """This really need to be improved later"""
    if answer:
        prompt += f"{answer}"
    return prompt

class ReasoningDataset(Dataset):
    """Custom dataset with latent cache support"""
    def __init__(self, data, tokenizer, max_length=512, latent_dir=None):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.latent_dir = latent_dir
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Tokenize
        prompt = prepare_prompt(item['question'], item.get('answer'))
        encoding = self.tokenizer(
            prompt,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        
        result = {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'idx': idx
        }
        
        # Load cached latent if available
        if self.latent_dir:
            latent_path = os.path.join(self.latent_dir, f"latent_{idx}.pt")
            if os.path.exists(latent_path):
                result['teacher_latents'] = torch.load(latent_path)
        
        return result

# Load GSM8K dataset
print("üì¶ Loading GSM8K dataset...")
dataset = load_dataset(config.DATASET_NAME, config.DATASET_CONFIG)

# Sample subset for Kaggle
train_data = dataset['train'].select(range(min(config.MAX_SAMPLES, len(dataset['train']))))
test_data = dataset['test'].select(range(min(500, len(dataset['test']))))

print(f"‚úÖ Train: {len(train_data)} | Test: {len(test_data)}")

üì¶ Loading GSM8K dataset...
‚úÖ Train: 2000 | Test: 500
‚úÖ Train: 2000 | Test: 500


In [36]:
def load_teacher_model():
    """Load teacher with 4-bit quantization to save memory"""
    print("üîÑ Loading Teacher Model (4-bit)...")
    
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
    )
    
    model = AutoModelForCausalLM.from_pretrained(
        config.TEACHER_MODEL,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(config.TEACHER_MODEL)
    tokenizer.pad_token = tokenizer.eos_token
    
    return model, tokenizer

def check_cache_completeness(cache_dir, data_size):
    """Check how many cache files exist"""
    if not os.path.exists(cache_dir):
        return 0
    
    existing_files = [f for f in os.listdir(cache_dir) if f.startswith('latent_') and f.endswith('.pt')]
    return len(existing_files)

def extract_latent_states(model, tokenizer, data, output_dir, batch_size=1):
    """Extract and cache teacher's latent states"""
    print(f"üß† Extracting latent states to {output_dir}...")
    
    # Check existing cache
    existing_count = check_cache_completeness(output_dir, len(data))
    print(f"   Found {existing_count}/{len(data)} existing cache files")
    
    if existing_count == len(data):
        print("‚úÖ All cache files exist! Skipping extraction.")
        return
    
    model.eval()
    
    # Estimate time
    print(f"\n‚è±Ô∏è  Estimated time: ~{len(data) - existing_count} samples √ó 2-3 sec = {(len(data) - existing_count) * 2.5 / 60:.1f} min")
    print(f"   Progress will be saved to: {output_dir}")
    if config.USE_GDRIVE:
        print(f"   üíæ Cache saved to Google Drive (permanent)")
    else:
        print(f"   ‚ö†Ô∏è  Cache in /content/ (lost after session ends)")
    
    import time
    start_time = time.time()
    
    with torch.no_grad():
        for idx in tqdm(range(len(data)), desc="Extracting"):
            cache_path = os.path.join(output_dir, f"latent_{idx}.pt")
            
            # Skip if already cached
            if os.path.exists(cache_path):
                continue
            
            item = data[idx]
            prompt = prepare_prompt(item['question'], item.get('answer'))
            
            inputs = tokenizer(
                prompt,
                return_tensors='pt',
                truncation=True,
                max_length=config.MAX_LENGTH
            ).to(model.device)
            
            # Forward pass with hidden states
            outputs = model(
                **inputs,
                output_hidden_states=True,
                return_dict=True
            )
            
            # Extract specific layers
            latent_states = {}
            for layer_idx in config.LATENT_LAYERS:
                if layer_idx < len(outputs.hidden_states):
                    # Average pool over sequence
                    hidden = outputs.hidden_states[layer_idx]
                    pooled = hidden.mean(dim=1).cpu()  # [batch, hidden_dim]
                    latent_states[f'layer_{layer_idx}'] = pooled
            
            # Save
            torch.save(latent_states, cache_path)
            
            # Free memory
            del outputs, inputs
            if idx % 100 == 0:
                torch.cuda.empty_cache()
    
    elapsed = time.time() - start_time
    print(f"‚úÖ Latent extraction complete! ({elapsed/60:.1f} minutes)")

# ===== EXTRACTION LOGIC =====

# Check if cache already exists
existing_cache = check_cache_completeness(config.LATENT_CACHE_DIR, len(train_data))
EXTRACT_LATENTS = existing_cache < len(train_data)

if EXTRACT_LATENTS:
    print(f"\n{'='*70}")
    print(f"üß† TEACHER KNOWLEDGE EXTRACTION")
    print(f"{'='*70}")
    print(f"Cache status: {existing_cache}/{len(train_data)} files exist")
    print(f"Will extract: {len(train_data) - existing_cache} samples")
    
    if config.IS_COLAB and not config.USE_GDRIVE:
        print(f"\n‚ö†Ô∏è  WARNING: Google Drive not mounted!")
        print(f"   Cache will be lost when session ends.")
        print(f"   Recommendation: Restart and mount Drive first.")
        proceed = input("Continue anyway? (yes/no): ")
        if proceed.lower() != 'yes':
            raise Exception("Stopped by user. Please mount Google Drive and restart.")
    
    teacher_model, teacher_tokenizer = load_teacher_model()
    extract_latent_states(
        teacher_model, 
        teacher_tokenizer, 
        train_data, 
        config.LATENT_CACHE_DIR
    )
    
    # Free teacher model
    del teacher_model, teacher_tokenizer
    gc.collect()
    torch.cuda.empty_cache()
    print("üóëÔ∏è  Teacher model freed from memory")
else:
    print(f"\n{'='*70}")
    print(f"‚úÖ CACHE FOUND - Skipping teacher extraction")
    print(f"{'='*70}")
    print(f"Using existing cache: {config.LATENT_CACHE_DIR}")
    print(f"Files: {existing_cache}/{len(train_data)}")
    print(f"This saves ~{existing_cache * 2.5 / 60:.1f} minutes! üöÄ")


üß† TEACHER KNOWLEDGE EXTRACTION
Cache status: 0/2000 files exist
Will extract: 2000 samples
üîÑ Loading Teacher Model (4-bit)...


Downloading shards:   0%|          | 0/3 [00:00<?, ?it/s]Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Error while downloading from https://huggingface.co/meta-llama/Llama-2-13b-hf/resolve/main/model-00001-of-00003.safetensors: HTTPSConnectionPool(host='cas-bridge.xethub.hf.co', port=443): Read timed out.
Trying to resume download...
Error while downloading from https://huggingface.co/meta-llama/Llama-2-13b-hf/resolve/main/model-00001-of-00003.safetensors: HTTPSConnectionPool(host='cas-bridge.xethub.hf.co', port=443): Read timed out.
Trying to resume download...
'(ReadTimeout

KeyboardInterrupt: 

In [None]:
def setup_student_model():
    """Load student model with LoRA"""
    print("üéì Loading Student Model with LoRA...")
    
    # Load base model
    model = AutoModelForCausalLM.from_pretrained(
        config.STUDENT_MODEL,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    
    tokenizer = AutoTokenizer.from_pretrained(config.STUDENT_MODEL)
    tokenizer.pad_token = tokenizer.eos_token
    
    # LoRA config
    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=config.LORA_R,
        lora_alpha=config.LORA_ALPHA,
        lora_dropout=config.LORA_DROPOUT,
        target_modules=config.LORA_TARGET_MODULES,
        bias="none"
    )
    
    # Apply LoRA
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
    
    return model, tokenizer

student_model, student_tokenizer = setup_student_model()

üéì Loading Student Model with LoRA...


Downloading shards:   0%|          | 0/2 [00:00<?, ?it/s]Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`
Downloading shards:   0%|          | 0/2 [08:55<?, ?it/s]



KeyboardInterrupt: 

In [None]:
class DistillationTrainer(Trainer):
    """Custom trainer with latent distillation loss"""
    
    def compute_loss(self, model, inputs, return_outputs=False):
        # Get student outputs with hidden states
        outputs = model(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            labels=inputs['input_ids'],
            output_hidden_states=True,
            return_dict=True
        )
        
        # 1. Output loss (standard language modeling)
        loss_output = outputs.loss
        
        # 2. Latent distillation loss
        loss_latent = 0.0
        if 'teacher_latents' in inputs:
            teacher_latents = inputs['teacher_latents']
            student_hidden = outputs.hidden_states
            
            num_latent_layers = 0
            for layer_idx in config.LATENT_LAYERS:
                layer_key = f'layer_{layer_idx}'
                if layer_key in teacher_latents and layer_idx < len(student_hidden):
                    # Get student hidden at same layer
                    student_h = student_hidden[layer_idx]
                    student_pooled = student_h.mean(dim=1)  # [batch, hidden]
                    
                    # Teacher latent
                    teacher_h = teacher_latents[layer_key].to(student_pooled.device)
                    
                    # MSE loss
                    loss_latent += F.mse_loss(student_pooled, teacher_h)
                    num_latent_layers += 1
            
            if num_latent_layers > 0:
                loss_latent /= num_latent_layers
        
        # Combined loss
        total_loss = (config.ALPHA_OUTPUT * loss_output + 
                      config.BETA_LATENT * loss_latent)
        
        return (total_loss, outputs) if return_outputs else total_loss

In [None]:
train_dataset = ReasoningDataset(
    train_data,
    student_tokenizer,
    max_length=config.MAX_LENGTH,
    latent_dir=config.LATENT_CACHE_DIR
)

test_dataset = ReasoningDataset(
    test_data,
    student_tokenizer,
    max_length=config.MAX_LENGTH,
    latent_dir=None  # No latent for test
)

# Training arguments
training_args = TrainingArguments(
    output_dir=config.OUTPUT_DIR,
    num_train_epochs=config.NUM_EPOCHS,
    per_device_train_batch_size=config.BATCH_SIZE,
    per_device_eval_batch_size=config.BATCH_SIZE,
    gradient_accumulation_steps=config.GRADIENT_ACCUM,
    learning_rate=config.LEARNING_RATE,
    warmup_steps=config.WARMUP_STEPS,
    logging_steps=50,
    save_steps=500,
    eval_steps=500,
    evaluation_strategy="steps",
    save_total_limit=2,
    load_best_model_at_end=True,
    fp16=True,
    report_to="none",
    remove_unused_columns=False,
)

# Initialize trainer
trainer = DistillationTrainer(
    model=student_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)

print("üöÄ Training configuration ready!")

In [None]:
print("üî• Starting training...")
trainer.train()

# Save final model
trainer.save_model(f"{config.OUTPUT_DIR}/final_model")
student_tokenizer.save_pretrained(f"{config.OUTPUT_DIR}/final_model")

print("‚úÖ Training complete!")

In [None]:
def evaluate_reasoning(model, tokenizer, test_data, num_samples=50):
    """Evaluate reasoning accuracy"""
    model.eval()
    correct = 0
    total = 0
    
    print("üìä Evaluating reasoning accuracy...")
    
    with torch.no_grad():
        for idx in tqdm(range(min(num_samples, len(test_data)))):
            item = test_data[idx]
            prompt = prepare_prompt(item['question'])
            
            inputs = tokenizer(
                prompt,
                return_tensors='pt',
                truncation=True,
                max_length=256
            ).to(model.device)
            
            outputs = model.generate(
                **inputs,
                max_new_tokens=128,
                temperature=0.7,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id
            )
            
            generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Simple accuracy check (contains answer)
            ground_truth = str(item['answer'])
            if ground_truth in generated:
                correct += 1
            total += 1
            
            # Print first 3 examples
            if idx < 3:
                print(f"\n{'='*60}")
                print(f"Q: {item['question']}")
                print(f"Ground Truth: {ground_truth}")
                print(f"Generated: {generated[len(prompt):][:200]}...")
    
    accuracy = correct / total if total > 0 else 0
    print(f"\n‚úÖ Accuracy: {accuracy:.2%} ({correct}/{total})")
    return accuracy

# Evaluate
accuracy = evaluate_reasoning(student_model, student_tokenizer, test_data)


In [None]:
# Save metrics
results = {
    'accuracy': float(accuracy),
    'config': {
        'teacher': config.TEACHER_MODEL,
        'student': config.STUDENT_MODEL,
        'lora_r': config.LORA_R,
        'alpha_output': config.ALPHA_OUTPUT,
        'beta_latent': config.BETA_LATENT
    }
}

with open(f"{config.OUTPUT_DIR}/results.json", 'w') as f:
    json.dump(results, f, indent=2)

print("üìÅ Results saved!")

# Inference example
def inference(question: str):
    """Single inference"""
    prompt = prepare_prompt(question)
    inputs = student_tokenizer(prompt, return_tensors='pt').to(student_model.device)
    
    with torch.no_grad():
        outputs = student_model.generate(
            **inputs,
            max_new_tokens=200,
            temperature=0.7,
            do_sample=True,
            pad_token_id=student_tokenizer.eos_token_id
        )
    
    result = student_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return result[len(prompt):]

# Test inference
test_question = "If John has 5 apples and gives 2 to Mary, how many does he have left?"
print(f"\nüß™ Test Inference:")
print(f"Q: {test_question}")
print(f"A: {inference(test_question)}")

print("\n‚ú® Pipeline complete! Model saved at:", config.OUTPUT_DIR)

## üì• Download Results from Google Colab

Run the cell below to download your trained model and results.

In [None]:
# üì¶ Package and download model (for Colab)

if config.IS_COLAB:
    from google.colab import files
    import shutil
    import zipfile
    
    print("üì¶ Packaging model for download...")
    
    # Create zip file
    zip_path = "/content/distill_model.zip"
    
    with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
        # Add model files
        model_dir = f"{config.OUTPUT_DIR}/final_model"
        if os.path.exists(model_dir):
            for root, dirs, files_list in os.walk(model_dir):
                for file in files_list:
                    file_path = os.path.join(root, file)
                    arcname = os.path.relpath(file_path, config.OUTPUT_DIR)
                    zipf.write(file_path, arcname)
        
        # Add results
        results_path = f"{config.OUTPUT_DIR}/results.json"
        if os.path.exists(results_path):
            zipf.write(results_path, "results.json")
    
    file_size_mb = os.path.getsize(zip_path) / (1024 * 1024)
    print(f"‚úÖ Package created: {file_size_mb:.1f} MB")
    
    # Download
    print(f"‚¨áÔ∏è  Downloading...")
    files.download(zip_path)
    print(f"‚úÖ Download complete!")
    
    if config.USE_GDRIVE:
        print(f"\nüíæ Model also saved to Google Drive:")
        print(f"   {config.OUTPUT_DIR}/final_model/")
        print(f"   You can access it anytime from Drive")
else:
    print("‚ÑπÔ∏è  Not running on Colab - files saved locally")
    print(f"   Location: {config.OUTPUT_DIR}")

---

## üí° Google Colab Tips & Best Practices

### ‚úÖ Before Starting:
1. **Enable GPU**: Runtime ‚Üí Change runtime type ‚Üí GPU (T4 recommended)
2. **Mount Google Drive**: Saves cache permanently
3. **Check GPU**: Run `!nvidia-smi` to verify

### ‚ö° Session Management:
- **Colab Free**: ~12 hours session limit
- **First run**: Extract cache (~30-60 min), saved to Drive
- **Subsequent runs**: Load from Drive cache (instant!)
- **Tip**: Keep tab active to prevent disconnection

### üíæ Storage Strategy:

| What | Where | Why |
|------|-------|-----|
| **Teacher Cache** | Google Drive | Reuse forever (2.5GB) |
| **Student Model** | Google Drive | Access later (~100MB LoRA) |
| **Training Logs** | /content/ | Temporary, download if needed |

### üîÑ If Session Disconnects:

```python
# Cache is safe in Google Drive!
# Just re-run from training cell:
# - Mounts Drive
# - Detects existing cache
# - Continues training
```

### üìä Monitor Training:

```python
# Check GPU usage
!nvidia-smi

# Check file sizes
!du -sh /content/drive/MyDrive/latent_cache
!du -sh /content/drive/MyDrive/distill_output
```

### ‚ö†Ô∏è Common Issues:

**GPU not available:**
```
Runtime ‚Üí Change runtime type ‚Üí GPU ‚Üí Save
Then Restart runtime
```

**Drive quota exceeded:**
```python
# Use /content/ instead (temporary)
config.OUTPUT_DIR = "/content/distill_output"
config.LATENT_CACHE_DIR = "/content/latent_cache"
```

**Session disconnected:**
```
- Cache is safe in Drive
- Just re-run cells
- Training resumes from last checkpoint
```

---

## üéì What This Notebook Does:

1. **Load Models**: Teacher (Llama-2-13B) + Student (Mistral-7B)
2. **Extract Knowledge**: Teacher's hidden states (cached)
3. **Train Student**: With LoRA + Knowledge Distillation
4. **Evaluate**: On GSM8K reasoning tasks
5. **Save**: Model to Drive + Download option

**Total Time:**
- First run: ~90 min (with extraction)
- Cached run: ~30 min (skip extraction)

---

**Happy Training on Colab! üöÄ**