In [None]:
import os, time
import torch
import torch.nn.utils.prune as prune
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer, DataCollatorWithPadding
from datasets import load_dataset
from evaluate import load

In [None]:
STUDENT_PATH = "./drive/MyDrive/bert_sst2_student/best_model"
STUDENT_TOKENIZER = "./drive/MyDrive/bert_sst2_student/tokenizer"
tokenizer = AutoTokenizer.from_pretrained(STUDENT_TOKENIZER)
student = AutoModelForSequenceClassification.from_pretrained(STUDENT_PATH)

device = "cuda" if torch.cuda.is_available() else "cpu"
student.to(device)

# dataset
task = "sst2"
raw = load_dataset("glue", task)
max_len = 128
def preprocess(examples):
    return tokenizer(examples["sentence"], truncation=True)
encoded = raw.map(preprocess, batched=True)
encoded = encoded.remove_columns(["sentence", "idx"])
encoded.set_format(type="torch")
train_ds = encoded["train"]
val_ds = encoded["validation"]
data_collator = DataCollatorWithPadding(tokenizer)
metric = load("accuracy")

In [None]:
def count_parameters(model):
    total = sum(p.numel() for p in model.parameters())
    nonzero = sum((p != 0).sum().item() for p in model.parameters())
    return total, nonzero, 1 - (nonzero/total)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = logits.argmax(-1)
    return {"accuracy": metric.compute(predictions=preds, references=labels)["accuracy"]}

def evaluate_latency(model, tokenizer, device, max_len=128, N=200):
    model.to(device).eval()
    sample = tokenizer("This is a sample sentence to measure latency.", return_tensors="pt", max_length=max_len, truncation=True, padding="max_length")
    input_ids = sample['input_ids'].to(device)
    attention_mask = sample['attention_mask'].to(device)
    # warmup
    with torch.no_grad():
        for _ in range(10):
            _ = model(input_ids=input_ids, attention_mask=attention_mask)
    if device == "cuda":
        torch.cuda.synchronize()
    import time
    t0 = time.time()
    with torch.no_grad():
        for _ in range(N):
            _ = model(input_ids=input_ids, attention_mask=attention_mask)
    if device == "cuda":
        torch.cuda.synchronize()
    t1 = time.time()
    return (t1 - t0) / N * 1000  # ms


In [None]:
## unstructured ##

In [None]:
modules_to_prune = []
for name, module in student.named_modules():
    if isinstance(module, torch.nn.Linear):
        modules_to_prune.append((module, 'weight'))

amount = 0.4  # 40% sparsity
prune.global_unstructured(
    modules_to_prune,
    pruning_method=prune.L1Unstructured,
    amount=amount,
)
# check sparsity
total, nonzero, sparsity = count_parameters(student)
print(f"Total params: {total:,}, nonzero: {nonzero:,}, global density: {1-sparsity:.3f}, sparsity ≈ {sparsity:.2%}")

In [None]:
training_args = TrainingArguments(
    output_dir="./pruned_student",
    num_train_epochs=3,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=3e-5,
    weight_decay=0.01,
    fp16=True,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
)

trainer = Trainer(
    model=student,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()

In [None]:
# Remove reparam (makes zeros permanent)
for module, _ in modules_to_prune:
    try:
        prune.remove(module, 'weight')
    except Exception:
        pass

# report
total, nonzero, sparsity = count_parameters(student)
print("After remove():", total, nonzero, sparsity)
print("Val eval:")
res = trainer.evaluate()
print(res)
print("Latency (ms):", evaluate_latency(student, tokenizer, device))

In [None]:
trainer.save_model("./drive/MyDrive/pruned_student_unstructured/best_model_40")
# Latency (ms): 5.5631959438323975

In [None]:
## Structured Pruning ##

In [None]:
# Check if method exists
print(hasattr(student, "prune_heads"))          # True if model supports pruning
print(hasattr(student, "distilbert"))           # DistilBERT container
print(student.distilbert)                       # check attributes


In [None]:
n_layers = student.config.num_hidden_layers
n_heads = student.config.num_attention_heads

heads_to_prune = {}
for layer in range(n_layers):
    # prune first half of heads in each layer
    heads_to_prune[layer] = list(range(n_heads // 2))

print("Heads to prune per layer:", heads_to_prune)

In [None]:
student.distilbert.prune_heads(heads_to_prune)

In [None]:
from transformers import TrainingArguments, Trainer, DataCollatorWithPadding

training_args = TrainingArguments(
    output_dir="./head_pruned_student",
    num_train_epochs=2,
    per_device_train_batch_size=16,
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=3e-5,
    weight_decay=0.01,
    fp16=True,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
)

trainer = Trainer(
    model=student,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics
)

trainer.train()
trainer.save_model("./drive/MyDrive/head_pruned_student/best_model")

In [None]:
# Example latency measurement
latency_ms = evaluate_latency(student, tokenizer, device)
print(f"Head-pruned student latency: {latency_ms:.2f} ms")

In [None]:
total_params, nonzero_params, sparsity = count_parameters(student)
print(f"Total params: {total_params:,}")
print(f"Non-zero params: {nonzero_params:,}")

In [None]:
import os

def folder_size_mb(path):
    total = 0
    for root, dirs, files in os.walk(path):
        for f in files:
            total += os.path.getsize(os.path.join(root, f))
    return total / (1024**2)

model_folder = "./drive/MyDrive/head_pruned_student/best_model"
size_mb = folder_size_mb(model_folder)
print(f"Saved model folder size: {size_mb:.2f} MB")

In [None]:
latency_ms = evaluate_latency(student, tokenizer, device, max_len=128, N=200)
print(f"Average batch=1 latency over 200 runs: {latency_ms:.2f} ms")

In [None]:
res = trainer.evaluate()
print(res)