In [None]:
# Mistral-7B to TinyLlama-1.1B Knowledge Distillation for Carbon Emission Factor Retrieval

# Install required packages with updated bitsandbytes
!pip install -q transformers accelerate peft datasets tqdm torch sentencepiece
!pip install -U bitsandbytes

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m491.2/491.2 kB[0m [31m10.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m2.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m113.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m87.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m51.0 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m1.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m211.5/211.5 MB[0m [31m9.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m56.3/56.3 MB[0m [31m35.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
# Clone the Carbon-EF repository if it doesn't exist
import os
if not os.path.exists('Carbon-EF'):
    !git clone https://github.com/Sbursu/Carbon-EF.git
%cd Carbon-EF


Cloning into 'Carbon-EF'...
remote: Enumerating objects: 659, done.[K
remote: Counting objects: 100% (29/29), done.[K
remote: Compressing objects: 100% (21/21), done.[K
remote: Total 659 (delta 21), reused 8 (delta 8), pack-reused 630 (from 1)[K
Receiving objects: 100% (659/659), 26.87 MiB | 32.91 MiB/s, done.
Resolving deltas: 100% (215/215), done.
/content/Carbon-EF


In [None]:
# Basic imports
import torch
import json
import time
import numpy as np
import random
import re
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
from torch.utils.data import Dataset, DataLoader
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    get_scheduler,
    set_seed,
    BitsAndBytesConfig
)
import gc
import warnings
warnings.filterwarnings("ignore")

# Set seed for reproducibility
SEED = 42
set_seed(SEED)
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

In [None]:
# First, check what files are actually in the Carbon-EF directory
print("Checking for dataset files...")
!find /content/Carbon-EF/data/processed -type f -name "*.json" | sort

# Check the content structure of one of the JSON files to understand the data format
json_files = !find . -type f -name "*.json" | head -n 1
if json_files:
    !head -n 20 {json_files[0]}

Checking for dataset files...
/content/Carbon-EF/data/processed/harmonized_global_ef_dataset_metadata.json
/content/Carbon-EF/data/processed/instructions.json
/content/Carbon-EF/data/processed/instructions_test.json
/content/Carbon-EF/data/processed/instructions_train.json
/content/Carbon-EF/data/processed/instructions_val.json
{
    "model": {
        "base_model": "mistralai/Mistral-7B-v0.1",
        "lora_rank": 64,
        "lora_alpha": 16,
        "lora_dropout": 0.05,
        "target_modules": [
            "q_proj",
            "v_proj",
            "k_proj",
            "o_proj",
            "gate_proj",
            "up_proj",
            "down_proj"
        ]
    },
    "training": {
        "learning_rate": 3e-4,
        "num_epochs": 3,
        "batch_size": 8,


In [None]:
# Configuration class
class DistillationConfig:
    # Model paths
    teacher_model_path = "/content/drive/MyDrive/mistral_merged_model"  # Path to your fine-tuned Mistral-7B
    student_model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"

    # Output paths
    output_dir = "/content/tinyllama_distilled"
    checkpoint_dir = "/content/tinyllama_distilled/checkpoints"
    final_model_save_path = "/content/drive/MyDrive/tinyllama_distilled_final"

    # Dataset paths - using the files found in the repository
    data_dir = "/content/Carbon-EF/data/processed"
    train_file = "instructions_train.json"
    val_file = "instructions_val.json"
    test_file = "instructions_test.json"

    # Training parameters - Modified for memory efficiency
    batch_size = 1  # Single example per batch for memory efficiency
    gradient_accumulation_steps = 16  # Accumulate gradients across more steps
    learning_rate = 1e-6  # Reduced learning rate for stability
    weight_decay = 0.01
    num_epochs = 3
    warmup_ratio = 0.1
    max_grad_norm = 1.0  # Added gradient clipping

    # Model sequence length - reduce this to save memory
    max_length = 256  # Reduced from 512 to save memory

    # Distillation parameters
    temperature = 1.0  # Reduced from 2.0 for more stability
    alpha = 0.5  # Weight for distillation loss vs task-specific loss

    # Checkpoint parameters
    save_steps = 200  # More frequent checkpoints
    eval_steps = 1000

    # Resource management - Enhanced for memory efficiency
    load_in_4bit = True  # Use 4-bit quantization instead of 8-bit
    gradient_checkpointing = True
    offload_to_cpu = True  # Enable CPU offloading for the model

    # Memory optimization
    torch_compile = False  # Enable torch.compile() - can help with memory
    use_flash_attention = True  # Use flash attention if available

    # Resume training
    resume_from_checkpoint = True  # Enable resuming from checkpoints

    # Device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = DistillationConfig()

# Update the config with actual file names found in the repository
print(f"Using dataset files from: {config.data_dir}")
print(f"Training file: {os.path.join(config.data_dir, config.train_file)}")
print(f"Validation file: {os.path.join(config.data_dir, config.val_file)}")
print(f"Test file: {os.path.join(config.data_dir, config.test_file)}")

# Verify that all files exist
for file_name in [config.train_file, config.val_file, config.test_file]:
    file_path = os.path.join(config.data_dir, file_name)
    if os.path.exists(file_path):
        print(f"✓ Found {file_name}")
    else:
        print(f"✗ Could not find {file_name} - please check the path")

# First, let's check for existing training runs and ask if we want to resume
if os.path.exists(config.checkpoint_dir):
    latest_checkpoint = find_latest_checkpoint(config.checkpoint_dir)
    if latest_checkpoint and config.resume_from_checkpoint:
        print(f"Found existing checkpoint at: {latest_checkpoint}")
        resume_option = input("Do you want to resume training from this checkpoint? (yes/no): ")
        if resume_option.lower() == "yes":
            print("Will resume training from the latest checkpoint.")
        else:
            print("Will start training from scratch.")
            config.resume_from_checkpoint = False

            # Clean up previous directories
            print(f"Removing previous training directory: {config.output_dir}")
            shutil.rmtree(config.output_dir)

            # Also ask if we want to remove the final model directory
            remove_option = input("Do you also want to remove the previous final model? (yes/no): ")
            if remove_option.lower() == "yes" and os.path.exists(config.final_model_save_path):
                print(f"Removing previous final model directory: {config.final_model_save_path}")
                shutil.rmtree(config.final_model_save_path)
else:
    # No existing checkpoints found
    print("No existing checkpoints found. Starting fresh training.")

# Create necessary directories
os.makedirs(config.output_dir, exist_ok=True)
os.makedirs(config.checkpoint_dir, exist_ok=True)
os.makedirs(config.final_model_save_path, exist_ok=True)

Using dataset files from: /content/Carbon-EF/data/processed
Training file: /content/Carbon-EF/data/processed/instructions_train.json
Validation file: /content/Carbon-EF/data/processed/instructions_val.json
Test file: /content/Carbon-EF/data/processed/instructions_test.json
✓ Found instructions_train.json
✓ Found instructions_val.json
✓ Found instructions_test.json
No existing checkpoints found. Starting fresh training.


In [None]:
# Memory Management
def cleanup():
    """Clean up memory"""
    # Memory cleanup
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

In [None]:
# Checkpoint Management Functions
def save_checkpoint(model, optimizer, scheduler, step, epoch, config, global_step):
    """Storage-efficient checkpoint saving"""
    # Only keep latest checkpoint to save space
    # First clean up old checkpoints
    if os.path.exists(config.checkpoint_dir):
        for old_cp in os.listdir(config.checkpoint_dir):
            if old_cp.startswith("checkpoint-"):
                old_path = os.path.join(config.checkpoint_dir, old_cp)
                if os.path.isdir(old_path):
                    import shutil
                    try:
                        shutil.rmtree(old_path)
                        print(f"Removed old checkpoint: {old_path}")
                    except Exception as e:
                        print(f"Error removing old checkpoint: {e}")

    checkpoint_path = os.path.join(config.checkpoint_dir, f"checkpoint-{global_step}")
    os.makedirs(checkpoint_path, exist_ok=True)

    # Save model
    model.save_pretrained(checkpoint_path)

    # Save optimizer and scheduler
    torch.save({
        'optimizer': optimizer.state_dict(),
        'scheduler': scheduler.state_dict() if scheduler else None,
        'step': step,
        'epoch': epoch,
        'global_step': global_step,
        'config': config.__dict__
    }, os.path.join(checkpoint_path, "training_state.pt"))

    # Only create backups at major milestones and store locally (not on Drive)
    if global_step % (config.save_steps * 10) == 0:
        # Clean up any old backups first
        for old_model in os.listdir(config.output_dir):
            if old_model.startswith("model-"):
                old_path = os.path.join(config.output_dir, old_model)
                if os.path.isdir(old_path):
                    import shutil
                    try:
                        shutil.rmtree(old_path)
                        print(f"Removed old backup: {old_path}")
                    except Exception as e:
                        print(f"Error removing old backup: {e}")

        # Save new backup
        backup_path = os.path.join(config.output_dir, f"model-{global_step}")
        model.save_pretrained(backup_path)

    print(f"Checkpoint saved at step {global_step} (storage-efficient mode)")

def find_latest_checkpoint(checkpoint_dir):
    """Find the latest checkpoint in the directory"""
    if not os.path.exists(checkpoint_dir):
        return None

    checkpoints = [d for d in os.listdir(checkpoint_dir) if d.startswith("checkpoint-")]
    if not checkpoints:
        return None

    # Extract steps from checkpoint names and find the latest
    steps = [int(cp.split("-")[1]) for cp in checkpoints]
    latest_step = max(steps)
    return os.path.join(checkpoint_dir, f"checkpoint-{latest_step}")

def load_checkpoint(student_model, optimizer, scheduler, config):
    """Load checkpoint for resuming training"""
    latest_checkpoint = find_latest_checkpoint(config.checkpoint_dir)

    if latest_checkpoint and config.resume_from_checkpoint:
        print(f"Resuming from checkpoint: {latest_checkpoint}")

        # Load model
        student_model = AutoModelForCausalLM.from_pretrained(
            latest_checkpoint,
            torch_dtype=torch.float16,
            device_map="auto",
            use_cache=not config.gradient_checkpointing
        )

        # Load optimizer and scheduler state
        training_state = torch.load(os.path.join(latest_checkpoint, "training_state.pt"))
        optimizer.load_state_dict(training_state['optimizer'])
        if scheduler and training_state['scheduler']:
            scheduler.load_state_dict(training_state['scheduler'])

        step = training_state['step']
        epoch = training_state['epoch']
        global_step = training_state['global_step']

        return student_model, optimizer, scheduler, step, epoch, global_step

    return student_model, optimizer, scheduler, 0, 0, 0

In [None]:
# Evaluation Function
def evaluate(model, eval_loader, epoch, global_step, config):
    """Storage-efficient evaluation function"""
    model.eval()
    eval_loss = 0

    with torch.no_grad():
        for batch in tqdm(eval_loader, desc="Evaluating"):
            batch = {k: v.to(config.device) for k, v in batch.items()}

            outputs = model(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                labels=batch["labels"]
            )

            eval_loss += outputs.loss.item()

    avg_eval_loss = eval_loss / len(eval_loader)
    print(f"Evaluation at epoch {epoch+1}, step {global_step}: Loss = {avg_eval_loss}")

    # Keep evaluation results in a single file to save space
    eval_file = os.path.join(config.output_dir, "eval_results.txt")

    # If file gets too large, truncate it to save space
    if os.path.exists(eval_file) and os.path.getsize(eval_file) > 1024 * 1024:  # If over 1MB
        # Keep only the last 100 lines
        with open(eval_file, 'r') as f:
            lines = f.readlines()

        with open(eval_file, 'w') as f:
            f.writelines(lines[-100:])

    # Append new results
    with open(eval_file, "a") as f:
        f.write(f"Step: {global_step}, Epoch: {epoch+1}, Loss: {avg_eval_loss}\n")

    model.train()
    return avg_eval_loss

In [None]:
# Dataset and Loss Function
class EmissionFactorDataset(Dataset):
    """Custom dataset for emission factor retrieval"""
    def __init__(self, file_path, tokenizer, max_length=512):
        self.data = []
        self.tokenizer = tokenizer
        self.max_length = max_length

        # Load and process data
        with open(file_path, 'r', encoding='utf-8') as f:
            data = json.load(f)

        # Check the structure of the first item to determine format
        if isinstance(data, list):
            # Analyze the first item to determine data structure
            if data and isinstance(data[0], dict):
                # Check what keys are available
                keys = set(data[0].keys())

                # Handle different data formats
                if all(k in keys for k in ["instruction", "output"]):
                    # Standard instruction-output format
                    for item in data:
                        self.data.append({
                            "instruction": item.get("instruction", ""),
                            "input": item.get("input", ""),
                            "output": item.get("output", "")
                        })
                elif all(k in keys for k in ["prompt", "completion"]):
                    # OpenAI-style format
                    for item in data:
                        # Split prompt into instruction and optional input
                        prompt = item.get("prompt", "")
                        parts = prompt.split("\n\n", 1)
                        instruction = parts[0]
                        input_text = parts[1] if len(parts) > 1 else ""

                        self.data.append({
                            "instruction": instruction,
                            "input": input_text,
                            "output": item.get("completion", "")
                        })
                elif "messages" in keys:
                    # ChatML format
                    for item in data:
                        messages = item.get("messages", [])
                        user_msgs = [m for m in messages if m.get("role") == "user"]
                        assistant_msgs = [m for m in messages if m.get("role") == "assistant"]

                        if user_msgs and assistant_msgs:
                            self.data.append({
                                "instruction": user_msgs[0].get("content", ""),
                                "input": " ".join([m.get("content", "") for m in user_msgs[1:]]),
                                "output": assistant_msgs[-1].get("content", "")
                            })
                else:
                    # Generic approach - try to extract relevant fields
                    for item in data:
                        instruction = ""
                        input_text = ""
                        output = ""

                        # Try to find appropriate keys
                        for k, v in item.items():
                            if "instruction" in k.lower() or "query" in k.lower() or "question" in k.lower():
                                instruction = v
                            elif "input" in k.lower() or "context" in k.lower():
                                input_text = v
                            elif "output" in k.lower() or "response" in k.lower() or "answer" in k.lower() or "completion" in k.lower():
                                output = v

                        if instruction and output:
                            self.data.append({
                                "instruction": instruction,
                                "input": input_text,
                                "output": output
                            })
            else:
                # Simple text list - create instruction/output pairs
                # This is a fallback for simple datasets
                for i in range(0, len(data), 2):
                    if i+1 < len(data):
                        self.data.append({
                            "instruction": "Continue the following text or answer the query:",
                            "input": data[i],
                            "output": data[i+1]
                        })

        print(f"Loaded {len(self.data)} examples from {file_path}")

        # Print a sample for verification
        if self.data:
            print("\nSample data item:")
            print(f"Instruction: {self.data[0]['instruction'][:100]}...")
            print(f"Input: {self.data[0]['input'][:100]}..." if self.data[0]['input'] else "Input: [Empty]")
            print(f"Output: {self.data[0]['output'][:100]}...")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        item = self.data[idx]

        # Format the text as per model requirements
        if item["input"]:
            text = f"<s>[INST] {item['instruction']}\n{item['input']} [/INST] {item['output']}</s>"
        else:
            text = f"<s>[INST] {item['instruction']} [/INST] {item['output']}</s>"

        # Encode the text
        encodings = self.tokenizer(
            text,
            truncation=True,
            max_length=self.max_length,
            padding="max_length",
            return_tensors="pt"
        )

        # Prepare the inputs
        input_ids = encodings.input_ids[0]
        attention_mask = encodings.attention_mask[0]

        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": input_ids.clone()  # For autoregressive training
        }

def distillation_loss(student_logits, teacher_logits, labels, temperature, alpha):
    """
    Combined loss function for knowledge distillation
    - alpha: weight for distillation loss vs task-specific loss
    - temperature: softmax temperature for distillation
    """
    # Handle potential NaN and Inf values
    student_logits = torch.nan_to_num(student_logits)
    teacher_logits = torch.nan_to_num(teacher_logits)

    # KL divergence for distribution matching
    teacher_probs = torch.nn.functional.softmax(teacher_logits / temperature, dim=-1)

    # For numerical stability, use log_softmax and sum instead of KL div
    log_probs_student = torch.nn.functional.log_softmax(student_logits / temperature, dim=-1)

    # Safer KL divergence implementation
    distill_loss = -(teacher_probs * log_probs_student).sum(dim=-1).mean() * (temperature ** 2)

    # Check for NaN values
    if torch.isnan(distill_loss):
        print("Warning: Distillation loss is NaN! Using fallback")
        distill_loss = torch.tensor(0.0, device=student_logits.device)

    # Task loss: cross-entropy on ground truth
    task_loss = torch.nn.functional.cross_entropy(
        student_logits.view(-1, student_logits.size(-1)),
        labels.view(-1),
        ignore_index=-100,
        reduction="mean"
    )

    # Check for NaN values
    if torch.isnan(task_loss):
        print("Warning: Task loss is NaN! Using fallback")
        task_loss = torch.tensor(1.0, device=student_logits.device)

    # Combined loss
    loss = alpha * distill_loss + (1 - alpha) * task_loss

    # Final sanity check
    if torch.isnan(loss) or torch.isinf(loss):
        print("Warning: Combined loss is NaN or Inf! Using fallback")
        loss = task_loss  # Fallback to just task loss

    return loss, task_loss, distill_loss

In [None]:
# Main Training Function
def train():
    """Main function for knowledge distillation training"""
    print("Loading tokenizers...")
    # Load student tokenizer
    student_tokenizer = AutoTokenizer.from_pretrained(config.student_model_name)
    student_tokenizer.pad_token = student_tokenizer.eos_token

    # Load teacher tokenizer (assuming same tokenizer as Mistral)
    teacher_tokenizer = AutoTokenizer.from_pretrained(config.teacher_model_path)
    teacher_tokenizer.pad_token = teacher_tokenizer.eos_token

    print("Preparing datasets...")
    # Prepare datasets
    train_dataset = EmissionFactorDataset(
        os.path.join(config.data_dir, config.train_file),
        student_tokenizer
    )
    val_dataset = EmissionFactorDataset(
        os.path.join(config.data_dir, config.val_file),
        student_tokenizer
    )

    # Create data loaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        pin_memory=True
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=config.batch_size,
        shuffle=False,
        pin_memory=True
    )

    print(f"Train dataset size: {len(train_dataset)}")
    print(f"Validation dataset size: {len(val_dataset)}")

    print("Loading teacher model...")
    # Check if the teacher model path exists
    if not os.path.exists(config.teacher_model_path):
        raise FileNotFoundError(f"Teacher model path not found: {config.teacher_model_path}")

    # Load teacher model with more memory-efficient settings
    print("Loading teacher model with memory optimization...")

    # Use the most aggressive memory optimization for the teacher model
    bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,  # Use 4-bit quantization instead of 8-bit
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4"
    )

    teacher_model = AutoModelForCausalLM.from_pretrained(
        config.teacher_model_path,
        quantization_config=bnb_config,
        device_map="auto",
        torch_dtype=torch.float16,
        offload_folder="offload_folder",  # Enable CPU offloading for parts of the model
        offload_state_dict=True  # Offload state dict to CPU
    )
    teacher_model.eval()  # Set to evaluation mode

    print("Loading student model...")
    # Load student model
    student_model = AutoModelForCausalLM.from_pretrained(
        config.student_model_name,
        torch_dtype=torch.float16,
        device_map="auto"
    )

    if config.gradient_checkpointing:
        student_model.gradient_checkpointing_enable()

    # Prepare optimizer
    optimizer = torch.optim.AdamW(
        student_model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay
    )

    # Learning rate scheduler
    total_steps = len(train_loader) * config.num_epochs // config.gradient_accumulation_steps
    warmup_steps = int(total_steps * config.warmup_ratio)

    scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=warmup_steps,
        num_training_steps=total_steps
    )

    # Try to load checkpoint
    student_model, optimizer, scheduler, step, start_epoch, global_step = load_checkpoint(
        student_model, optimizer, scheduler, config
    )

    # Training loop
    print(f"Starting training from epoch {start_epoch}, step {step}, global_step {global_step}")
    student_model.train()

    for epoch in range(start_epoch, config.num_epochs):
        epoch_loss = 0
        epoch_task_loss = 0
        epoch_distill_loss = 0

        progress_bar = tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}")

        # Skip steps that were already processed in the current epoch
        if epoch == start_epoch and step > 0:
            print(f"Skipping to step {step} in epoch {epoch}")
            for i, _ in enumerate(progress_bar):
                if i >= step:
                    break

        for i, batch in enumerate(train_loader):
            # Skip steps that were already processed
            if epoch == start_epoch and i < step:
                continue

            batch = {k: v.to(config.device) for k, v in batch.items()}

            # Forward pass through student
            student_outputs = student_model(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                labels=batch["labels"],
                return_dict=True
            )

            # Get student logits
            student_logits = student_outputs.logits

            # Forward pass through teacher (no gradients needed)
            with torch.no_grad():
                teacher_outputs = teacher_model(
                    input_ids=batch["input_ids"],
                    attention_mask=batch["attention_mask"],
                    return_dict=True
                )
                teacher_logits = teacher_outputs.logits

            # Calculate combined loss
            loss, task_loss, distill_loss = distillation_loss(
                student_logits,
                teacher_logits,
                batch["labels"],
                config.temperature,
                config.alpha
            )

            # Scale loss for gradient accumulation
            loss = loss / config.gradient_accumulation_steps

            # Backward pass
            loss.backward()

            # Add gradient clipping to prevent explosion
            if (i + 1) % config.gradient_accumulation_steps == 0 or i == len(train_loader) - 1:
                # Clip gradients to prevent NaN and instability
                torch.nn.utils.clip_grad_norm_(student_model.parameters(), config.max_grad_norm)

                epoch_loss += loss.item() * config.gradient_accumulation_steps
                epoch_task_loss += task_loss.item()
                epoch_distill_loss += distill_loss.item()

            # Update weights if we've accumulated enough gradients
            if (i + 1) % config.gradient_accumulation_steps == 0 or i == len(train_loader) - 1:
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()

                global_step += 1

                # Save checkpoint
                if global_step % config.save_steps == 0:
                    save_checkpoint(student_model, optimizer, scheduler, i+1, epoch, config, global_step)

                # Evaluate
                if global_step % config.eval_steps == 0:
                    evaluate(student_model, val_loader, epoch, global_step, config)

            # Update progress bar with warnings if losses are abnormal
            progress_info = {
                'loss': epoch_loss / (i + 1) if not np.isnan(epoch_loss / (i + 1)) else "NaN",
                'task_loss': epoch_task_loss / (i + 1) if not np.isnan(epoch_task_loss / (i + 1)) else "NaN",
                'distill_loss': epoch_distill_loss / (i + 1) if not np.isnan(epoch_distill_loss / (i + 1)) else "NaN",
                'lr': scheduler.get_last_lr()[0]
            }

            # Add warning flag if any loss is NaN
            if np.isnan(epoch_loss / (i + 1)) or np.isnan(epoch_task_loss / (i + 1)) or np.isnan(epoch_distill_loss / (i + 1)):
                progress_info['warning'] = "Loss is NaN!"

            progress_bar.set_postfix(progress_info)

        # Save checkpoint at the end of each epoch
        save_checkpoint(student_model, optimizer, scheduler, 0, epoch+1, config, global_step)

        print(f"Epoch {epoch+1} complete. Average loss: {epoch_loss / len(train_loader)}")

    # Save final model to Google Drive
    student_model.save_pretrained(config.final_model_save_path)
    student_tokenizer.save_pretrained(config.final_model_save_path)
    print(f"Training complete. Final model saved to Drive at: {config.final_model_save_path}")

    # Clean up local storage
    print("Cleaning up local storage to free up space...")
    if os.path.exists(config.output_dir):
        import shutil
        try:
            shutil.rmtree(config.output_dir)
            print(f"Removed local storage at {config.output_dir}")
        except Exception as e:
            print(f"Error cleaning up local storage: {e}")

In [None]:
# Benchmarking Function
def benchmark_models():
    """
    Compare teacher and student models on key metrics using the test set.
    Measures: accuracy, perplexity, exact match, latency, and emission factor accuracy.
    """
    print("\n=== Starting Model Quality Benchmarking ===")

    # Load test data
    test_file_path = os.path.join(config.data_dir, config.test_file)
    print(f"Loading test data from {test_file_path}...")

    with open(test_file_path, 'r', encoding='utf-8') as f:
        test_data = json.load(f)

    # Load tokenizers
    teacher_tokenizer = AutoTokenizer.from_pretrained(config.teacher_model_path)
    student_tokenizer = AutoTokenizer.from_pretrained(config.final_model_save_path)

    # Load models
    print("Loading teacher model...")
    teacher_config = BitsAndBytesConfig(
        load_in_8bit=True,
        bnb_8bit_compute_dtype=torch.float16
    )

    teacher_model = AutoModelForCausalLM.from_pretrained(
        config.teacher_model_path,
        quantization_config=teacher_config,
        device_map="auto",
        torch_dtype=torch.float16
    )

    print("Loading student model...")
    student_model = AutoModelForCausalLM.from_pretrained(
        config.final_model_save_path,
        device_map="auto",
        torch_dtype=torch.float16
    )

    # Prepare results dictionary
    results = {
        "perplexity": {"teacher": 0, "student": 0},
        "exact_matches": {"teacher": 0, "student": 0},
        "emission_factor_accuracy": {"teacher": 0, "student": 0},
        "latency": {"teacher": [], "student": []},
        "model_size": {"teacher": 0, "student": 0}
    }

    # Function to extract emission factors from text
    def extract_emission_factors(text):
        """Extract emission factor values with their units"""
        patterns = [
            r"(\d+\.?\d*)\s*(kg\s*CO2e?(?:/\w+)?)",
            r"(\d+\.?\d*)\s*(tons?\s*CO2e?(?:/\w+)?)",
            r"(\d+\.?\d*)\s*(g\s*CO2e?(?:/\w+)?)",
            r"(\d+\.?\d*)\s*(tCO2e?(?:/\w+)?)"
        ]

        findings = []
        for pattern in patterns:
            matches = re.findall(pattern, text, re.IGNORECASE)
            for match in matches:
                findings.append({
                    "value": float(match[0]),
                    "unit": match[1].strip()
                })
        return findings

    # Calculate perplexity
    print("Evaluating perplexity...")

    def calculate_perplexity(model, tokenizer, texts, max_samples=100):
        # Limit to max_samples to save time
        sample_texts = texts[:max_samples] if len(texts) > max_samples else texts

        nlls = []
        model.eval()
        with torch.no_grad():
            for text in tqdm(sample_texts, desc="Perplexity"):
                encodings = tokenizer(text, return_tensors="pt").to(model.device)
                outputs = model(**encodings, labels=encodings.input_ids)
                nlls.append(outputs.loss.item())

        return torch.exp(torch.tensor(nlls).mean()).item()

    # Extract sample outputs for perplexity calculation
    sample_outputs = [item["output"] for item in test_data[:100]]

    # Calculate perplexity
    results["perplexity"]["teacher"] = calculate_perplexity(
        teacher_model, teacher_tokenizer, sample_outputs
    )
    results["perplexity"]["student"] = calculate_perplexity(
        student_model, student_tokenizer, sample_outputs
    )

    print(f"Teacher perplexity: {results['perplexity']['teacher']:.2f}")
    print(f"Student perplexity: {results['perplexity']['student']:.2f}")

    # Generate predictions and measure quality metrics
    print("Generating predictions and measuring metrics...")

    teacher_predictions = []
    student_predictions = []

    # Process test samples
    for i, item in enumerate(tqdm(test_data[:100], desc="Testing")):  # Limit to 100 samples
        # Format input
        if item["input"]:
            prompt = f"<s>[INST] {item['instruction']}\n{item['input']} [/INST]"
        else:
            prompt = f"<s>[INST] {item['instruction']} [/INST]"

        # Ground truth
        ground_truth = item["output"]

        # Teacher prediction with latency measurement
        teacher_input = teacher_tokenizer(prompt, return_tensors="pt").to(teacher_model.device)

        teacher_start_time = time.time()
        with torch.no_grad():
            teacher_output = teacher_model.generate(
                teacher_input.input_ids,
                max_new_tokens=256,
                do_sample=False
            )
        teacher_end_time = time.time()

        teacher_response = teacher_tokenizer.decode(
            teacher_output[0][teacher_input.input_ids.shape[1]:],
            skip_special_tokens=True
        )
        teacher_predictions.append(teacher_response)
        results["latency"]["teacher"].append(teacher_end_time - teacher_start_time)

        # Student prediction with latency measurement
        student_input = student_tokenizer(prompt, return_tensors="pt").to(student_model.device)

        student_start_time = time.time()
        with torch.no_grad():
            student_output = student_model.generate(
                student_input.input_ids,
                max_new_tokens=256,
                do_sample=False
            )
        student_end_time = time.time()

        student_response = student_tokenizer.decode(
            student_output[0][student_input.input_ids.shape[1]:],
            skip_special_tokens=True
        )
        student_predictions.append(student_response)
        results["latency"]["student"].append(student_end_time - student_start_time)

        # Check for exact matches
        if teacher_response.strip() == ground_truth.strip():
            results["exact_matches"]["teacher"] += 1

        if student_response.strip() == ground_truth.strip():
            results["exact_matches"]["student"] += 1

        # Check emission factor accuracy
        teacher_factors = extract_emission_factors(teacher_response)
        student_factors = extract_emission_factors(student_response)
        ground_truth_factors = extract_emission_factors(ground_truth)

        if ground_truth_factors:
            # For teacher
            if teacher_factors:
                matches = 0
                for gt_factor in ground_truth_factors:
                    for t_factor in teacher_factors:
                        # Check if value is within 5% of ground truth
                        if abs(t_factor["value"] - gt_factor["value"]) / gt_factor["value"] < 0.05:
                            matches += 1
                            break

                teacher_accuracy = matches / len(ground_truth_factors)
                results["emission_factor_accuracy"]["teacher"] += teacher_accuracy

            # For student
            if student_factors:
                matches = 0
                for gt_factor in ground_truth_factors:
                    for s_factor in student_factors:
                        # Check if value is within 5% of ground truth
                        if abs(s_factor["value"] - gt_factor["value"]) / gt_factor["value"] < 0.05:
                            matches += 1
                            break

                student_accuracy = matches / len(ground_truth_factors)
                results["emission_factor_accuracy"]["student"] += student_accuracy

    # Calculate averages
    num_samples = min(len(test_data), 100)  # Number of samples we processed

    # Compute exact match percentage
    results["exact_matches"]["teacher"] = (results["exact_matches"]["teacher"] / num_samples) * 100
    results["exact_matches"]["student"] = (results["exact_matches"]["student"] / num_samples) * 100

    # Compute average emission factor accuracy
    results["emission_factor_accuracy"]["teacher"] = (results["emission_factor_accuracy"]["teacher"] / num_samples) * 100
    results["emission_factor_accuracy"]["student"] = (results["emission_factor_accuracy"]["student"] / num_samples) * 100

    # Calculate average latency
    avg_teacher_latency = sum(results["latency"]["teacher"]) / len(results["latency"]["teacher"])
    avg_student_latency = sum(results["latency"]["student"]) / len(results["latency"]["student"])

    # Calculate model sizes
    def get_model_size_mb(model):
        """Calculate model size in MB"""
        param_size = 0
        for param in model.parameters():
            param_size += param.nelement() * param.element_size()
        return param_size / (1024 * 1024)

    results["model_size"]["teacher"] = get_model_size_mb(teacher_model)
    results["model_size"]["student"] = get_model_size_mb(student_model)

    # Print summary
    print("\n=== Benchmark Results ===")
    print(f"Model Size: Teacher = {results['model_size']['teacher']:.1f} MB, Student = {results['model_size']['student']:.1f} MB (Reduction: {(1 - results['model_size']['student']/results['model_size']['teacher'])*100:.1f}%)")
    print(f"Perplexity: Teacher = {results['perplexity']['teacher']:.2f}, Student = {results['perplexity']['student']:.2f}")
    print(f"Exact Match: Teacher = {results['exact_matches']['teacher']:.2f}%, Student = {results['exact_matches']['student']:.2f}%")
    print(f"Emission Factor Accuracy: Teacher = {results['emission_factor_accuracy']['teacher']:.2f}%, Student = {results['emission_factor_accuracy']['student']:.2f}%")
    print(f"Average Latency: Teacher = {avg_teacher_latency:.3f}s, Student = {avg_student_latency:.3f}s (Speedup: {(1 - avg_student_latency/avg_teacher_latency)*100:.1f}%)")

    # Create visualization
    metrics = [
        "Model Size (MB)",
        "Perplexity",
        "Exact Match (%)",
        "Emission Factor Accuracy (%)",
        "Latency (s)"
    ]

    teacher_values = [
        results["model_size"]["teacher"] / 1000,  # Convert to GB for better visualization
        results["perplexity"]["teacher"],
        results["exact_matches"]["teacher"],
        results["emission_factor_accuracy"]["teacher"],
        avg_teacher_latency
    ]

    student_values = [
        results["model_size"]["student"] / 1000,  # Convert to GB for better visualization
        results["perplexity"]["student"],
        results["exact_matches"]["student"],
        results["emission_factor_accuracy"]["student"],
        avg_student_latency
    ]

    # Create a bar chart
    fig, ax = plt.subplots(figsize=(12, 7))

    x = np.arange(len(metrics))
    width = 0.35

    rects1 = ax.bar(x - width/2, teacher_values, width, label='Teacher (Mistral-7B)')
    rects2 = ax.bar(x + width/2, student_values, width, label='Student (TinyLlama-1.1B)')

    ax.set_title('Model Quality Comparison')
    ax.set_xticks(x)
    ax.set_xticklabels(metrics)
    ax.legend()

    # Save the visualization
    benchmark_dir = os.path.join(config.final_model_save_path, "benchmarks")
    os.makedirs(benchmark_dir, exist_ok=True)
    plt.savefig(os.path.join(benchmark_dir, "model_quality_comparison.png"))
    plt.close()

    # Save detailed results to JSON
    with open(os.path.join(benchmark_dir, "benchmark_results.json"), "w") as f:
        # Convert to serializable format
        serializable_results = {
            "perplexity": {
                "teacher": float(results["perplexity"]["teacher"]),
                "student": float(results["perplexity"]["student"]),
                "relative_change": float((results["perplexity"]["student"] - results["perplexity"]["teacher"]) / results["perplexity"]["teacher"] * 100)
            },
            "exact_matches": {
                "teacher": float(results["exact_matches"]["teacher"]),
                "student": float(results["exact_matches"]["student"]),
                "relative_change": float((results["exact_matches"]["student"] - results["exact_matches"]["teacher"]) / results["exact_matches"]["teacher"] * 100) if results["exact_matches"]["teacher"] > 0 else 0
            },
            "emission_factor_accuracy": {
                "teacher": float(results["emission_factor_accuracy"]["teacher"]),
                "student": float(results["emission_factor_accuracy"]["student"]),
                "relative_change": float((results["emission_factor_accuracy"]["student"] - results["emission_factor_accuracy"]["teacher"]) / results["emission_factor_accuracy"]["teacher"] * 100) if results["emission_factor_accuracy"]["teacher"] > 0 else 0
            },
            "latency": {
                "teacher": float(avg_teacher_latency),
                "student": float(avg_student_latency),
                "speedup_percent": float((1 - avg_student_latency / avg_teacher_latency) * 100)
            },
            "model_size": {
                "teacher_mb": float(results["model_size"]["teacher"]),
                "student_mb": float(results["model_size"]["student"]),
                "reduction_percent": float((1 - results["model_size"]["student"] / results["model_size"]["teacher"]) * 100)
            }
        }
        json.dump(serializable_results, f, indent=2)

    # Save a sample of predictions
    with open(os.path.join(benchmark_dir, "sample_predictions.json"), "w") as f:
        samples = []
        for i in range(min(10, len(teacher_predictions))):
            sample = {
                "instruction": test_data[i]["instruction"],
                "input": test_data[i]["input"] if "input" in test_data[i] else "",
                "ground_truth": test_data[i]["output"],
                "teacher_prediction": teacher_predictions[i],
                "student_prediction": student_predictions[i]
            }
            samples.append(sample)
        json.dump(samples, f, indent=2)

    print(f"Benchmark results saved to {benchmark_dir}")
    return results

In [None]:
def quantize_for_mobile():
    """Quantize the distilled model to INT8 for mobile deployment"""
    from transformers import BitsAndBytesConfig

    print("Quantizing model for mobile deployment...")

    # Load the model with quantization config
    quantization_config = BitsAndBytesConfig(
        load_in_8bit=True,
        llm_int8_threshold=6.0
    )

    # Load from final save path
    model = AutoModelForCausalLM.from_pretrained(
        config.final_model_save_path,
        quantization_config=quantization_config,
        device_map="auto"
    )

    # Save in GGUF format if possible, otherwise just save in 8-bit
    try:
        print("Attempting to save in GGUF format for maximum compression...")
        # If gguf library is available
        try:
            from gguf import GGUFWriter
            # GGUF implementation would go here
            # Since gguf direct integration might not be available, fallback to system commands
            !pip install -q llama-cpp-python
            output_path = os.path.join(config.final_model_save_path, "model_quantized")
            os.makedirs(output_path, exist_ok=True)
            !python -m llama_cpp.quantize {config.final_model_save_path} {output_path}/model_q4_0.gguf q4_0
            print(f"GGUF model saved to {output_path}/model_q4_0.gguf")
        except ImportError:
            # Fallback to system command if library not importable directly
            !pip install -q llama-cpp-python
            output_path = os.path.join(config.final_model_save_path, "model_quantized")
            os.makedirs(output_path, exist_ok=True)
            !python -m llama_cpp.quantize {config.final_model_save_path} {output_path}/model_q4_0.gguf q4_0
            print(f"GGUF model saved to {output_path}/model_q4_0.gguf")
    except Exception as e:
        print(f"GGUF conversion failed: {e}")
        print("Falling back to standard 8-bit quantized model...")
        # Directly save back to Drive folder
        model.save_pretrained(config.final_model_save_path)

    # Add a marker file to indicate this is quantized
    with open(os.path.join(config.final_model_save_path, "QUANTIZED.txt"), "w") as f:
        f.write("This model was quantized to INT8 for mobile deployment")

    print(f"Quantized model saved successfully")

In [None]:
# Main Execution
if __name__ == "__main__":
    try:
        start_time = time.time()

        # Step 1: Train the model
        print("=== Starting Distillation Training ===")
        train()

        # Step 2: Quantize the model
        print("\n=== Quantizing Model for Mobile Deployment ===")
        quantize_for_mobile()

        # Step 3: Benchmark the models
        print("\n=== Benchmarking Model Quality ===")
        benchmark_results = benchmark_models()

        print(f"\nTotal execution time: {(time.time() - start_time) / 60:.2f} minutes")
        print("Complete process finished successfully!")

    except KeyboardInterrupt:
        print("Process interrupted.")
        save_option = input("Do you want to save an emergency checkpoint to Drive? (yes/no): ")
        if save_option.lower() == "yes":
            print("Saving emergency checkpoint...")
            # Get current model, optimizer, scheduler and state
            if 'student_model' in locals() and 'optimizer' in locals():
                emergency_path = os.path.join(config.final_model_save_path, "emergency_backup")
                os.makedirs(emergency_path, exist_ok=True)

                try:
                    # Save the model
                    student_model.save_pretrained(emergency_path)

                    # Save optimizer state and other training info
                    torch.save({
                        'optimizer': optimizer.state_dict() if 'optimizer' in locals() else None,
                        'scheduler': scheduler.state_dict() if 'scheduler' in locals() else None,
                        'epoch': epoch if 'epoch' in locals() else 0,
                        'step': i if 'i' in locals() else 0,
                        'global_step': global_step if 'global_step' in locals() else 0
                    }, os.path.join(emergency_path, "emergency_training_state.pt"))

                    print(f"Emergency checkpoint saved to {emergency_path}")
                except Exception as e:
                    print(f"Error saving emergency checkpoint: {e}")
            else:
                print("No model available to save yet.")
    except Exception as e:
        print(f"Error during execution: {e}")
        import traceback
        traceback.print_exc()

=== Starting Distillation Training ===
Loading tokenizers...


tokenizer_config.json:   0%|          | 0.00/776 [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/414 [00:00<?, ?B/s]

Preparing datasets...
Loaded 3884 examples from /content/Carbon-EF/data/processed/instructions_train.json

Sample data item:
Instruction: What is the emission factor for Biomass Energy in Germany?...
Input: [Empty]
Output: The emission factor for Biomass Energy in Germany is 87.7 kg CO2e/kWh. This data is sourced from GRE...
Loaded 972 examples from /content/Carbon-EF/data/processed/instructions_val.json

Sample data item:
Instruction: I need to know the emission factor for Cotton Farming in EU?...
Input: [Empty]
Output: The emission factor for Cotton Farming in EU is 10.66 kg CO2e/kWh. This data is sourced from GREET....
Train dataset size: 3884
Validation dataset size: 972
Loading teacher model...
Loading teacher model with memory optimization...


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading student model...


config.json:   0%|          | 0.00/560 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/4.40G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/129 [00:00<?, ?B/s]

Starting training from epoch 0, step 0, global_step 0


Epoch 1:   0%|          | 0/3884 [00:00<?, ?it/s]

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Checkpoint saved at step 200 (storage-efficient mode)
Removed old checkpoint: /content/tinyllama_distilled/checkpoints/checkpoint-200
Checkpoint saved at step 243 (storage-efficient mode)
Epoch 1 complete. Average loss: 0.6500265512358393


Epoch 2:   0%|          | 0/3884 [00:00<?, ?it/s]

Removed old checkpoint: /content/tinyllama_distilled/checkpoints/checkpoint-243
Checkpoint saved at step 400 (storage-efficient mode)
Removed old checkpoint: /content/tinyllama_distilled/checkpoints/checkpoint-400
Checkpoint saved at step 486 (storage-efficient mode)
Epoch 2 complete. Average loss: 0.6491053038105047


Epoch 3:   0%|          | 0/3884 [00:00<?, ?it/s]

Removed old checkpoint: /content/tinyllama_distilled/checkpoints/checkpoint-486
Checkpoint saved at step 600 (storage-efficient mode)
Removed old checkpoint: /content/tinyllama_distilled/checkpoints/checkpoint-600
Checkpoint saved at step 729 (storage-efficient mode)
Epoch 3 complete. Average loss: 0.6491053038105047
Training complete. Final model saved to Drive at: /content/drive/MyDrive/tinyllama_distilled_final
Cleaning up local storage to free up space...
Removed local storage at /content/tinyllama_distilled

=== Quantizing Model for Mobile Deployment ===
Quantizing model for mobile deployment...
Attempting to save in GGUF format for maximum compression...
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m67.3/67.3 MB[0m [31m28.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml)

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Loading student model...
Evaluating perplexity...


Perplexity:   0%|          | 0/100 [00:00<?, ?it/s]

Perplexity:   0%|          | 0/100 [00:00<?, ?it/s]

Teacher perplexity: 4.53
Student perplexity: nan
Generating predictions and measuring metrics...


Testing:   0%|          | 0/100 [00:00<?, ?it/s]

The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
The attention mask is not set and cannot be inferred from input because pad token is same as eos token. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
The attention mask and the pad


=== Benchmark Results ===
Model Size: Teacher = 7156.5 MB, Student = 2098.2 MB (Reduction: 70.7%)
Perplexity: Teacher = 4.53, Student = nan
Exact Match: Teacher = 0.00%, Student = 0.00%
Emission Factor Accuracy: Teacher = 13.33%, Student = 0.00%
Average Latency: Teacher = 24.816s, Student = 7.547s (Speedup: 69.6%)
Benchmark results saved to /content/drive/MyDrive/tinyllama_distilled_final/benchmarks

Total execution time: 108.76 minutes
Complete process finished successfully!
