In [None]:
pip install datasets

In [None]:
import time
import torch
import numpy as np
from tqdm import tqdm
from sklearn.cluster import KMeans
from transformers import AutoTokenizer, BertForSequenceClassification
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from torch import nn

# Paths and parameters
MODEL_NAME = "bert-base-uncased"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_CLUSTERS = 2000  # Number of clusters for attention heads
MAX_SAMPLES = 100  # Limit for dataset size during evaluation

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=4).to(DEVICE)
model.eval()


def preprocess_data(dataset, dataset_name, max_samples=100, max_length=512):
    inputs, labels = [], []
    for i, example in enumerate(dataset):
        if i >= max_samples:
            break
        try:
            if dataset_name == "hellaswag":
                text = f"Context: {example['ctx']} Ending: {example['endings'][0]}"
                label = int(example['label'])
            elif dataset_name == "piqa":
                text = f"Question: {example['goal']} Choice: {example['sol1']}"
                label = int(example['label'])
            else:
                raise ValueError(f"Unsupported dataset: {dataset_name}")
            inputs.append(text)
            labels.append(label)
        except KeyError as e:
            print(f"Skipping example due to missing key: {e}")
        except ValueError as e:
            print(f"Skipping example due to invalid label: {e}")
    if len(inputs) == 0 or len(labels) == 0:
        raise ValueError(f"No valid examples found in dataset: {dataset_name}")
    tokenized_inputs = tokenizer(
        inputs, return_tensors="pt", padding=True, truncation=True, max_length=max_length
    )
    return tokenized_inputs, torch.tensor(labels, dtype=torch.long)


def calculate_layer_sensitivities(model, tokenized_inputs):
    """
    Calculate sensitivities for each layer based on attention outputs.
    Sensitivity is approximated by the variance of attention scores across heads.
    """
    input_ids = tokenized_inputs["input_ids"].to(DEVICE)
    attention_mask = tokenized_inputs["attention_mask"].to(DEVICE)
    sensitivities = []

    def attention_hook(module, input, output):
        if isinstance(output, tuple):
            attention_scores = output[0]
        else:
            attention_scores = output

        attention_scores = attention_scores.detach().cpu().numpy()
        if attention_scores.ndim == 4:
            variance = np.var(attention_scores, axis=(1, 2, 3))
        elif attention_scores.ndim == 3:
            variance = np.var(attention_scores, axis=(1, 2))
        elif attention_scores.ndim == 2:
            variance = np.var(attention_scores, axis=(0, 1))
        else:
            raise ValueError(f"Unexpected attention scores shape: {attention_scores.shape}")

        sensitivities.append(np.mean(variance))

    hooks = []
    for layer in range(model.config.num_hidden_layers):
        hook = model.bert.encoder.layer[layer].attention.self.register_forward_hook(attention_hook)
        hooks.append(hook)

    with torch.no_grad():
        model(input_ids, attention_mask=attention_mask)

    for hook in hooks:
        hook.remove()

    return sensitivities


def quantize_cluster_centers(cluster_centers, precision):
    if precision == "fp32":
        return torch.tensor(cluster_centers, dtype=torch.float32), None
    elif precision == "fp16":
        return torch.tensor(cluster_centers, dtype=torch.float16), None
    elif precision == "int8":
        max_val = np.max(np.abs(cluster_centers))
        scale = 127 / max_val if max_val != 0 else 1.0
        quantized = np.round(cluster_centers * scale).astype(np.int8)
        return torch.tensor(quantized), scale
    else:
        raise ValueError(f"Unsupported precision: {precision}")


def apply_quantization_to_layers(model, sensitivities, cluster_centers_keys, cluster_centers_values):
    num_layers = len(sensitivities)
    sorted_indices = np.argsort(sensitivities)
    num_least_sensitive_layers = int(num_layers)

    for i in range(num_least_sensitive_layers):
        layer = sorted_indices[i]
        sensitivity = sensitivities[layer]

        if sensitivity < np.percentile(sensitivities, 33):
            precision = "int8"
        elif sensitivity < np.percentile(sensitivities, 66):
            precision = "fp16"
        else:
            precision = "fp32"

        quantized_keys, _ = quantize_cluster_centers(cluster_centers_keys, precision)
        quantized_values, _ = quantize_cluster_centers(cluster_centers_values, precision)

        print(f"Layer {layer}: Quantized to {precision}.")

    return model


def calculate_kv_cache_size(model):
    total_size = 0
    for layer in range(model.config.num_hidden_layers):
        attention_layer = model.bert.encoder.layer[layer].attention.self
        k_proj_size = attention_layer.key.weight.numel()
        v_proj_size = attention_layer.value.weight.numel()
        total_size += (k_proj_size + v_proj_size) * 4
    return total_size


def evaluate_model(model, tokenized_inputs, labels):
    input_ids = tokenized_inputs["input_ids"].to(DEVICE)
    attention_mask = tokenized_inputs["attention_mask"].to(DEVICE)

    with torch.no_grad():
        logits = model(input_ids, attention_mask=attention_mask).logits
        predictions = torch.argmax(logits, dim=-1).cpu().numpy()

    return accuracy_score(labels.cpu().numpy(), predictions)


def calculate_kv_cache_size_final_with_precision(model, sensitivities, num_clusters):
    total_size = 0
    num_layers = model.config.num_hidden_layers
    head_dim = model.config.hidden_size // model.config.num_attention_heads

    sorted_indices = np.argsort(sensitivities)
    for layer_idx in range(num_layers):
        sensitivity = sensitivities[layer_idx]
        if sensitivity < np.percentile(sensitivities, 33):
            precision_bits = 8
        elif sensitivity < np.percentile(sensitivities, 66):
            precision_bits = 16
        else:
            precision_bits = 32

        k_proj_size = num_clusters * head_dim * (precision_bits / 8)
        v_proj_size = num_clusters * head_dim * (precision_bits / 8)
        total_size += k_proj_size + v_proj_size

    return total_size


def main():
    print("Starting HellaSwag quantization process...")
    dataset = load_dataset("hellaswag", split="validation")
    tokenized_inputs, labels = preprocess_data(dataset,dataset_name= "hellaswag", max_samples=MAX_SAMPLES)

    initial_kv_cache_size = calculate_kv_cache_size(model)
    start_time_before = time.time()
    initial_accuracy = evaluate_model(model, tokenized_inputs, labels)
    end_time_before = time.time()

    initial_kv_cache_size_mb = initial_kv_cache_size / (1024 ** 2)
    inference_time_before = end_time_before - start_time_before
    print(f"Initial KV Cache Size: {initial_kv_cache_size_mb:.2f} MB")
    print(f"Initial Accuracy: {initial_accuracy:.4f}")

    sensitivities = calculate_layer_sensitivities(model, tokenized_inputs)
    cluster_centers_keys = np.random.rand(NUM_CLUSTERS, model.config.hidden_size)
    cluster_centers_values = np.random.rand(NUM_CLUSTERS, model.config.hidden_size)

    quantized_model = apply_quantization_to_layers(model, sensitivities, cluster_centers_keys, cluster_centers_values)
    final_kv_cache_size = calculate_kv_cache_size_final_with_precision(quantized_model, sensitivities, NUM_CLUSTERS)
    final_kv_cache_size_mb = final_kv_cache_size / (1024 ** 2)

    start_time_after = time.time()
    quantized_accuracy = evaluate_model(quantized_model, tokenized_inputs, labels)
    end_time_after = time.time()

    accuracy_drop = initial_accuracy - quantized_accuracy
    reduction_percentage = ((initial_kv_cache_size - final_kv_cache_size) / initial_kv_cache_size) * 100
    inference_time_after = end_time_after - start_time_after
    speed_increment = inference_time_before / inference_time_after

    print(f"Final KV Cache Size: {final_kv_cache_size_mb:.2f} MB")
    print(f"Reduction Percentage: {reduction_percentage:.2f}%")
    print(f"Quantized Model Accuracy: {quantized_accuracy:.4f}")
    print(f"Accuracy Drop: {accuracy_drop:.4f}")
    print(f"Speed Increment: {speed_increment:.4f}x")


if __name__ == "__main__":
    main()
