In [None]:
import torch
MODEL_PATH = "./drive/MyDrive/model_compression/head_pruned_student/best_model"   # pruned (non-quantized) saved model
QUANT_OUT =  "./drive/MyDrive/model_compression/head_pruned_student_quantized"    # where we'll save quantized model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_LEN = 128
N_RUNS = 200

print("Device:", DEVICE)

In [None]:
import os, time
import torch
from datasets import load_dataset
from evaluate import load
from transformers import AutoTokenizer, AutoModelForSequenceClassification, DataCollatorWithPadding, BitsAndBytesConfig
from torch.utils.data import DataLoader

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, local_files_only=True)

task = "sst2"
raw = load_dataset("glue", task)
def preprocess(examples):
    return tokenizer(examples["sentence"], truncation=True, padding="max_length", max_length=MAX_LEN)

val_ds = raw["validation"].map(preprocess, batched=True)
if "sentence" in val_ds.column_names:
    val_ds = val_ds.remove_columns(["sentence", "idx"])
else:
    cols_to_remove = [c for c in ["sentence","idx"] if c in val_ds.column_names]
    if cols_to_remove:
        val_ds = val_ds.remove_columns(cols_to_remove)

if "label" in val_ds.column_names and "labels" not in val_ds.column_names:
    val_ds = val_ds.rename_column("label", "labels")

val_ds.set_format(type="torch")
data_collator = DataCollatorWithPadding(tokenizer)
val_loader = DataLoader(val_ds, batch_size=32, collate_fn=data_collator)
metric = load("accuracy")

In [None]:
from transformers import AutoModelForSequenceClassification

# DEVICE
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load in 16-bit float
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_PATH,
    torch_dtype=torch.float16,   # load all weights as FP16
    device_map="auto" if device=="cuda" else None
)

model.eval()

# Check some parameter dtypes
cnt = 0
for n, p in model.named_parameters():
    print(n, p.device, p.dtype)
    cnt += 1
    if cnt > 6:
        break

In [None]:
os.makedirs(QUANT_OUT, exist_ok=True)
print("Saving quantized model to:", QUANT_OUT)
model.save_pretrained(QUANT_OUT)
tokenizer.save_pretrained(QUANT_OUT)

In [None]:
def folder_size_mb(path):
    total = 0
    for root, dirs, files in os.walk(path):
        for f in files:
            total += os.path.getsize(os.path.join(root, f))
    return total / (1024**2)

def count_params(model):
    total = sum(p.numel() for p in model.parameters())
    return total

total_params = count_params(model)
print(f"Parameter count: {total_params:,} (equivalent)")
print(f"Saved quantized model folder size: {folder_size_mb(QUANT_OUT):.2f} MB")

In [None]:
all_preds, all_labels = [], []

model_device = next(model.parameters()).device

with torch.no_grad():
    for batch in val_loader:
        batch_inputs = {k: v.to(model_device) for k, v in batch.items() if k in ["input_ids", "attention_mask"]}
        labels = batch["labels"]

        outputs = model(**batch_inputs)
        preds = outputs.logits.argmax(-1).cpu()
        all_preds.extend(preds.numpy())
        all_labels.extend(labels.numpy())

val_acc = metric.compute(predictions=all_preds, references=all_labels)["accuracy"]
print(f"Validation Accuracy: {val_acc*100:.4f}%")

In [None]:
TEST_MODEL_PATH = "./drive/MyDrive/model_compression/"

sample = tokenizer("This is a sample sentence to measure latency.", return_tensors="pt",
                   max_length=MAX_LEN, truncation=True, padding="max_length")
input_ids = sample["input_ids"].to(model_device)
attention_mask = sample["attention_mask"].to(model_device)

# warmup
with torch.no_grad():
    for _ in range(10):
        _ = model(input_ids=input_ids, attention_mask=attention_mask)

if model_device.type == "cuda":
    torch.cuda.synchronize()

t0 = time.time()
with torch.no_grad():
    for _ in range(N_RUNS):
        _ = model(input_ids=input_ids, attention_mask=attention_mask)
if model_device.type == "cuda":
    torch.cuda.synchronize()
t1 = time.time()

latency_ms = (t1 - t0)/N_RUNS*1000
print(f"Inference latency (batch=1, avg over {N_RUNS} runs): {latency_ms:.2f} ms")