In [None]:
!pip install -U transformers==4.39.3 peft==0.9.0 datasets nltk rouge_score bitsandbytes accelerate
!pip install --upgrade transformers
!pip install bert-score sentence-transformers

#!git clone https://github.com/facebookresearch/SpinQuant.git
#!pip install git+https://github.com/Dao-AILab/fast-hadamard-transform.git
import sys
# sys.path.append("/content/SpinQuant")
import math
import itertools
#from fast_hadamard_transform import hadamard_transform
#from utils.hadamard_utils import random_hadamard_matrix
#from train_utils.main import prepare_model
import os, json, shutil, time, torch, gc, requests
import pandas as pd
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from datasets import Dataset, concatenate_datasets, load_dataset
from torch.utils.data import TensorDataset, DataLoader, random_split
from transformers import (AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig, get_linear_schedule_with_warmup, get_cosine_schedule_with_warmup)
from torch.optim import AdamW
from peft import LoraConfig, get_peft_model, TaskType, prepare_model_for_kbit_training
from torch import amp as torch_amp
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from huggingface_hub import snapshot_download
from torch.nn.utils import prune
from copy import deepcopy
from rouge_score import rouge_scorer
from sklearn.metrics import f1_score
from tqdm import tqdm
from google.colab import runtime
from bert_score import score as bertscore_score
from sentence_transformers import SentenceTransformer, util
#from train_utils.main import spinquantize
%matplotlib inline

In [None]:
# Model + Tokenizer setup
hf_token = ""
# Llama
model_id = "meta-llama/Llama-3.2-1B-Instruct"
model_dir = "/content/Llama-3.2-1B-Instruct"
# model_id = "meta-llama/Llama-3.2-3B-Instruct"
# model_dir = "/content/Llama-3.2-3B-Instruct"
# Gemma
# model_id = "google/gemma-3-1b-it"
# model_dir = "/content/gemma-3-1b-it"
# Mistral
# model_id = "mistralai/Mistral-7B-Instruct-v0.3"
# model_dir = "/content/Mistral-7B-Instruct"

snapshot_download(repo_id=model_id, token=hf_token, local_dir=model_dir)

# # Patch config for Llama
# config_path = os.path.join(model_dir, "config.json")
# with open(config_path, "r") as f:
#     config = json.load(f)
# config["rope_scaling"] = {"type": "dynamic", "factor": 1.1}
# with open(config_path, "w") as f:
#     json.dump(config, f, indent=2)

base_model = AutoModelForCausalLM.from_pretrained(
    model_dir,
    token=hf_token,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map="cpu"
)

#base_model.config._attn_implementation = "eager" #It is strongly recommended to train Gemma3 models with the `eager` attention implementation. Newer versions of the transformer packs enable FlashAttention2 automatically for supported models.
# But this changes numerical stability, it might cause inf/nan values where the old slower kernels didn‚Äôt. So, I disable FlashAttention kernels by this line. This makes it use the standard PyTorch attention instead of flash kernels).
tokenizer = AutoTokenizer.from_pretrained(model_dir, token=hf_token)
tokenizer.padding_side = "left" # Ensure tokenizer uses correct padding token
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token  # Reuse eos if no pad_token

In [None]:
# Alpaca dataset
alpaca = load_dataset("tatsu-lab/alpaca")

def format_alpaca(example):
    if example["input"]:
        prompt = f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['input']}\n\n### Response:\n"
    else:
        prompt = f"### Instruction:\n{example['instruction']}\n\n### Response:\n"
    return {"prompt": prompt, "response": example["output"]}

formatted_alpaca = alpaca["train"].map(format_alpaca)
formatted_alpaca = formatted_alpaca.select(range(min(5000, len(formatted_alpaca))))

# Dolly dataset
dolly = load_dataset("databricks/databricks-dolly-15k")

def format_dolly(example):
    if example["context"]:
        prompt = f"### Instruction:\n{example['instruction']}\n\n### Input:\n{example['context']}\n\n### Response:\n"
    else:
        prompt = f"### Instruction:\n{example['instruction']}\n\n### Response:\n"
    return {"prompt": prompt, "response": example["response"]}

formatted_dolly = dolly["train"].map(format_dolly)
formatted_dolly = formatted_dolly.select(range(min(5000, len(formatted_dolly))))

# Agent dataset
agent = load_dataset("THUDM/AgentInstruct")

def format_agent(example):
    conv = example["conversations"]
    user_msgs = [c["value"] for c in conv if c["from"] in ["human", "user"]]
    asst_msgs = [c["value"] for c in conv if c["from"] in ["gpt", "assistant"]]
    if not user_msgs or not asst_msgs:
        return None
    prompt = f"### Instruction:\n{user_msgs[-1]}\n\n### Response:\n"
    return {"prompt": prompt, "response": asst_msgs[-1]}

formatted_agent = [ds.map(format_agent).filter(lambda x: x is not None) for ds in agent.values()]
merged_agent = concatenate_datasets(formatted_agent)
# print("AgentInstruct merged size:", len(merged_agent))
# print(merged_agent[0])

def build_dataloader(dataset, tokenizer, max_len=512, batch_size=2):
    prompts = list(dataset["prompt"])
    responses = list(dataset["response"])
    full_texts = [p + r for p, r in zip(prompts, responses)]

    # Tokenize
    encodings = tokenizer(
        full_texts,
        padding="max_length",
        truncation=True,
        max_length=max_len,
        return_tensors="pt",
        add_special_tokens=False
    )

    input_ids = encodings["input_ids"]
    labels = input_ids.clone()

    # Mask prompt part in labels
    for i, prompt in enumerate(prompts):
        prompt_ids = tokenizer(
            prompt,
            truncation=True,
            max_length=max_len,
            return_tensors="pt",
            add_special_tokens=False
        )["input_ids"][0]
        prompt_len = (prompt_ids != tokenizer.pad_token_id).sum().item()
        labels[i][:prompt_len] = -100

    # Split train/val
    ds = TensorDataset(input_ids, labels)
    train_len = int(0.98 * len(ds))
    train_ds, val_ds = random_split(ds, [train_len, len(ds) - train_len])

    return (
        DataLoader(train_ds, batch_size=batch_size, shuffle=True),
        DataLoader(val_ds, batch_size=batch_size)
    )

train_alpaca, val_alpaca = build_dataloader(formatted_alpaca, tokenizer)
train_dolly, val_dolly   = build_dataloader(formatted_dolly, tokenizer)
train_agent, val_agent = build_dataloader(merged_agent, tokenizer)

In [None]:
def stochastic_linear(x, weight, bias=None, noise_std=0.01):
    noise = torch.randn_like(x) * noise_std
    return F.linear(x + noise, weight, bias)

def quantize_tensor_uniform(tensor, bits=8):
    qmax = 2 ** (bits - 1) - 1
    qmin = -qmax - 1
    max_val = tensor.abs().max(dim=-1, keepdim=True).values
    scale = torch.clamp(max_val / qmax, min=1e-8)
    q = torch.clamp((tensor / scale).round(), qmin, qmax)
    return q, scale

def dequantize_tensor(q, scale):
    return q * scale

def quantize_tensor_nf4(tensor, block_size=64):
    nf4 = torch.tensor([
        -1.0000, -0.6962, -0.5251, -0.3949,
        -0.2844, -0.1848, -0.0911,  0.0000,
         0.0911,  0.1848,  0.2844,  0.3949,
         0.5251,  0.6962,  0.8682,  1.0000
    ], device=tensor.device, dtype=torch.bfloat16)

    shape = tensor.shape
    flat = tensor.view(-1, shape[-1])
    feat_dim = flat.size(-1)

    # pad to multiple of block size
    pad = (block_size - (feat_dim % block_size)) % block_size
    if pad > 0:
        flat = F.pad(flat, (0, pad))

    n_blocks = flat.size(-1) // block_size
    blocks = flat.view(flat.size(0), n_blocks, block_size)

    # per-block scales
    max_vals = blocks.abs().amax(dim=-1, keepdim=True).clamp(min=1e-8)
    normed = blocks / max_vals

    # streaming quantization: compute closest NF4 value per element without huge broadcast
    indices = torch.empty_like(normed, dtype=torch.long)
    best_dist = torch.full_like(normed, float("inf"))

    # loop over 16 NF4 codebook entries ‚Äî fast since it's small
    for i, c in enumerate(nf4):
        dist = (normed - c).abs()
        mask = dist < best_dist
        indices[mask] = i
        best_dist = torch.minimum(best_dist, dist)

    # remove padding
    indices = indices.view(flat.size(0), -1)[..., :feat_dim]

    return indices.view(shape), max_vals


def dequantize_tensor_nf4(q_indices, scales, block_size=64):
    nf4 = torch.tensor([
        -1.0000, -0.6962, -0.5251, -0.3949,
        -0.2844, -0.1848, -0.0911,  0.0000,
         0.0911,  0.1848,  0.2844,  0.3949,
         0.5251,  0.6962,  0.8682,  1.0000
    ], device=q_indices.device, dtype=torch.bfloat16)

    shape = q_indices.shape
    flat_idx = q_indices.view(-1, shape[-1])
    feat_dim = flat_idx.size(-1)

    pad = (block_size - (feat_dim % block_size)) % block_size
    if pad > 0:
        flat_idx = F.pad(flat_idx, (0, pad))

    n_blocks = flat_idx.size(-1) // block_size
    blocks_idx = flat_idx.view(flat_idx.size(0), n_blocks, block_size)
    scales = scales.view(blocks_idx.size(0), n_blocks, 1)
    values = nf4[blocks_idx] * scales

    # remove padding
    flat_val = values.view(flat_idx.size(0), -1)[..., :feat_dim]
    return flat_val.view(shape)

class QuantLinear(nn.Module):
    def __init__(self, in_features, out_features, bias=True, method="int8"):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.method = method.lower()

        # Register the FP16 weight with the canonical name "weight"
        self.weight = nn.Parameter(torch.empty(out_features, in_features))
        if bias:
            self.bias = nn.Parameter(torch.empty(out_features))
        else:
            self.register_parameter("bias", None)

        # Buffers for quantized weights (after training)
        self.register_buffer("weight_q", None)
        self.register_buffer("weight_scale", None)

        self.reset_parameters()

    def reset_parameters(self):
        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        if self.bias is not None:
            fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
            bound = 1 / math.sqrt(fan_in)
            nn.init.uniform_(self.bias, -bound, bound)

    def pre_quantize_weights(self):
        with torch.no_grad():
            if self.method == "nf4":
                q, scale = quantize_tensor_nf4(self.weight)
                self.weight.copy_(dequantize_tensor_nf4(q, scale))
            else:
                bits = 8 if self.method == "int8" else 4
                q, scale = quantize_tensor_uniform(self.weight, bits=bits)
                self.weight.copy_(dequantize_tensor(q, scale))

    def quantize_weights(self):
        with torch.no_grad():
            if self.method == "nf4":
                q, scale = quantize_tensor_nf4(self.weight)
            else:
                bits = 8 if self.method == "int8" else 4
                q, scale = quantize_tensor_uniform(self.weight, bits=bits)
            self.weight_q = q
            self.weight_scale = scale

            # if "weight" in self._parameters:
            #     del self._parameters["weight"]
            # self.register_buffer("weight_bfp16_placeholder", torch.tensor(0.0))

    def forward(self, x):

        if self.weight_q is None:
            w = self.weight
            if w.dtype != x.dtype:
                w = w.to(x.dtype)

            b = self.bias
            if b is not None and b.dtype != x.dtype:
                b = b.to(x.dtype)

            return F.linear(x, w, b)

        x_q, x_scale = quantize_tensor_uniform(x, bits=8)
        x_dq = dequantize_tensor(x_q, x_scale) # x_q , x_dq = activation tensors, w_q , w_dq = weight tensors

        if self.method == "nf4":
            w_dq = dequantize_tensor_nf4(self.weight_q, self.weight_scale)
        else:
            w_dq = dequantize_tensor(self.weight_q, self.weight_scale)

        ditype = x.dtype

        out = F.linear(x_dq.to(ditype), w_dq.to(ditype), self.bias)

        out_q, out_scale = quantize_tensor_uniform(out, bits=8)
        return dequantize_tensor(out_q, out_scale)

def patch_model_with_quant(model, method="int8", pre_quantize=False, auto_quantize=False):

    method = method.lower()
    # Pass 1: structure patch (replace plain Linear only)
    for name, module in model.named_children():
        # Skip if already quantized
        if isinstance(module, QuantLinear):
            continue

        if isinstance(module, nn.Linear):
            qlinear = QuantLinear(
                module.in_features,
                module.out_features,
                bias=(module.bias is not None),
                method=method
            )
            # Copy weights/bias
            qlinear.weight.data.copy_(module.weight.data)
            if module.bias is not None:
                qlinear.bias.data.copy_(module.bias.data)
            setattr(model, name, qlinear)
        else:
            # Recurse into non-QuantLinear submodules
            patch_model_with_quant(module, method, pre_quantize, auto_quantize)

    # pre-quantize
    if pre_quantize:
        for m in model.modules():
            if isinstance(m, QuantLinear):
                if hasattr(m, "weight") and m.weight is not None:
                    m.pre_quantize_weights()

    # auto-quantize
    if auto_quantize:
        for m in model.modules():
            if not isinstance(m, QuantLinear):
                continue

            if getattr(m, "weight_q", None) is not None:
                continue

            if hasattr(m, "weight") and m.weight is not None:
                m.quantize_weights()

            # neither weight nor weight_q -> skip gracefully
            elif getattr(m, "weight_q", None) is None:
                continue

    return model

def patch_model(model, config):
    for layer in model.modules():
        if hasattr(layer, "mlp"):
            gate = layer.mlp.gate_proj

            if config.get("prune"):
                prune.l1_unstructured(gate, name="weight", amount=0.05)

            if config.get("stochastic"):
                new = nn.Linear(gate.in_features, gate.out_features, bias=(gate.bias is not None)).to(gate.weight.device, dtype=gate.weight.dtype)
                new.weight.data.copy_(gate.weight.data)
                if gate.bias is not None:
                    new.bias.data.copy_(gate.bias.data)
                new.forward = lambda x, w=new.weight, b=new.bias: stochastic_linear(x, w, b, noise_std=0.01)
                layer.mlp.gate_proj = new

        if config.get("stoc_mem_mask") and hasattr(layer, "self_attn"):
            drop_p = 0.05
            orig = layer.self_attn.forward

            def wrapped_forward(*args, orig=orig, drop_p=drop_p, **kw):
                out = orig(*args, **kw)

                def mask_fn(t):
                    if not torch.is_floating_point(t): # Generate mask in same dtype and device to avoid NaN
                        return t
                    mask = (torch.rand_like(t) > drop_p)
                    return t * mask

                if isinstance(out, tuple):
                    if isinstance(out[0], torch.Tensor) and out[0].numel() > 0:
                        return (mask_fn(out[0]),) + out[1:]
                    return out
                elif isinstance(out, torch.Tensor) and out.numel() > 0:
                    return mask_fn(out)
                return out

            layer.self_attn.forward = wrapped_forward

    if config.get("lora"):
        lora_cfg = LoraConfig(
            r=8, lora_alpha=16, lora_dropout=0.05, bias="none",
            target_modules=["q_proj", "v_proj"],
            task_type=TaskType.CAUSAL_LM)
        model = get_peft_model(model, lora_cfg)
        model.print_trainable_parameters()

    return model

def evaluate_metrics(model, tokenizer, val_dl, save_path=None, variant_name=""):

    device = torch.device("cuda")
    model.to(device)
    model.eval()

    # SBERT model for semantic similarity
    sbert_model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2").to(device)

    def get_model_size(model):
        total_bytes = 0
        for module in model.modules():

            # Quantized linear Layers
            if isinstance(module, QuantLinear):
                # Determine weight precision
                if hasattr(module, "method"):
                    if module.method == "int8":
                        weight_bytes_per_elem = 1
                    elif module.method in ["int4", "nf4"]:
                        weight_bytes_per_elem = 0.5
                    else:
                        weight_bytes_per_elem = torch.finfo(module.weight.dtype).bits / 8
                else:
                    weight_bytes_per_elem = torch.finfo(module.weight.dtype).bits / 8

                # Quantized weight tensor
                if module.weight_q is not None:
                    total_bytes += module.weight_q.numel() * weight_bytes_per_elem
                else:
                    total_bytes += module.weight.numel() * (torch.finfo(module.weight.dtype).bits / 8)

                # Bias (keeps original dtype, e.g., FP16 or BF16)
                if module.bias is not None:
                    total_bytes += module.bias.numel() * (torch.finfo(module.bias.dtype).bits / 8)

                # Scale parameters (stored in FP16/BF16 normally)
                if module.weight_scale is not None:
                    total_bytes += module.weight_scale.numel() * (torch.finfo(module.weight_scale.dtype).bits / 8)

            # Non-quantized layers
            else:
                for _, param in module.named_parameters(recurse=False):
                    total_bytes += param.numel() * (torch.finfo(param.dtype).bits / 8)
                for _, buf in module.named_buffers(recurse=False):
                    total_bytes += buf.numel() * (torch.finfo(buf.dtype).bits / 8)

        return round(total_bytes / (1024 * 1024), 2)

    model_size_mb = get_model_size(model)
    prompts, preds, refs = [], [], []
    bleu_scores, rouge_l_scores, f1_scores = [], [], []
    inference_times = []
    scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
    smooth_fn = SmoothingFunction().method4

    tokenizer.padding_side = "left"
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    def extract_response(text: str) -> str:
        return text.split("### Response:")[-1].strip() if "### Response:" in text else text.strip()

    # Warmup for stable CUDA kernel caching
    with torch.no_grad():
        for input_ids, _ in itertools.islice(val_dl, 1):
            input_ids = input_ids.to(device)
            _ = model.generate(input_ids=input_ids, max_new_tokens=8)
            torch.cuda.synchronize()

    # Reset GPU memory tracking
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats(device)
    torch.cuda.synchronize()

    start_time = time.time()

    with torch.no_grad():
        for input_ids, labels in tqdm(val_dl, desc=f"üîç Evaluating {variant_name}"):
            input_ids, labels = input_ids.to(device), labels.to(device)
            attention_mask = (input_ids != tokenizer.pad_token_id).long()

            # Measure inference time per batch
            torch.cuda.synchronize()
            t0 = time.time()
            gen_ids = model.generate(
                input_ids=input_ids,
                attention_mask=attention_mask,
                max_new_tokens=128,
                do_sample=True,
                top_p=0.9,
                temperature=0.6,
                eos_token_id=tokenizer.convert_tokens_to_ids("<|eot_id|>"),
                pad_token_id=tokenizer.pad_token_id
            )
            torch.cuda.synchronize()
            t1 = time.time()
            inference_times.append(t1 - t0)

            # Decode predictions and references
            pred_texts = [tokenizer.decode(g, skip_special_tokens=True).strip() for g in gen_ids]
            ref_labels = labels.clone()
            ref_labels[ref_labels == -100] = tokenizer.pad_token_id
            ref_texts = tokenizer.batch_decode(ref_labels, skip_special_tokens=True)

            for prompt, pred, ref in zip(
                tokenizer.batch_decode(input_ids, skip_special_tokens=True),
                pred_texts, ref_texts
            ):
                pred_resp = extract_response(pred)
                ref_resp = extract_response(ref)

                prompts.append(prompt)
                preds.append(pred_resp)
                refs.append(ref_resp)

                # Accuracy metrics (n-gram based)
                pred_tokens = tokenizer.tokenize(pred_resp)
                ref_tokens = tokenizer.tokenize(ref_resp)

                bleu = sentence_bleu([ref_tokens], pred_tokens, smoothing_function=smooth_fn)
                rouge_l = scorer.score(ref_resp, pred_resp)["rougeL"].fmeasure

                common = set(pred_tokens) & set(ref_tokens)
                prec = len(common) / len(pred_tokens) if pred_tokens else 0.0
                rec  = len(common) / len(ref_tokens)  if ref_tokens  else 0.0
                f1   = 2 * prec * rec / (prec + rec) if (prec + rec) > 0 else 0.0

                bleu_scores.append(bleu)
                rouge_l_scores.append(rouge_l)
                f1_scores.append(f1)

    sbert_sim_mean = 0.0
    if len(preds) > 0:
        # Encode all at once for efficiency
        emb_preds = sbert_model.encode(preds, convert_to_tensor=True, batch_size=32)
        emb_refs  = sbert_model.encode(refs, convert_to_tensor=True, batch_size=32)
        sim_matrix = util.cos_sim(emb_preds, emb_refs)  # [N, N]
        diag_sims = sim_matrix.diag()                   # only (pred_i, ref_i)
        sbert_sim_mean = float(diag_sims.mean().item())

    # BERTScore
    bert_p = bert_r = bert_f1 = 0.0
    if len(preds) > 0:

        device_str = "cuda" if torch.cuda.is_available() else "cpu"
        P, R, F1 = bertscore_score(preds, refs, lang="en", model_type="microsoft/deberta-base-mnli", device=device_str)
        bert_p = float(P.mean().item())
        bert_r = float(R.mean().item())
        bert_f1 = float(F1.mean().item())

    torch.cuda.synchronize()
    peak_alloc_mb = torch.cuda.max_memory_allocated(device) / (1024 ** 2)
    peak_resvd_mb = torch.cuda.max_memory_reserved(device) / (1024 ** 2)
    print(f"Peak GPU memory | Allocated: {peak_alloc_mb:.2f} MB | Reserved: {peak_resvd_mb:.2f} MB")

    torch.cuda.empty_cache()

    # Debug samples
    if len(refs) > 0:
        print("\n DEBUG: Sample Outputs")
        for i in range(min(2, len(refs))):
            print(f"\nüìú Prompt {i+1}:\n{prompts[i]}")
            print(f"üî∏ Prediction {i+1}:\n{preds[i]}")

    # Save predictions
    file_size_kb = None
    if save_path:
        os.makedirs(os.path.dirname(save_path), exist_ok=True)
        output_file = f"{save_path}_{variant_name.replace(' ', '_')}.jsonl"
        with open(output_file, "w", encoding="utf-8") as f:
            for p, r, y in zip(prompts, refs, preds):
                record = {"prompt": p, "reference": r, "prediction": y}
                f.write(json.dumps(record, ensure_ascii=False) + "\n")
        file_size_kb = os.path.getsize(output_file) / 1024
        print(f"üìù Saved predictions to {output_file} ({file_size_kb:.2f} KB)")

    total_eval_time = time.time() - start_time

    return {
        "BLEU": round(np.mean(bleu_scores), 4) if bleu_scores else 0.0,
        "ROUGE-L": round(np.mean(rouge_l_scores), 4) if rouge_l_scores else 0.0,
        "F1": round(np.mean(f1_scores), 4) if f1_scores else 0.0,
        "BERTScore_P": round(bert_p, 4),
        "BERTScore_R": round(bert_r, 4),
        "BERTScore_F1": round(bert_f1, 4),
        "SBERT_Similarity": round(sbert_sim_mean, 4),
        "Avg Inference Time (s)": round(np.mean(inference_times), 2) if inference_times else None,
        "Total Eval Time (s)": round(total_eval_time, 2),
        "GPU Memory Allocated (MB)": round(peak_alloc_mb, 2),
        "GPU Memory Reserved (MB)": round(peak_resvd_mb, 2),
        "Model Size (MB)": model_size_mb,
        "Output Size (KB)": round(file_size_kb, 2) if file_size_kb else None
    }

def train_model(model, train_dl, epochs=3, accumulation_steps=2, lr=1e-6, quant_method=None):
    model.to("cuda")
    model.train()
    optimizer = AdamW([p for p in model.parameters() if p.requires_grad], lr=lr) # Only optimize trainable params (important for LoRA)
    num_training_steps = epochs * len(train_dl) // accumulation_steps
    scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=int(0.2 * num_training_steps), num_training_steps=num_training_steps)
    scaler = torch_amp.GradScaler(enabled=False)
    loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
    history = []

    for epoch in range(epochs):
        print(f"\nüöÄ Epoch {epoch + 1}/{epochs}")
        total_loss = 0.0
        total_tokens = 0
        optimizer.zero_grad()
        progress_bar = tqdm(enumerate(train_dl), total=len(train_dl), desc=f"Epoch {epoch+1}")

        for step, (xb, yb) in progress_bar:
            xb, yb = xb.cuda(), yb.cuda()

            with torch_amp.autocast('cuda', enabled=True):
                outputs = model(xb)
                logits = outputs.logits
                valid_tokens = (yb != -100).sum().item()
                if valid_tokens == 0:
                    continue
                loss = loss_fn(logits.view(-1, logits.size(-1)), yb.view(-1))

            if torch.isnan(loss) or torch.isinf(loss):
                print("NaN/Inf loss detected ‚Äî skipping batch.")
                optimizer.zero_grad()
                continue

            total_loss += loss.item() * valid_tokens
            total_tokens += valid_tokens
            loss.backward()

            if (step + 1) % accumulation_steps == 0 or (step + 1) == len(train_dl):
                torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=0.5)
                optimizer.step()
                optimizer.zero_grad()
                scheduler.step()

            progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})

        avg_loss = total_loss / total_tokens
        perplexity = np.exp(avg_loss)
        print(f"‚úÖ Epoch {epoch + 1} Avg Loss: {avg_loss:.4f} | Perplexity: {perplexity:.2f}")
        history.append({"Epoch": epoch + 1, "Loss": avg_loss, "Perplexity": perplexity})

    return history

def report(model, train_dl, val_dl, tokenizer, epochs=3, variant_name="", quant_method=None, save_dir="outputs"):

    os.makedirs(save_dir, exist_ok=True)
    # Pre-quantize before training
    if quant_method:
        model = patch_model_with_quant(model, method=quant_method, pre_quantize=True)

    # Train on GPU
    train_history = train_model(model, train_dl, epochs=epochs, quant_method=quant_method)

    # Quantize weights once after training for inference
    if quant_method:
        model = patch_model_with_quant(model, method=quant_method, auto_quantize=True)

    # Inference on GPU
    eval_metrics = evaluate_metrics(model, tokenizer, val_dl, save_path=os.path.join(save_dir, "preds"), variant_name=variant_name)
    eval_metrics["Variant"] = variant_name

    df = pd.DataFrame(train_history)
    df["Variant"] = variant_name
    print("\nüìä Final Evaluation Metrics:", eval_metrics)
    plt.figure(figsize=(10, 5))
    plt.plot(df["Epoch"], df["Loss"], label="Loss", marker="o")
    plt.plot(df["Epoch"], df["Perplexity"], label="Perplexity", marker="s")
    plt.xlabel("Epoch")
    plt.title(f"Training Curve ({variant_name})")
    plt.legend()
    plt.show()

    return df, eval_metrics

In [None]:
# train_dl, val_dl = train_alpaca, val_alpaca
# name = "Alpaca"
#train_dl, val_dl = train_dolly, val_dolly
#name = "Dolly"
#train_dl, val_dl = train_agent, val_agent
#name = "Agent"
# calib_dl = val_dl
# train_dl, val_dl = train_alpaca, val_agent
# name = "Alpaca-Agent"
train_dl, val_dl = train_alpaca, val_dolly
name = "Alpaca-Dolly"
# train_dl, val_dl = train_agent, val_dolly
# name = "Agent-Dolly"


In [None]:
for name, param in base_model.named_parameters():
    print(name, param.dtype)
    break


model.embed_tokens.weight torch.bfloat16


In [None]:
model_lora = deepcopy(base_model)
print(f"\n  Evaluating Base Model")
metrics = evaluate_metrics(base_model, tokenizer, val_dl, save_path="outputs/preds", variant_name="Base model without approximation")
metrics["Variant"] = "Base Model"
print(metrics)

In [None]:
model_lora = deepcopy(base_model)
print(f"\n Training LoRA model on {name} dataset")
model_lora = patch_model(model_lora, config={"lora": True})
df_lora, lora_eval = report(model_lora, train_dl, val_dl, tokenizer, epochs=1, variant_name=f"LoRA model ({name})")

In [None]:
model_pruned = deepcopy(base_model)
print(f"\n Training Pruned model on {name} dataset")
model_pruned = patch_model(model_pruned, config={"prune": True})
df_pruned, pruned_eval = report(model_pruned, train_dl, val_dl, tokenizer, epochs=1, variant_name=f"Pruned model ({name})")

In [None]:
model_stochastic = deepcopy(base_model)
print(f"\n Training Stochastic Perturbed model on {name} dataset")
model_stochastic = patch_model(model_stochastic, config={"stochastic": True})
df_stochastic, stochastic_eval = report(model_stochastic, train_dl, val_dl, tokenizer, epochs=1, variant_name=f"Stochastic Perturbed model ({name})")

In [None]:
model_memmasked = deepcopy(base_model)
print(f"\n Training Memory Masked model on {name} dataset")
model_memmasked = patch_model(model_memmasked, config={"stoc_mem_mask": True})
df_memmasked, memmasked_eval = report(model_memmasked, train_dl, val_dl, tokenizer, epochs=1, variant_name=f"Memory Masked model ({name})")

In [None]:

model_int8 = deepcopy(base_model)
print(f"\n Training Int8 Quantized model on {name} dataset")
df_int8, int8_eval = report(model_int8, train_dl, val_dl, tokenizer, epochs=1, variant_name=f"Int8 Quantized model ({name})", quant_method="int8")


 Training Int8 Quantized model on Alpaca-Dolly dataset

üöÄ Epoch 1/1


Epoch 1:   0%|          | 1/2450 [00:01<48:11,  1.18s/it, loss=11.9856]


OutOfMemoryError: CUDA out of memory. Tried to allocate 96.00 MiB. GPU 0 has a total capacity of 39.56 GiB of which 16.88 MiB is free. Process 137386 has 39.53 GiB memory in use. Of the allocated memory 38.51 GiB is allocated by PyTorch, and 528.74 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
model_int4 = deepcopy(base_model)
print(f"\n Training Int4 Quantized model on {name} dataset")
df_int4, int4_eval = report(model_int4, train_dl, val_dl, tokenizer, epochs=1, variant_name=f"Int4 Quantized model ({name})", quant_method="int4")

In [None]:
model_pruned_lora = deepcopy(base_model)
print(f"\n Training Pruned_LoRA model on {name} dataset")
model_pruned_lora = patch_model(model_pruned_lora, config={"prune": True})
model_pruned_lora = patch_model(model_pruned_lora, config={"lora": True})
df_pruned_lora, pruned_lora_eval = report(model_pruned_lora, train_dl, val_dl, tokenizer, epochs=1, variant_name=f"Pruned_LoRA model ({name})")

In [None]:
model_stochastic_lora = deepcopy(base_model)
print(f"\n Training Stochastic Perturbation and Lora model on {name} dataset")
model_stochastic_lora = patch_model(model_stochastic_lora, config={"stochastic": True})
model_stochastic_lora = patch_model(model_stochastic_lora, config={"lora": True})
df_lora_nf4, nf4_lora_eval   = report(model_stochastic_lora, train_dl, val_dl, tokenizer, epochs=1, variant_name=f"Stochastic Perturbation and LoRA model ({name})")

In [None]:
model_memapprox_lora = deepcopy(base_model)
print(f"\n Training Memory Masked and Lora model on {name} dataset")
model_memapprox_lora = patch_model(model_memapprox_lora, config={"stoc_mem_mask": True})
model_memapprox_lora = patch_model(model_memapprox_lora, config={"lora": True})
df_lora_nf4, nf4_lora_eval   = report(model_memapprox_lora, train_dl, val_dl, tokenizer, epochs=1, variant_name=f"Memory Masked and LoRA model ({name})")

In [None]:
model_int8_lora = deepcopy(base_model)
print(f"\n Training Int8 Quantized and Lora model on {name} dataset")
model_int8_lora = patch_model(model_int8_lora, config={"lora": True})
train_model(model_int8_lora, train_dl, epochs=1, quant_method=None)
df_lora_int8, int8_lora_eval = report(model_int8_lora, train_dl, val_dl, tokenizer, epochs=1, variant_name=f"Int8 Quantized LoRA model ({name})", quant_method="int8")

In [None]:
model_int4_lora = deepcopy(base_model)
print(f"\n Training Int4 Quantized and Lora model on {name} dataset")
model_int4_lora = patch_model(model_int4_lora, config={"lora": True})
train_model(model_int4_lora, train_dl, epochs=1, quant_method=None)
df_lora_int4, int4_lora_eval = report(model_int4_lora, train_dl, val_dl, tokenizer, epochs=1, variant_name=f"Int4 Quantized LoRA model ({name})", quant_method="int4")

In [None]:
model_int8_pruned_lora = deepcopy(base_model)
print(f"\n Training Int8 Pruned LoRA model on {name} dataset")
model_int8_pruned_lora = patch_model(model_int8_pruned_lora, config={"prune": True})
train_model(model_int8_pruned_lora, train_dl, epochs=1, quant_method=None)
model_int8_pruned_lora = patch_model(model_int8_pruned_lora, config={"lora": True})
train_model(model_int8_pruned_lora, train_dl, epochs=1, quant_method=None)
df_int8_pruned_lora, int8_pruned_lora_eval = report(model_int8_pruned_lora, train_dl, val_dl, tokenizer, epochs=1, variant_name=f"Int8 Pruned LoRA model ({name})", quant_method="int8")

In [None]:
model_int4_pruned_lora = deepcopy(base_model)
print(f"\n Training Int4 Pruned LoRA model on {name} dataset")
model_int4_pruned_lora = patch_model(model_int4_pruned_lora, config={"prune": True})
train_model(model_int4_pruned_lora, train_dl, epochs=1, quant_method=None)
model_int4_pruned_lora = patch_model(model_int4_pruned_lora, config={"lora": True})
train_model(model_int4_pruned_lora, train_dl, epochs=1, quant_method=None)
df_int4_pruned_lora, int4_pruned_lora_eval = report(model_int4_pruned_lora, train_dl, val_dl, tokenizer, epochs=1, variant_name=f"Int4 Pruned LoRA model ({name})", quant_method="int4")

In [None]:
model_nf4 = deepcopy(base_model)
print(f"\n Training NF4 Quantized model on {name} dataset")
df_nf4, nf4_eval = report(model_nf4, train_dl, val_dl, tokenizer, epochs=1, variant_name=f"NF4 Quantized model ({name})", quant_method="nf4")

In [None]:
model_nf4_lora = deepcopy(base_model)
print(f"\n Training NF4 Quantized and Lora model on {name} dataset")
model_nf4_lora = patch_model(model_nf4_lora, config={"lora": True})
train_model(model_nf4_lora, train_dl, epochs=1, quant_method=None)
df_lora_nf4, nf4_lora_eval  = report(model_nf4_lora, train_dl, val_dl, tokenizer, epochs=1, variant_name=f"NF4 Quantized and LoRA model ({name})", quant_method="nf4")

In [None]:
runtime.unassign()