In [4]:
%pip install -q transformers==4.36.0 peft==0.7.1 datasets==2.16.0 accelerate==0.25.0 bitsandbytes==0.41.3 wandb scikit-learn

import os
os.environ['WANDB_DISABLED'] = 'true'  # T·∫Øt n·∫øu kh√¥ng d√πng wandb

Note: you may need to restart the kernel to use updated packages.



[notice] A new release of pip is available: 24.2 -> 25.3
[notice] To update, run: python.exe -m pip install --upgrade pip


In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModelForCausalLM, 
    AutoTokenizer,
    TrainingArguments,
    Trainer,
    BitsAndBytesConfig
)
from peft import (
    LoraConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
    TaskType
)
from datasets import load_dataset
import numpy as np
from dataclasses import dataclass
from typing import Dict, List, Optional
import gc
from tqdm.auto import tqdm
import json

# Config
class Config:
    # Models
    TEACHER_MODEL = "meta-llama/Llama-2-13b-hf"  # Ho·∫∑c "NousResearch/Llama-2-13b-hf"
    STUDENT_MODEL = "mistralai/Mistral-7B-v0.1"
    
    # Dataset
    DATASET_NAME = "gsm8k"
    DATASET_CONFIG = "main"
    MAX_SAMPLES = 2000  # Gi·ªõi h·∫°n cho Kaggle
    MAX_LENGTH = 512
    
    # Training
    BATCH_SIZE = 2
    GRADIENT_ACCUM = 8
    LEARNING_RATE = 2e-4
    NUM_EPOCHS = 3
    WARMUP_STEPS = 100
    
    # Distillation
    ALPHA_OUTPUT = 0.5  # Output loss weight
    BETA_LATENT = 0.5   # Latent loss weight
    TEMPERATURE = 2.0
    LATENT_LAYERS = [8, 16, 24]  # Layers to match
    
    # LoRA
    LORA_R = 16
    LORA_ALPHA = 32
    LORA_DROPOUT = 0.05
    LORA_TARGET_MODULES = ["q_proj", "k_proj", "v_proj", "o_proj"]
    
    # Paths
    OUTPUT_DIR = "/kaggle/working/distill_output"
    LATENT_CACHE_DIR = "/kaggle/working/latent_cache"
    
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

config = Config()
os.makedirs(config.OUTPUT_DIR, exist_ok=True)
os.makedirs(config.LATENT_CACHE_DIR, exist_ok=True)

print(f"üî• Device: {config.DEVICE}")
print(f"üî• GPU: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'N/A'}")

  from .autonotebook import tqdm as notebook_tqdm
  _torch_pytree._register_pytree_node(
The cache for model files in Transformers v4.22.0 has been updated. Migrating your old cache. This is a one-time only operation. You can interrupt this and resume the migration later on by calling `transformers.utils.move_cache()`.
0it [00:00, ?it/s]


RuntimeError: Failed to import transformers.training_args because of the following error (look up to see its traceback):
module 'torch' has no attribute 'version'

In [None]:
def prepare_prompt(question: str, answer: str = None) -> str:
    """Format prompt for reasoning task"""
    prompt = f"Question: {question}\n\nLet's solve this step by step:\n"
    """This really need to be improved later"""
    if answer:
        prompt += f"{answer}"
    return prompt

class ReasoningDataset(Dataset):
    """Custom dataset with latent cache support"""
    def __init__(self, data, tokenizer, max_length=512, latent_dir=None):
        self.data = data
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.latent_dir = latent_dir
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        item = self.data[idx]
        
        # Tokenize
        prompt = prepare_prompt(item['question'], item.get('answer'))
        encoding = self.tokenizer(
            prompt,
            truncation=True,
            max_length=self.max_length,
            padding='max_length',
            return_tensors='pt'
        )
        
        result = {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'idx': idx
        }
        
        # Load cached latent if available
        if self.latent_dir:
            latent_path = os.path.join(self.latent_dir, f"latent_{idx}.pt")
            if os.path.exists(latent_path):
                result['teacher_latents'] = torch.load(latent_path)
        
        return result

# Load GSM8K dataset
print("üì¶ Loading GSM8K dataset...")
dataset = load_dataset(config.DATASET_NAME, config.DATASET_CONFIG)

# Sample subset for Kaggle
train_data = dataset['train'].select(range(min(config.MAX_SAMPLES, len(dataset['train']))))
test_data = dataset['test'].select(range(min(500, len(dataset['test']))))

print(f"‚úÖ Train: {len(train_data)} | Test: {len(test_data)}")

In [None]:
def load_teacher_model():
    """Load teacher with 4-bit quantization to save memory"""
    print("üîÑ Loading Teacher Model (4-bit)...")
    
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
    )
    
    model = AutoModelForCausalLM.from_pretrained(
        config.TEACHER_MODEL,
        quantization_config=bnb_config,
        device_map="auto",
        trust_remote_code=True
    )
    tokenizer = AutoTokenizer.from_pretrained(config.TEACHER_MODEL)
    tokenizer.pad_token = tokenizer.eos_token
    
    return model, tokenizer

def extract_latent_states(model, tokenizer, data, output_dir, batch_size=1):
    """Extract and cache teacher's latent states"""
    print(f"üß† Extracting latent states to {output_dir}...")
    model.eval()
    
    with torch.no_grad():
        for idx in tqdm(range(len(data)), desc="Extracting"):
            cache_path = os.path.join(output_dir, f"latent_{idx}.pt")
            
            # Skip if already cached
            if os.path.exists(cache_path):
                continue
            
            item = data[idx]
            prompt = prepare_prompt(item['question'], item.get('answer'))
            
            inputs = tokenizer(
                prompt,
                return_tensors='pt',
                truncation=True,
                max_length=config.MAX_LENGTH
            ).to(model.device)
            
            # Forward pass with hidden states
            outputs = model(
                **inputs,
                output_hidden_states=True,
                return_dict=True
            )
            
            # Extract specific layers
            latent_states = {}
            for layer_idx in config.LATENT_LAYERS:
                if layer_idx < len(outputs.hidden_states):
                    # Average pool over sequence
                    hidden = outputs.hidden_states[layer_idx]
                    pooled = hidden.mean(dim=1).cpu()  # [batch, hidden_dim]
                    latent_states[f'layer_{layer_idx}'] = pooled
            
            # Save
            torch.save(latent_states, cache_path)
            
            # Free memory
            del outputs, inputs
            if idx % 100 == 0:
                torch.cuda.empty_cache()
    
    print("‚úÖ Latent extraction complete!")

# Extract latents (comment out if already done)
EXTRACT_LATENTS = True  # Set False if cache exists

if EXTRACT_LATENTS:
    teacher_model, teacher_tokenizer = load_teacher_model()
    extract_latent_states(
        teacher_model, 
        teacher_tokenizer, 
        train_data, 
        config.LATENT_CACHE_DIR
    )
    
    # Free teacher model
    del teacher_model, teacher_tokenizer
    gc.collect()
    torch.cuda.empty_cache()
    print("üóëÔ∏è  Teacher model freed from memory")


In [None]:
def setup_student_model():
    """Load student model with LoRA"""
    print("üéì Loading Student Model with LoRA...")
    
    # Load base model
    model = AutoModelForCausalLM.from_pretrained(
        config.STUDENT_MODEL,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True
    )
    
    tokenizer = AutoTokenizer.from_pretrained(config.STUDENT_MODEL)
    tokenizer.pad_token = tokenizer.eos_token
    
    # LoRA config
    lora_config = LoraConfig(
        task_type=TaskType.CAUSAL_LM,
        r=config.LORA_R,
        lora_alpha=config.LORA_ALPHA,
        lora_dropout=config.LORA_DROPOUT,
        target_modules=config.LORA_TARGET_MODULES,
        bias="none"
    )
    
    # Apply LoRA
    model = get_peft_model(model, lora_config)
    model.print_trainable_parameters()
    
    return model, tokenizer

student_model, student_tokenizer = setup_student_model()

In [None]:
class DistillationTrainer(Trainer):
    """Custom trainer with latent distillation loss"""
    
    def compute_loss(self, model, inputs, return_outputs=False):
        # Get student outputs with hidden states
        outputs = model(
            input_ids=inputs['input_ids'],
            attention_mask=inputs['attention_mask'],
            labels=inputs['input_ids'],
            output_hidden_states=True,
            return_dict=True
        )
        
        # 1. Output loss (standard language modeling)
        loss_output = outputs.loss
        
        # 2. Latent distillation loss
        loss_latent = 0.0
        if 'teacher_latents' in inputs:
            teacher_latents = inputs['teacher_latents']
            student_hidden = outputs.hidden_states
            
            num_latent_layers = 0
            for layer_idx in config.LATENT_LAYERS:
                layer_key = f'layer_{layer_idx}'
                if layer_key in teacher_latents and layer_idx < len(student_hidden):
                    # Get student hidden at same layer
                    student_h = student_hidden[layer_idx]
                    student_pooled = student_h.mean(dim=1)  # [batch, hidden]
                    
                    # Teacher latent
                    teacher_h = teacher_latents[layer_key].to(student_pooled.device)
                    
                    # MSE loss
                    loss_latent += F.mse_loss(student_pooled, teacher_h)
                    num_latent_layers += 1
            
            if num_latent_layers > 0:
                loss_latent /= num_latent_layers
        
        # Combined loss
        total_loss = (config.ALPHA_OUTPUT * loss_output + 
                      config.BETA_LATENT * loss_latent)
        
        return (total_loss, outputs) if return_outputs else total_loss

In [None]:
train_dataset = ReasoningDataset(
    train_data,
    student_tokenizer,
    max_length=config.MAX_LENGTH,
    latent_dir=config.LATENT_CACHE_DIR
)

test_dataset = ReasoningDataset(
    test_data,
    student_tokenizer,
    max_length=config.MAX_LENGTH,
    latent_dir=None  # No latent for test
)

# Training arguments
training_args = TrainingArguments(
    output_dir=config.OUTPUT_DIR,
    num_train_epochs=config.NUM_EPOCHS,
    per_device_train_batch_size=config.BATCH_SIZE,
    per_device_eval_batch_size=config.BATCH_SIZE,
    gradient_accumulation_steps=config.GRADIENT_ACCUM,
    learning_rate=config.LEARNING_RATE,
    warmup_steps=config.WARMUP_STEPS,
    logging_steps=50,
    save_steps=500,
    eval_steps=500,
    evaluation_strategy="steps",
    save_total_limit=2,
    load_best_model_at_end=True,
    fp16=True,
    report_to="none",
    remove_unused_columns=False,
)

# Initialize trainer
trainer = DistillationTrainer(
    model=student_model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
)

print("üöÄ Training configuration ready!")

In [None]:
print("üî• Starting training...")
trainer.train()

# Save final model
trainer.save_model(f"{config.OUTPUT_DIR}/final_model")
student_tokenizer.save_pretrained(f"{config.OUTPUT_DIR}/final_model")

print("‚úÖ Training complete!")

In [None]:
def evaluate_reasoning(model, tokenizer, test_data, num_samples=50):
    """Evaluate reasoning accuracy"""
    model.eval()
    correct = 0
    total = 0
    
    print("üìä Evaluating reasoning accuracy...")
    
    with torch.no_grad():
        for idx in tqdm(range(min(num_samples, len(test_data)))):
            item = test_data[idx]
            prompt = prepare_prompt(item['question'])
            
            inputs = tokenizer(
                prompt,
                return_tensors='pt',
                truncation=True,
                max_length=256
            ).to(model.device)
            
            outputs = model.generate(
                **inputs,
                max_new_tokens=128,
                temperature=0.7,
                do_sample=False,
                pad_token_id=tokenizer.eos_token_id
            )
            
            generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            # Simple accuracy check (contains answer)
            ground_truth = str(item['answer'])
            if ground_truth in generated:
                correct += 1
            total += 1
            
            # Print first 3 examples
            if idx < 3:
                print(f"\n{'='*60}")
                print(f"Q: {item['question']}")
                print(f"Ground Truth: {ground_truth}")
                print(f"Generated: {generated[len(prompt):][:200]}...")
    
    accuracy = correct / total if total > 0 else 0
    print(f"\n‚úÖ Accuracy: {accuracy:.2%} ({correct}/{total})")
    return accuracy

# Evaluate
accuracy = evaluate_reasoning(student_model, student_tokenizer, test_data)


In [None]:
# Save metrics
results = {
    'accuracy': float(accuracy),
    'config': {
        'teacher': config.TEACHER_MODEL,
        'student': config.STUDENT_MODEL,
        'lora_r': config.LORA_R,
        'alpha_output': config.ALPHA_OUTPUT,
        'beta_latent': config.BETA_LATENT
    }
}

with open(f"{config.OUTPUT_DIR}/results.json", 'w') as f:
    json.dump(results, f, indent=2)

print("üìÅ Results saved!")

# Inference example
def inference(question: str):
    """Single inference"""
    prompt = prepare_prompt(question)
    inputs = student_tokenizer(prompt, return_tensors='pt').to(student_model.device)
    
    with torch.no_grad():
        outputs = student_model.generate(
            **inputs,
            max_new_tokens=200,
            temperature=0.7,
            do_sample=True,
            pad_token_id=student_tokenizer.eos_token_id
        )
    
    result = student_tokenizer.decode(outputs[0], skip_special_tokens=True)
    return result[len(prompt):]

# Test inference
test_question = "If John has 5 apples and gives 2 to Mary, how many does he have left?"
print(f"\nüß™ Test Inference:")
print(f"Q: {test_question}")
print(f"A: {inference(test_question)}")

print("\n‚ú® Pipeline complete! Model saved at:", config.OUTPUT_DIR)