In [None]:
# ✨ SECTION 0 – Runtime prep  (GPU runtime → “T4” or “A100”)
# ----------------------------------------------------------
# !pip install -q torch==2.2.1 torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
# !pip install -q transformers==4.41.2 datasets evaluate accelerate pip install -q peft==0.10.0  bitsandbytes safetensors
# !pip install -q git+https://github.com/huggingface/nn_pruning.git          # pruning   :contentReference[oaicite:2]{index=2}
# !pip install -q "optimum[onnxruntime-gpu]"                                 # quantisation :contentReference[oaicite:3]{index=3}
# !pip install -q git+https://github.com/VainF/Torch-Pruning.git        # ⬅ structured pruning
# GPU PyTorch
!pip install -q torch==2.2.1 torchvision torchaudio \
    --index-url https://download.pytorch.org/whl/cu118

# HF Transformers + Datasets + Eval + Accelerate + BitsAndBytes + Safetensors + Diffusers + PEFT pin
!pip install -q transformers==4.41.2 datasets evaluate accelerate \
    bitsandbytes safetensors diffusers peft==0.10.0

# Pruning libraries
!pip install -q git+https://github.com/huggingface/nn_pruning.git
!pip install -q git+https://github.com/VainF/Torch-Pruning.git

  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [None]:
# ✨ SECTION 1 – Dataset & helper utilities
# -------------------------------------------------
from datasets import load_dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AdamW, get_linear_schedule_with_warmup
import torch
import time

DATASET = load_dataset("glue", "sst2")
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

def tokenize(batch):
    return tokenizer(batch["sentence"],
                     truncation=True,
                     padding="max_length",
                     max_length=128)

# tokenize & keep only what we need
DATASET = DATASET.map(tokenize, batched=True)
DATASET = DATASET.rename_column("label", "labels")
cols    = ["input_ids", "attention_mask", "labels"]

def dl(split, bs=64):
    # select only the 3 columns, then format as PyTorch
    ds = DATASET[split].select_columns(cols).with_format("torch")
    return torch.utils.data.DataLoader(ds, batch_size=bs, shuffle=(split=="train"))

# metrics
import evaluate
metric = evaluate.load("accuracy")
def compute_acc(model, loader, device="cuda"):
    model.eval()
    metric = evaluate.load("accuracy")
    with torch.no_grad():
        for batch in loader:
            # pull labels out as a Python list
            if "labels" in batch:
                labels = batch["labels"].tolist()
                inputs = {
                    k:batch[k].to(device)
                    for k in ("input_ids","attention_mask")
                }
            else:
                labels = batch["label"].tolist()
                inputs = {
                    k:batch[k].to(device)
                    for k in ("input_ids","attention_mask")
                }

            # forward
            out   = model(**inputs)
            preds = out.logits.argmax(-1)

            # convert preds to list as well
            metric.add_batch(
                predictions=preds.cpu().tolist(),
                references= labels
            )

    return metric.compute()["accuracy"]


def param_stats(model):
    total = sum(p.numel() for p in model.parameters())
    nz    = sum((p!=0).sum().item() for p in model.parameters())
    return total, nz

def model_size_mb(model):
    return sum(p.numel()*p.element_size() for p in model.parameters()) / (1024**2)

def train_one_epoch(model, loader, device="cuda", lr=5e-5):
    model.train()
    optim     = AdamW(model.parameters(), lr=lr)
    steps_tot = len(loader)
    sched     = get_linear_schedule_with_warmup(
                    optim, num_warmup_steps=int(0.1*steps_tot),
                    num_training_steps=steps_tot)
    loss_fn   = torch.nn.CrossEntropyLoss()
    for batch in loader:
        optim.zero_grad(set_to_none=True)
        labels  = batch["labels"].to(device)
        inputs  = {k:batch[k].to(device) for k in ("input_ids","attention_mask")}
        logits  = model(**inputs).logits
        loss    = loss_fn(logits, labels)
        loss.backward()
        optim.step(); sched.step()


device = "cuda"
train_loader = dl("train",  bs=128)
val_loader   = dl("validation", bs=128)

# ===============================================================
# ✨ SECTION 1.3 – Structured channel-pruning with torch_pruning
# ===============================================================
import torch_pruning as tp
from torch import nn
import gc
import pandas as pd
# ── helper: wrap the HF model so DG sees tensor args, not dict ──
class BertWrapper(nn.Module):
    def __init__(self, core):
        super().__init__(); self.core=core
    def forward(self, input_ids, attention_mask):
        return self.core(input_ids=input_ids,
                         attention_mask=attention_mask).logits
def prune_once(model, sparsity: float):
    """
    Structured L1 channel-pruning on *BertIntermediate.dense* only.
    Keeps attention / residual sizes intact → no shape errors.
    """
    class Wrapper(nn.Module):                 # so DG sees tensors not dicts
        def __init__(self, core): super().__init__(); self.core=core
        def forward(self, input_ids, attention_mask):
            return self.core(input_ids=input_ids,
                              attention_mask=attention_mask).logits

    wrapped   = Wrapper(model)
    example   = (torch.randint(0, tokenizer.vocab_size,(1,128)).to(device),
                 torch.ones(1,128,dtype=torch.long).to(device))
    DG = tp.DependencyGraph().build_dependency(wrapped,
                                               example_inputs=example)
    pruned_layers = 0
    for m in wrapped.modules():
        # prune ONLY Intermediate dense layers (out_features == 3072)
        if isinstance(m, nn.Linear) and m.out_features == 3072:
            w = m.weight.detach()
            keep = int(w.size(0) * (1 - sparsity))
            if keep == w.size(0): continue      # nothing to prune
            idx = torch.argsort(w.abs().sum(dim=1))[: w.size(0)-keep]
            grp = DG.get_pruning_group(m,
                    tp.prune_linear_out_channels, idxs=idx.tolist())
            grp.prune()
            pruned_layers += 1
    return wrapped.core, pruned_layers

train-00000-of-00001.parquet:   0%|          | 0.00/3.11M [00:00<?, ?B/s]

validation-00000-of-00001.parquet:   0%|          | 0.00/72.8k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/148k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]



config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

Downloading builder script:   0%|          | 0.00/4.20k [00:00<?, ?B/s]

In [None]:
# ✨ SECTION 2 – Baseline model evaluation
# -------------------------------------------------

baseline = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased", num_labels=2).to(device)
train_one_epoch(baseline, train_loader, device)         # ⬅ fine-tune

t0=time.time(); base_acc = compute_acc(baseline, val_loader, device); t_infer=time.time()-t0
base_total, base_nz = param_stats(baseline)
print(f"Baseline fp32 acc={base_acc:.4f}  time={t_infer*1000:.1f}ms  params={base_total/1e6:.1f}M")

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Baseline fp32 acc=0.9209  time=1479.9ms  params=109.5M


In [None]:

# ===============================================================
# ✨ SECTION 3 – Experiment over several sparsity levels
# ===============================================================
results, sparsities = [], [x/10 for x in range(10)]    # feel free to edit

for sp in sparsities:
    torch.cuda.empty_cache(); gc.collect()
    # model = AutoModelForSequenceClassification.from_pretrained(
    #         "bert-base-uncased", num_labels=2,
    #         torch_dtype=torch.float16).to(device)
    model = AutoModelForSequenceClassification.from_pretrained(
            "bert-base-uncased", num_labels=2).to(device)

    model, pruned_layers_count = prune_once(model, sparsity=sp)    # ➊ PRUNE
    train_one_epoch(model, train_loader, device)  # ⬅ same 1-epoch fine-tune

    t0=time.time(); acc=compute_acc(model,val_loader, "cuda"); torch.cuda.synchronize()
    infer=time.time()-t0
    tot,nz = param_stats(model)
    sz_mb = model_size_mb(model)
    results.append(dict(sparsity=sp,
                        acc=acc,
                        latency_ms=infer*1000,
                        nonzero_params_M=nz/1e6,
                        pruned_layers = pruned_layers_count,
                        size_MB=round(sz_mb,1)))

df = pd.DataFrame(results)
print(df.to_markdown(index=False))

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are

|   sparsity |      acc |   latency_ms |   nonzero_params_M |   pruned_layers |   size_MB |
|-----------:|---------:|-------------:|-------------------:|----------------:|----------:|
|        0   | 0.922018 |      2397.34 |           109.484  |               0 |     417.6 |
|        0.1 | 0.912844 |      2358.99 |           103.803  |              12 |     396   |
|        0.2 | 0.91055  |      2283.64 |            98.1407 |              12 |     374.4 |
|        0.3 | 0.896789 |      2214.99 |            92.4784 |              12 |     352.8 |
|        0.4 | 0.905963 |      2104.46 |            86.8161 |              12 |     331.2 |
|        0.5 | 0.893349 |      2078.77 |            81.1538 |              12 |     309.6 |
|        0.6 | 0.887615 |      1975.46 |            75.473  |              12 |     287.9 |
|        0.7 | 0.875    |      2044.05 |            69.8107 |              12 |     266.3 |
|        0.8 | 0.850917 |      1823.73 |            64.1484 |              12 | 

In [None]:
# ✨ SECTION 4 – Prune-then-quantize experiment (adds size counting)
# -----------------------------------------------------------------
import copy, torch.nn.quantized.dynamic as nnqd
val_loader_cpu   = dl("validation", bs=256)
quant_sparsities = [x/10 for x in range(10)]
qresults         = []

def param_stats_fp32(model):
    tot, nz = param_stats(model)
    return tot, nz, model_size_mb(model)

def param_stats_int8(qmodel):
    tot = nz = size_bytes = 0
    for m in qmodel.modules():
        if isinstance(m, nnqd.Linear):
            w = m.weight(); tot += w.numel(); nz += (w!=0).sum().item()
            size_bytes += w.numel()           # 1-byte/elt
            if m.bias() is not None:
                b = m.bias(); tot += b.numel(); nz += (b!=0).sum().item()
                size_bytes += b.numel()*4     # fp32 bias
        else:
            for p in m.parameters(recurse=False):
                tot += p.numel(); nz += (p!=0).sum().item()
                size_bytes += p.numel()*p.element_size()
    return tot, nz, size_bytes/(1024**2)

for sp in quant_sparsities:
    torch.cuda.empty_cache(); gc.collect()
    model = AutoModelForSequenceClassification.from_pretrained(
                "bert-base-uncased", num_labels=2).to(device)
    if sp:
        model, pruned_layers_count = prune_once(model, sp)
    train_one_epoch(model, train_loader, device)
    fp_tot, fp_nz, fp_mb = param_stats_fp32(model)
    fp_acc = compute_acc(model, val_loader, device)

    qmodel = torch.quantization.quantize_dynamic(
                 copy.deepcopy(model).cpu(), {torch.nn.Linear},
                 dtype=torch.qint8)
    t0=time.time(); int8_acc=compute_acc(qmodel, val_loader_cpu,"cpu")
    int8_ms=(time.time()-t0)*1000
    in_tot,in_nz,in_mb = param_stats_int8(qmodel)

    qresults.append(dict(
        sparsity=sp,
        fp32_acc=round(fp_acc,4),
        int8_acc=round(int8_acc,4),
        int8_latency_ms=round(int8_ms,1),
        fp32_params_M=round(fp_nz/1e6,3),
        fp32_size_MB=round(fp_mb,1),
        int8_params_M=round(in_nz/1e6,3),
        int8_size_MB=round(in_mb,1)
    ))

print(pd.DataFrame(qresults).to_markdown(index=False))

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are

|   sparsity |   fp32_acc |   int8_acc |   int8_latency_ms |   fp32_params_M |   fp32_size_MB |   int8_params_M |   int8_size_MB |
|-----------:|-----------:|-----------:|------------------:|----------------:|---------------:|----------------:|---------------:|
|        0   |     0.9209 |     0.8601 |           35556.5 |         109.484 |          417.6 |          97.499 |          173   |
|        0.1 |     0.9048 |     0.8429 |           32841.8 |         103.803 |          396   |          93.095 |          167.5 |
|        0.2 |     0.8922 |     0.8567 |           32888   |          98.141 |          374.4 |          88.664 |          162.1 |
|        0.3 |     0.9002 |     0.8853 |           31714.7 |          92.478 |          352.8 |          84.186 |          156.7 |
|        0.4 |     0.8922 |     0.883  |           30061.6 |          86.816 |          331.2 |          79.678 |          151.3 |
|        0.5 |     0.8933 |     0.8853 |           28419.4 |          81.154 |     

In [None]:
# ===============================================================
# ✨ SECTION 5 – Knowledge-Distillation experiment
# ---------------------------------------------------------------
from transformers import DistilBertForSequenceClassification
import torch.nn.functional as F

T      = 4.0      # temperature
alpha  = 0.5      # KD vs CE weight
epochs = 3        # student KD fine-tuning epochs

# ➊ train / load the teacher (full BERT, no pruning/quant yet)
teacher = AutoModelForSequenceClassification.from_pretrained(
            "bert-base-uncased", num_labels=2).to(device)
train_one_epoch(teacher, train_loader, device)          # reuse helper

# ➋ build the smaller student
student = DistilBertForSequenceClassification.from_pretrained(
            "distilbert-base-uncased", num_labels=2).to(device)

def kd_train_one_epoch(student, teacher, loader,
                       T=4.0, alpha=0.5, lr=3e-5):
    student.train(); teacher.eval()
    opt   = AdamW(student.parameters(), lr=lr)
    sched = get_linear_schedule_with_warmup(
              opt, num_warmup_steps=int(0.1*len(loader)),
              num_training_steps=len(loader))
    ce_fn = torch.nn.CrossEntropyLoss()
    for batch in loader:
        opt.zero_grad(set_to_none=True)
        labels = batch["labels"].to(device)
        inputs = {k:batch[k].to(device)
                  for k in ("input_ids","attention_mask")}

        with torch.no_grad():
            t_logits = teacher(**inputs).logits / T          # teacher

        s_logits = student(**inputs).logits                  # student
        s_logits_T = s_logits / T

        # distillation + hard-label CE
        kd_loss = F.kl_div(
                    F.log_softmax(s_logits_T, dim=-1),
                    F.softmax(t_logits,   dim=-1),
                    reduction="batchmean") * (T**2)
        ce_loss = ce_fn(s_logits, labels)
        loss    = alpha*kd_loss + (1-alpha)*ce_loss

        loss.backward(); opt.step(); sched.step()

# ➌ knowledge-distill
for _ in range(epochs):
    kd_train_one_epoch(student, teacher, train_loader, T, alpha)

# ➍ evaluate & print stats
acc     = compute_acc(student, val_loader, device)
tot,nz  = param_stats(student)
print(f"Student acc={acc:.4f} | params={nz/1e6:.1f} M | size={model_size_mb(student):.1f} MB")


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Student acc=0.9083 | params=67.0 M | size=255.4 MB


In [None]:
# ===============================================================
# ✨ SECTION 6 – KD student: prune 0.4 → quantize INT8
# ---------------------------------------------------------------
import copy, torch.nn.quantized.dynamic as nnqd

sparsity = 0.4
# ➊ copy KD-trained student
student_p = copy.deepcopy(student).to(device)

# ➋ structured prune
student_p, pruned_layers_cnt = prune_once(student_p, sparsity)
train_one_epoch(student_p, train_loader, device)

# ➌ evaluate pruned FP32
fp32_acc  = compute_acc(student_p, val_loader, device)
fp_tot, fp_nz = param_stats(student_p)
fp_size   = model_size_mb(student_p)
print(f"Pruned 40%  | acc {fp32_acc:.4f} | nz {fp_nz/1e6:.2f} M | "
      f"size {fp_size:.1f} MB | layers {pruned_layers_cnt}")

# ➍ dynamic INT8 quantization
student_int8 = torch.quantization.quantize_dynamic(
                 copy.deepcopy(student_p).cpu(), {torch.nn.Linear},
                 dtype=torch.qint8)

# ➎ evaluate INT8
val_loader_cpu = dl("validation", bs=256)
int8_acc = compute_acc(student_int8, val_loader_cpu, device="cpu")

def int8_stats(qm):
    tot = nz = size_bytes = 0
    for m in qm.modules():
        if isinstance(m, nnqd.Linear):
            w = m.weight(); tot += w.numel(); nz += (w!=0).sum().item()
            size_bytes += w.numel()
            if m.bias() is not None:
                b = m.bias(); tot += b.numel(); nz += (b!=0).sum().item()
                size_bytes += b.numel()*4
        else:
            for p in m.parameters(recurse=False):
                tot += p.numel(); nz += (p!=0).sum().item()
                size_bytes += p.numel()*p.element_size()
    return nz, size_bytes/(1024**2)

in_nz, in_mb = int8_stats(student_int8)
print(f"INT8        | acc {int8_acc:.4f} | nz {in_nz/1e6:.2f} M | "
      f"size {in_mb:.1f} MB")




Pruned 40%  | acc 0.8979 | nz 55.62 M | size 212.2 MB | layers 6
INT8        | acc 0.8922 | nz 52.09 M | size 121.4 MB


In [None]:
# ===============================================================
# ✨ SECTION 7 – Low-Rank factorization of KD student
# ---------------------------------------------------------------
import copy, gc, torch
from torch import nn

rank_frac = 0.5          # keep 25 % of singular values
device    = "cuda"

class LowRankLinear(nn.Module):
    def __init__(self, in_f, out_f, rank, bias):
        super().__init__()
        self.A = nn.Linear(in_f, rank, bias=False)    # (rank, in_f)
        self.B = nn.Linear(rank, out_f, bias=bias)    # (out_f, rank)
    def forward(self, x):                              # B(A(x))
        return self.B(self.A(x))

@torch.no_grad()
def _replace_linear(parent, name, linear, rank):
    W = linear.weight.detach().cpu()                  # (out_f, in_f)
    U, S, Vh = torch.linalg.svd(W, full_matrices=False)
    U_r, S_r, Vh_r = U[:, :rank], S[:rank], Vh[:rank, :]
    new_layer = LowRankLinear(linear.in_features,
                              linear.out_features,
                              rank, bias=linear.bias is not None)
    # A weight:  (rank, in_f)
    new_layer.A.weight.copy_(Vh_r)
    # B weight:  (out_f, rank) = U_r * S_r
    new_layer.B.weight.copy_(U_r * S_r.unsqueeze(0))
    if linear.bias is not None:
        new_layer.B.bias.copy_(linear.bias)
    setattr(parent, name, new_layer)

def factorize_module(module, rank_frac=0.25):
    for name, child in list(module.named_children()):
        if isinstance(child, nn.Linear) and child.weight.requires_grad:
            r = max(1, int(min(child.in_features,
                               child.out_features) * rank_frac))
            if r < min(child.in_features, child.out_features):
                _replace_linear(module, name, child, r)
        else:
            factorize_module(child, rank_frac)

def low_rank_student(base_model, rank_frac=0.25):
    m = copy.deepcopy(base_model).cpu()
    factorize_module(m, rank_frac)
    return m.to(device)

torch.cuda.empty_cache(); gc.collect()
student_lr = low_rank_student(student, rank_frac)

# optional recovery fine-tune
train_one_epoch(student_lr, train_loader, device, lr=3e-5)

# evaluation
acc      = compute_acc(student_lr, val_loader, device)
tot, nz  = param_stats(student_lr)
print(f"Low-Rank {rank_frac:.2f} | acc {acc:.4f} | "
      f"params {nz/1e6:.2f} M | size {model_size_mb(student_lr):.1f} MB")






Low-Rank 0.50 | acc 0.9002 | params 56.34 M | size 214.9 MB


In [None]:
torch.cuda.empty_cache(); gc.collect()
student_lr_int8 = low_rank_student(student_int8, rank_frac)

# optional recovery fine-tune
train_one_epoch(student_lr_int8, train_loader, device, lr=3e-5)

# evaluation
acc      = compute_acc(student_lr_int8, val_loader, device)
tot, nz  = param_stats(student_lr_int8)
print(f"Low-Rank {rank_frac:.2f} | acc {acc:.4f} | "
      f"params {nz/1e6:.2f} M | size {model_size_mb(student_lr_int8):.1f} MB")



Low-Rank 0.50 | acc 0.9048 | params 56.34 M | size 214.9 MB


In [None]:
# ===============================================================
# ✨ SECTION 2 – Distill → prune → low-rank factorization → quantize
# ---------------------------------------------------------------
from transformers import DistilBertForSequenceClassification
import torch.nn.functional as F
import copy, gc, torch.nn.quantized.dynamic as nnqd
from torch import nn

# ---------- 1.  Knowledge-distil a student ---------------------
T, alpha, epochs = 4.0, 0.5, 3             # hyper-params

teacher = AutoModelForSequenceClassification.from_pretrained(
            "bert-base-uncased", num_labels=2).to(device)
train_one_epoch(teacher, train_loader, device, lr=2e-5)   # single-epoch FT

student = DistilBertForSequenceClassification.from_pretrained(
            "distilbert-base-uncased", num_labels=2).to(device)

def kd_train_one_epoch(s, t, loader, T=4.0, a=0.5, lr=3e-5):
    s.train(); t.eval()
    opt   = AdamW(s.parameters(), lr=lr)
    sched = get_linear_schedule_with_warmup(
              opt, num_warmup_steps=int(0.1*len(loader)),
              num_training_steps=len(loader))
    ce = nn.CrossEntropyLoss()
    for batch in loader:
        opt.zero_grad(set_to_none=True)
        y = batch["labels"].to(device)
        x = {k:batch[k].to(device) for k in ("input_ids","attention_mask")}
        with torch.no_grad():
            t_log = t(**x).logits / T
        s_log     = s(**x).logits
        kd_loss   = F.kl_div(F.log_softmax(s_log/T, dim=-1),
                             F.softmax(t_log, dim=-1),
                             reduction="batchmean") * (T**2)
        ce_loss   = ce(s_log, y)
        (a*kd_loss + (1-a)*ce_loss).backward()
        opt.step(); sched.step()

for _ in range(epochs):
    kd_train_one_epoch(student, teacher, train_loader, T, alpha)

base_acc = compute_acc(student, val_loader, device)
print(f"[KD]    acc={base_acc:.4f} | params={param_stats(student)[1]/1e6:.2f} M "
      f"| size={model_size_mb(student):.1f} MB")

# ---------- 2.  Structured pruning (40 % channels) -------------
student_p, _ = prune_once(copy.deepcopy(student).to(device), sparsity=0.4)
train_one_epoch(student_p, train_loader, device, lr=3e-5)
pruned_acc = compute_acc(student_p, val_loader, device)
print(f"[PRUNE] acc={pruned_acc:.4f} | params={param_stats(student_p)[1]/1e6:.2f} M "
      f"| size={model_size_mb(student_p):.1f} MB")

# ---------- 3.  Low-rank factorisation (rank = 50 % min dim) ---
class LowRankLinear(nn.Module):
    def __init__(self, in_f, out_f, r, bias):
        super().__init__()
        self.A = nn.Linear(in_f, r, bias=False)
        self.B = nn.Linear(r, out_f, bias=bias)
    def forward(self, x): return self.B(self.A(x))

@torch.no_grad()
def _replace_linear(parent, name, lin, r):
    W = lin.weight.detach().cpu()
    U,S,Vh = torch.linalg.svd(W, full_matrices=False)
    new = LowRankLinear(lin.in_features, lin.out_features, r,
                        bias=lin.bias is not None)
    new.A.weight.copy_(Vh[:r, :])
    new.B.weight.copy_(U[:, :r] * S[:r].unsqueeze(0))
    if lin.bias is not None: new.B.bias.copy_(lin.bias)
    setattr(parent, name, new)

def factorize(module, frac=0.5):
    for n,c in list(module.named_children()):
        if isinstance(c, nn.Linear) and c.weight.requires_grad:
            r = max(1, int(min(c.in_features, c.out_features)*frac))
            if r < min(c.in_features, c.out_features):
                _replace_linear(module, n, c, r)
        else:
            factorize(c, frac)

torch.cuda.empty_cache(); gc.collect()
student_lr = copy.deepcopy(student_p).cpu()
factorize(student_lr, frac=0.5)
student_lr = student_lr.to(device)
train_one_epoch(student_lr, train_loader, device, lr=3e-5)
lr_acc = compute_acc(student_lr, val_loader, device)
print(f"[LoRa]  acc={lr_acc:.4f} | params={param_stats(student_lr)[1]/1e6:.2f} M "
      f"| size={model_size_mb(student_lr):.1f} MB")

# ---------- 4.  Dynamic INT8 quantisation ----------------------
student_int8 = torch.quantization.quantize_dynamic(
                  copy.deepcopy(student_lr).cpu(), {nn.Linear},
                  dtype=torch.qint8)

val_loader_cpu = dl("validation", bs=256)
int8_acc = compute_acc(student_int8, val_loader_cpu, device="cpu")

def int8_stats(qm):
    tot = nz = size = 0
    for m in qm.modules():
        if isinstance(m, nnqd.Linear):
            w = m.weight(); tot+=w.numel(); nz+=(w!=0).sum().item(); size+=w.numel()
            if m.bias() is not None:
                b = m.bias(); tot+=b.numel(); nz+=(b!=0).sum().item(); size+=b.numel()*4
        else:
            for p in m.parameters(recurse=False):
                tot+=p.numel(); nz+=(p!=0).sum().item(); size+=p.numel()*p.element_size()
    return nz/1e6, size/(1024**2)

nzM, sizeMB = int8_stats(student_int8)
print(f"[INT8]  acc={int8_acc:.4f} | params={nzM:.2f} M | size={sizeMB:.1f} MB")


Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


[KD]    acc=0.9083 | params=66.96 M | size=255.4 MB




[PRUNE] acc=0.8922 | params=55.62 M | size=212.2 MB




[LoRa]  acc=0.8911 | params=50.67 M | size=193.3 MB
[INT8]  acc=0.8865 | params=48.08 M | size=116.7 MB
