In [None]:
# CLEAN
!pip uninstall -y torchvision torchaudio triton bitsandbytes transformers accelerate peft

# TORCH FAMILY (CUDA 12.1 band)
!pip install --index-url https://download.pytorch.org/whl/cu121 \
  "torch==2.4.1" "torchvision==0.19.1" "torchaudio==2.4.1"

!pip install "transformers==4.44.2" "sentence-transformers==5.1.1"


[0mLooking in indexes: https://download.pytorch.org/whl/cu121
Collecting torchvision==0.19.1
  Using cached https://download.pytorch.org/whl/cu121/torchvision-0.19.1%2Bcu121-cp312-cp312-linux_x86_64.whl (7.1 MB)
Collecting torchaudio==2.4.1
  Using cached https://download.pytorch.org/whl/cu121/torchaudio-2.4.1%2Bcu121-cp312-cp312-linux_x86_64.whl (3.4 MB)
Collecting triton==3.0.0 (from torch==2.4.1)
  Using cached https://download.pytorch.org/whl/triton-3.0.0-1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (209.5 MB)
Installing collected packages: triton, torchvision, torchaudio
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
sentence-transformers 5.1.1 requires transformers<5.0.0,>=4.41.0, which is not installed.[0m[31m
[0mSuccessfully installed torchaudio-2.4.1+cu121 torchvision-0.19.1+cu121 triton-3.0.0
Collecting transformers==4.44.2
 

In [None]:
!pip install \
  "triton==3.0.0" \
  "bitsandbytes==0.44.1" \
  "transformers==4.44.2" \
  "accelerate==0.34.2" \
  "peft==0.13.0" \
  "datasets==2.19.1" \
  sentencepiece safetensors huggingface_hub hf-transfer codecarbon \
  "sentence-transformers==5.1.1"


Collecting bitsandbytes==0.44.1
  Using cached bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl.metadata (3.5 kB)
Collecting accelerate==0.34.2
  Using cached accelerate-0.34.2-py3-none-any.whl.metadata (19 kB)
Collecting peft==0.13.0
  Using cached peft-0.13.0-py3-none-any.whl.metadata (13 kB)
Using cached bitsandbytes-0.44.1-py3-none-manylinux_2_24_x86_64.whl (122.4 MB)
Using cached accelerate-0.34.2-py3-none-any.whl (324 kB)
Using cached peft-0.13.0-py3-none-any.whl (322 kB)
Installing collected packages: bitsandbytes, accelerate, peft
Successfully installed accelerate-0.34.2 bitsandbytes-0.44.1 peft-0.13.0


In [None]:
import math
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import torch
from torch.nn.utils import prune
from torch.utils.data import Dataset
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    PreTrainedModel,
    PreTrainedTokenizer,
    Trainer,
    TrainingArguments,
)
import csv
import time
from codecarbon import OfflineEmissionsTracker
from google.colab import files
import json
try:
    from google.colab import drive  # noqa: F401
    IN_COLAB = True
except ImportError:
    IN_COLAB = False

In [None]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on {DEVICE} with {torch.cuda.get_device_name(0) if DEVICE.type == 'cuda' else 'CPU-only'}")

BASE_MODEL = "meta-llama/Llama-3.2-3B-Instruct"
QUAN_OUT_DIR = "llama3p2-3b-instruct-nf4"

STUDENT_MODEL = "meta-llama/Llama-3.2-1B"

PROMPTS: Dict[str, str] = {}
function_code = """
from typing import List

def below_zero(operations: List[int]) -> bool:
    \"\"
    You're given a list of deposit and withdrawal operations
    on a bank account that starts with zero balance. Your task is to
    detect if at any point the balance of account falls below zero,
    and at that point function should return True.
    \"\"
    balance = 0
    for op in operations:
        balance += op
        if balance < 0:
            return True
    return False
"""

PROMPTS["complete_prompt_text"] = f"""
Please act as an expert Python software engineer. Given the python function below:
{function_code}
I would appreciate it if you could generate a complete and professional Google-style docstring.
The docstring should not include any extra commentary, strictly limited to include the docstring itself and the original function code.
CODE ONLY. Use standard Python indentation. Thank you.
Do not add explanations, notes, or text outside of the code.
Return only the function code with its docstring, without markdown fences or extra text before or after.
"""

PROMPTS["concise_prompt_text"] = f"""
Generate COMPLETE Google-style docstring for the following Python function:
{function_code}
Output the docstring with the function code. Do not include explanations, notes, or text outside the code.
Return only the function code with its docstring, without extra text.
"""

PROMPTS["ultra_concise_prompt_text"] = f"""Add Google-style docstring to function:
{function_code}
Output code only, no text."""

def ensure_tokenizer_padding(tokenizer: PreTrainedTokenizer, model: Optional[PreTrainedModel] = None):
    if tokenizer.pad_token is None:
        tokenizer.add_special_tokens({"pad_token": tokenizer.eos_token})
        if model:
            model.resize_token_embeddings(len(tokenizer))
    tokenizer.padding_side = "right"

def build_prompt(record: Dict[str, str]) -> Tuple[str, str]:
    instr = (record.get("instruction") or "").strip()
    user_input = (record.get("input") or "").strip()
    output = (record.get("output") or record.get("response") or record.get("answer") or "").strip()
    prompt = "### Instruction:\n" + instr
    if user_input:
        prompt += "\n\n### Input:\n" + user_input
    prompt += "\n\n### Response:\n"
    return prompt, output

class InstructionDataset(Dataset):
    def __init__(self, rows: List[Dict[str, str]], tokenizer: PreTrainedTokenizer, max_length: int = 2048):
        self.rows = rows
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.rows)

    def __getitem__(self, idx):
        prompt, response = build_prompt(self.rows[idx])
        prompt_ids = self.tokenizer(prompt, add_special_tokens=False)["input_ids"]
        response_ids = self.tokenizer(response + self.tokenizer.eos_token, add_special_tokens=False)["input_ids"]
        input_ids = (prompt_ids + response_ids)[: self.max_length]
        attention_mask = [1] * len(input_ids)
        labels = ([-100] * len(prompt_ids) + response_ids)[: len(input_ids)]
        if len(labels) < len(input_ids):
            labels += [-100] * (len(input_ids) - len(labels))
        return {
            "input_ids": input_ids,
            "attention_mask": attention_mask,
            "labels": labels,
        }

@dataclass
class CausalDataCollator:
    tokenizer: PreTrainedTokenizer

    def __call__(self, features: List[Dict[str, List[int]]]) -> Dict[str, torch.Tensor]:
        batch = self.tokenizer.pad(features, padding=True, return_tensors="pt")
        labels = batch["labels"]
        labels[labels == self.tokenizer.pad_token_id] = -100
        batch["labels"] = labels
        return batch

Running on cuda with NVIDIA A100-SXM4-80GB
Running on cuda with NVIDIA A100-SXM4-80GB


## 1. Quantize LLaMA 3.1 8B to 4-bit NF4


In [None]:

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_use_double_quant=False,
    # On A100s, BF16 is stable; keep FP16 if you prefer or if BF16 unsupported
    bnb_4bit_compute_dtype=torch.bfloat16 if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else torch.float16,
)

# Load tokenizer + 4-bit model
tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False)
model = AutoModelForCausalLM.from_pretrained(
    BASE_MODEL,
    quantization_config=bnb_config,
    device_map="auto",
    trust_remote_code=False,  # not needed for Llama
)

# Ensure a pad token for causal LM
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
    model.config.pad_token_id = tokenizer.eos_token_id

# Save quantized weights + tokenizer
tokenizer.save_pretrained(QUAN_OUT_DIR)
model.save_pretrained(QUAN_OUT_DIR, safe_serialization=True)
print(f"Saved 4-bit NF4 weights to {QUAN_OUT_DIR}")


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Saved 4-bit NF4 weights to llama3p2-3b-instruct-nf4


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Saved 4-bit NF4 weights to llama3p2-3b-instruct-nf4


In [None]:
def generate_docstring_quantization(model_path: str, prompt_key: str = "complete_prompt_text", max_new_tokens: int = 256):
    tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False)
    model = AutoModelForCausalLM.from_pretrained(
        model_path,
        device_map="auto",
        torch_dtype=torch.float16,
        trust_remote_code=True,
    )
    prompt = PROMPTS[prompt_key]
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    input_len = inputs["input_ids"].shape[1]
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            do_sample=True,
            max_new_tokens=max_new_tokens,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
    gen_ids = outputs[0][input_len:]
    text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
    print(text)

generate_docstring_quantization("llama3p2-3b-instruct-nf4")

Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


def below_zero(operations: List[int]) -> bool:
    """
    Detects if at any point the balance of account falls below zero.

    Args:
        operations (List[int]): A list of deposit and withdrawal operations
            on a bank account that starts with zero balance.

    Returns:
        bool: True if the balance falls below zero at any point, False otherwise.
    """
    balance = 0
    for op in operations:
        balance += op
        if balance < 0:
            return True
    return False
```


Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


def below_zero(operations: List[int]) -> bool:
    """
    Detects if at any point the balance of account falls below zero.

    Args:
        operations (List[int]): A list of deposit and withdrawal operations
            on a bank account that starts with zero balance.

    Returns:
        bool: True if the balance falls below zero at any point, False otherwise.
    """
    balance = 0
    for op in operations:
        balance += op
        if balance < 0:
            return True
    return False
```


## 2. Upload CSV results and build a fine-tuning corpus


In [None]:


print("Upload `green_prompt_results_three_models.csv`")
uploaded = files.upload()
dataset_rows: List[Dict[str, str]] = []

csv_path = next(iter(uploaded.keys()))
with open(csv_path, "r", encoding="utf-8") as f:
    reader = csv.DictReader(f)
    for row in reader:
        if row["Model"] == "llama3.1:8b-instruct" and row["Quantization"] == "fp16":
            prompt_name = row["Prompt"]
            prompt_text = PROMPTS.get(prompt_name)
            output_text = row["Output"]
            if prompt_text and output_text:
                dataset_rows.append(
                    {
                        "instruction": prompt_text,
                        "input": "",
                        "output": output_text.strip(),
                        "meta": {
                            "epoch": row["Epoch"],
                            "prompt": prompt_name,
                            "accuracy": row["Accuracy"],
                        },
                    }
                )

print(f"Usable rows: {len(dataset_rows)}")

dataset_path = Path("llama3_fp16_docstrings.jsonl")
dataset_path.parent.mkdir(parents=True, exist_ok=True)
with dataset_path.open("w", encoding="utf-8") as handle:
    for row in dataset_rows:
        handle.write(json.dumps(row, ensure_ascii=False) + "\n")
print("Saved dataset to", dataset_path)

Upload `green_prompt_results_three_models.csv`


Saving green_prompt_results_three_models.csv to green_prompt_results_three_models (1).csv
Usable rows: 15
Saved dataset to llama3_fp16_docstrings.jsonl
Upload `green_prompt_results_three_models.csv`


KeyboardInterrupt: 

## 3. Prune (40 % sparsity) and run a short recovery fine-tune


In [None]:


def magnitude_prune_inplace_cpu(model: torch.nn.Module, sparsity: float = 0.4):
    """
    Hard-masks weights by magnitude using a global quantile threshold per Linear layer.
    Runs entirely on CPU to avoid GPU OOM.
    """
    model_cpu = model.to("cpu")
    with torch.no_grad():
        for m in model_cpu.modules():
            if isinstance(m, torch.nn.Linear) and hasattr(m, "weight") and m.weight is not None:
                w = m.weight.data  # on CPU
                # threshold at given sparsity (e.g., 0.4 → zero-out 40% smallest magnitudes)
                thr = torch.quantile(w.abs().float(), sparsity)
                mask = (w.abs() >= thr)
                w.mul_(mask)  # hard mask in place
    return model_cpu


In [None]:
import os, gc, math, torch
from pathlib import Path
from typing import List, Dict
from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments, Trainer

# --- A100-friendly runtime knobs ---
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"  # reduce fragmentation
torch.backends.cuda.matmul.allow_tf32 = True                        # allow TF32 on A100
torch.set_float32_matmul_precision("high")

def ensure_tokenizer_padding(tokenizer, model):
    # Safe pad-token setup for causal LMs
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        model.config.pad_token_id = tokenizer.eos_token_id

def split_records(records: List[Dict[str, str]], eval_ratio: float = 0.1):
    if eval_ratio <= 0 or len(records) < 2:
        return records, []
    eval_size = max(1, int(math.ceil(len(records) * eval_ratio)))
    return records[eval_size:], records[:eval_size]

from tqdm.auto import tqdm
import torch

def magnitude_prune_inplace_cpu(model: torch.nn.Module, sparsity: float = 0.4, show_stats: bool = True):
    """
    Unstructured magnitude pruning per Linear layer on CPU with a tqdm progress bar.
    Uses k-th smallest magnitude threshold (k = floor(sparsity * N)) per layer
    to avoid torch.quantile size limits.
    """
    # Collect layers first so tqdm knows the total
    linear_layers = []
    for m in model.modules():
        if isinstance(m, torch.nn.Linear) and getattr(m, "weight", None) is not None:
            linear_layers.append(m)

    total = len(linear_layers)
    if total == 0:
        print("No torch.nn.Linear layers found to prune.")
        return model

    model = model.to("cpu")
    zeroed_params = 0
    total_params = 0

    with torch.no_grad(), tqdm(total=total, desc=f"CPU pruning @ {int(sparsity*100)}% sparsity", leave=True) as pbar:
        for idx, m in enumerate(linear_layers, 1):
            w = m.weight.data  # CPU tensor
            N = w.numel()
            total_params += N
            if N == 0 or sparsity <= 0.0:
                pbar.update(1)
                continue

            k = int(sparsity * N)
            if k <= 0:
                pbar.update(1)
                continue
            if k >= N:
                w.zero_()
                zeroed_params += N
                pbar.set_postfix_str(f"layer {idx}/{total} | zeroed 100%")
                pbar.update(1)
                continue

            flat = w.abs().float().view(-1)
            try:
                thr = torch.kthvalue(flat, k).values
            except RuntimeError:
                # Fallback sampling for extremely large tensors
                step = max(1, N // 5_000_000)  # ~<=5M elems
                sample = flat[::step]
                est_k = max(1, int(len(sample) * sparsity))
                thr = torch.kthvalue(sample, est_k).values

            mask = (w.abs() >= thr)
            # Count zeros after masking for stats (cheap on CPU)
            zeroed_params += (N - int(mask.sum().item()))
            w.mul_(mask)  # hard-mask in place

            kept_pct = (mask.sum().item() / N) * 100.0
            pbar.set_postfix_str(f"layer {idx}/{total} | kept {kept_pct:.1f}%")
            pbar.update(1)

    if show_stats and total_params > 0:
        overall_zero = 100.0 * zeroed_params / total_params
        print(f"[Prune] Overall zeroed: {overall_zero:.2f}% ({zeroed_params:,}/{total_params:,})")

    return model



def prune_and_ft(
    sparsity: float = 0.4,          # 0.3–0.5 recommended
    max_length: int = 2048,
    epochs: float = 0.5,
    lr: float = 2e-5,
    grad_accum: int = 4,
    per_device_bs: int = 1,
    eval_ratio: float = 0.2,
    logging_steps: int = 5,
    output_dir: str = "llama3p2-3b-pruned-ft",
):
    # ---- memory clean ---
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

    # ---- load on CPU (fp32) ---
    tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=False)
    base_model = AutoModelForCausalLM.from_pretrained(
        BASE_MODEL,
        device_map=None,             # stay on CPU for pruning
        torch_dtype=torch.float32,   # prune in fp32 on CPU
        low_cpu_mem_usage=True,
        trust_remote_code=True,
    )
    ensure_tokenizer_padding(tokenizer, base_model)

    # ---- prune on CPU  ---
    print(f"Pruning Linear layers on CPU at {int(sparsity*100)}% sparsity …")
    base_model = magnitude_prune_inplace_cpu(base_model, sparsity=sparsity)

    # ---- move to A100 in BF16 ---
    use_bf16 = torch.cuda.is_available() and torch.cuda.is_bf16_supported()
    if torch.cuda.is_available():
        device = torch.device("cuda")
        dtype  = torch.bfloat16 if use_bf16 else torch.float32
        base_model = base_model.to(device=device, dtype=dtype)
    else:
        device = torch.device("cpu")

    # ---- data ---
    train_rows, eval_rows = split_records(dataset_rows, eval_ratio)
    train_dataset = InstructionDataset(train_rows, tokenizer, max_length)
    eval_dataset  = InstructionDataset(eval_rows,  tokenizer, max_length) if eval_rows else None
    collator = CausalDataCollator(tokenizer)

    # ---- training args  ---
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=epochs,
        per_device_train_batch_size=per_device_bs,
        gradient_accumulation_steps=grad_accum,
        learning_rate=lr,
        weight_decay=0.0,
        warmup_ratio=0.03,
        logging_strategy="steps",
        logging_steps=logging_steps,
        save_strategy="no",
        report_to="none",
        evaluation_strategy="steps" if eval_dataset is not None else "no",
        eval_steps=logging_steps if eval_dataset is not None else None,
        fp16=False,                 # avoid GradScaler on FP16 grads
        bf16=use_bf16,             # True on A100/AMPERE+
        dataloader_pin_memory=True,
    )

    # keep memory lower
    base_model.gradient_checkpointing_enable()
    base_model.config.use_cache = False

    trainer = Trainer(
        model=base_model,
        tokenizer=tokenizer,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=collator,
    )

    trainer.train()

    # ---- save ---
    Path(output_dir).mkdir(parents=True, exist_ok=True)
    tokenizer.save_pretrained(output_dir)
    base_model = base_model.to("cpu")
    base_model.save_pretrained(output_dir, safe_serialization=True)
    print(f"Saved pruned + fine-tuned weights to {output_dir}")


prune_and_ft(sparsity=0.4)


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Pruning Linear layers on CPU at 40% sparsity …


CPU pruning @ 40% sparsity:   0%|          | 0/197 [00:00<?, ?it/s]

[Prune] Overall zeroed: 39.89% (1,281,644,973/3,212,574,720)


You're using a PreTrainedTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]


Step,Training Loss,Validation Loss


Saved pruned + fine-tuned weights to llama3p1-8b-pruned-ft


## 4. Docstring generator for the pruned model


In [None]:
def generate_docstring_prune(model_dir: str, prompt_key: str = "complete_prompt_text", max_new_tokens: int = 256):
    tokenizer = AutoTokenizer.from_pretrained(model_dir, use_fast=False)
    model = AutoModelForCausalLM.from_pretrained(
        model_dir,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True,
    )
    ensure_tokenizer_padding(tokenizer, model)
    prompt = PROMPTS[prompt_key]
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    input_len = inputs["input_ids"].shape[1]
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            do_sample=True,
            max_new_tokens=max_new_tokens,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.pad_token_id,
        )
    gen_ids = outputs[0][input_len:]
    text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
    print(text)

generate_docstring_prune("llama3p2-3b-pruned-ft")

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]



```python
from typing import List

def below_zero(operations: List[int]) -> bool:
    """
    Detects if at any point the balance of a bank account falls below zero.

    Args:
        operations (List[int]): A list of integers representing deposit (positive) or withdrawal (negative) amounts.

    Returns:
        bool: True if the balance falls below zero at any point, False otherwise.
    """
    balance = 0
    for op in operations:
        balance += op
        if balance < 0:
            return True
    return False
```python
from typing import List

def below_zero(operations: List[int]) -> bool:
    """
    Detects if at any point the balance of a bank account falls below zero.

    Args:
        operations (List[int]): A list of integers representing deposit (positive) or withdrawal (negative) amounts.

    Returns:
        bool: True if the balance falls below zero at any point, False otherwise.
    """
    balance = 0
    for op in operations:
        balance += op
        if 

In [None]:
def get_model_size_gb(model):
  torch.save(model.state_dict(), "temp_model.pt")
  size_mb = os.path.getsize("temp_model.pt") / (1024 * 1024)
  os.remove("temp_model.pt")
  return size_mb

In [None]:
RUNS_PER_MODEL = 5
PROMPT_KEY = "complete_prompt_text"
MAX_NEW_TOKENS = 256
COUNTRY_CODE = "USA"
METRICS_PATH = Path("docstring_metrics_runs.csv")

NF4_MODEL_DIR   = str(Path("./llama3p2-3b-instruct-nf4").resolve())
PRUNE_FT_DIR    = str(Path("./llama3p2-3b-pruned-ft").resolve())

try:
    torch.backends.cuda.enable_flash_sdp(False)
    torch.backends.cuda.enable_mem_efficient_sdp(False)
    torch.backends.cuda.enable_math_sdp(True)
except Exception:
    pass

def _sdp_math_ctx():
    try:
        return torch.backends.cuda.sdp_kernel(
            enable_flash=False, enable_math=True, enable_mem_efficient=False
        )
    except AttributeError:
        # older torch: no-op
        from contextlib import contextmanager
        @contextmanager
        def _noop():
            yield
        return _noop()
def load_pair(model_dir_or_id: str):
    tok = AutoTokenizer.from_pretrained(model_dir_or_id, use_fast=False, trust_remote_code=True)
    mdl = AutoModelForCausalLM.from_pretrained(
        model_dir_or_id,
        torch_dtype=torch.float16,
        device_map="auto",
        trust_remote_code=True,
        low_cpu_mem_usage=True,
        attn_implementation="eager",   # <-- enable eager attention
    )
    if tok.pad_token_id is None:
        tok.pad_token = tok.eos_token

    vocab = mdl.get_input_embeddings().num_embeddings
    for tid_name, tid in {"eos": tok.eos_token_id, "pad": tok.pad_token_id}.items():
        assert tid is not None and 0 <= tid < vocab, f"{tid_name}_token_id out of range"

    mdl.generation_config.eos_token_id = tok.eos_token_id
    mdl.generation_config.pad_token_id = tok.pad_token_id
    return tok, mdl

VARIANTS = [
    {
        "key": "baseline_fp16",
        "label": "Llama 3.2 3B FP16 (baseline)",
        "load_fn": lambda: load_pair(BASE_MODEL),
    },
    {
        "key": "quantized_nf4",
        "label": "Llama 3.2 3B NF4 Quantized (local)",
        "load_fn": lambda: load_pair(NF4_MODEL_DIR),
    },
    {
        "key": "pruned_ft",
        "label": "Llama 3.2 3B Pruned + Fine-Tuned (local)",
        "load_fn": lambda: load_pair(PRUNE_FT_DIR),
    },
]

records = []

for variant in VARIANTS:
    print(f"Measuring {variant['label']}...")
    tokenizer, model = variant["load_fn"]()
    ensure_tokenizer_padding(tokenizer, model)
    model.eval()
    model_size_mb = get_model_size_gb(model)

    prompt_text = PROMPTS[PROMPT_KEY]
    inputs = tokenizer(prompt_text, return_tensors="pt")
    if hasattr(model, "device"):
        inputs = inputs.to(model.device)
    input_length = inputs["input_ids"].shape[1]

    for run_idx in range(1, RUNS_PER_MODEL + 1):
        # Set random seed for reproducibility
        torch.manual_seed(42 + run_idx)
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(42 + run_idx)

        if torch.cuda.is_available():
            torch.cuda.synchronize()
        tracker = OfflineEmissionsTracker(
            country_iso_code=COUNTRY_CODE,
            measure_power_secs=1,
            tracking_mode="process",
            save_to_file=False,
            log_level="error",
        )
        tracker.start()
        start_time = time.perf_counter()

        with torch.inference_mode():
            outputs = model.generate(
                **inputs,
                do_sample=True,
                max_new_tokens=MAX_NEW_TOKENS,
                eos_token_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
        )
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        latency_s = time.perf_counter() - start_time
        co2_kg = tracker.stop() or 0.0

        generated_ids = outputs[0][input_length:]
        output_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()

        records.append(
            {
                "model_key": variant["key"],
                "model_label": variant["label"],
                "run_index": run_idx,
                "co2_kg": float(co2_kg),
                "latency_s": float(latency_s),
                "model_size_mb": float(model_size_mb),
                "output": output_text,
                "accuracy": "N/A",
            }
        )
        print(f"  Run {run_idx}: latency={latency_s:.2f}s | CO2={co2_kg:.6f} kg")

    del model
    del tokenizer
    if torch.cuda.is_available():
        torch.cuda.empty_cache()

fieldnames = [
    "model_key",
    "model_label",
    "run_index",
    "co2_kg",
    "latency_s",
    "model_size_mb",
    "output",
    "accuracy",
]

with METRICS_PATH.open("w", newline="", encoding="utf-8") as handle:
    writer = csv.DictWriter(handle, fieldnames=fieldnames)
    writer.writeheader()
    writer.writerows(records)

print(f"Saved metrics to {METRICS_PATH}")

Measuring Llama 3.2 3B FP16 (baseline)...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  Run 1: latency=5.08s | CO2=0.000071 kg
  Run 2: latency=3.78s | CO2=0.000053 kg
  Run 3: latency=3.65s | CO2=0.000052 kg
  Run 4: latency=8.28s | CO2=0.000117 kg
  Run 5: latency=3.70s | CO2=0.000052 kg
Measuring Llama 3.2 3B NF4 Quantized (local)...


Unused kwargs: ['_load_in_4bit', '_load_in_8bit', 'quant_method']. These kwargs are not used in <class 'transformers.utils.quantization_config.BitsAndBytesConfig'>.


  Run 1: latency=6.86s | CO2=0.000086 kg
  Run 2: latency=6.81s | CO2=0.000086 kg
  Run 3: latency=5.87s | CO2=0.000075 kg
  Run 4: latency=6.81s | CO2=0.000087 kg
  Run 5: latency=14.66s | CO2=0.000185 kg
Measuring Llama 3.2 3B Pruned + Fine-Tuned (local)...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

  Run 1: latency=8.69s | CO2=0.000121 kg
  Run 2: latency=8.33s | CO2=0.000117 kg
  Run 3: latency=8.30s | CO2=0.000117 kg
  Run 4: latency=8.25s | CO2=0.000116 kg
  Run 5: latency=8.36s | CO2=0.000117 kg
Saved metrics to docstring_metrics_runs.csv
