In [None]:
import time
import torch
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer, OPTForSequenceClassification
from datasets import load_dataset
from sklearn.metrics import accuracy_score

# Paths and parameters
MODEL_NAME = "facebook/opt-350m"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_CLUSTERS = 2000  # Number of clusters for attention heads
MAX_SAMPLES = 100  # Limit for dataset size during evaluation

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = OPTForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2).to(DEVICE)
model.eval()


def preprocess_data(dataset, max_samples=100, max_length=512):
    inputs = []
    labels = []

    for i, example in enumerate(dataset):
        if i >= max_samples:
            break
        goal = example["goal"]
        sol1 = example["sol1"]
        sol2 = example["sol2"]
        label = example["label"]

        text = f"Goal: {goal} [SEP] Solution 1: {sol1} [SEP] Solution 2: {sol2}"
        inputs.append(text)
        labels.append(label)

    tokenized_inputs = tokenizer(
        inputs, return_tensors="pt", padding=True, truncation=True, max_length=max_length
    )
    return tokenized_inputs, torch.tensor(labels)


def calculate_layer_sensitivities(model, tokenized_inputs):
    """
    Calculate sensitivities for each layer based on attention outputs.
    Sensitivity is approximated by the variance of attention scores.
    """
    input_ids = tokenized_inputs["input_ids"].to(DEVICE)
    attention_mask = tokenized_inputs["attention_mask"].to(DEVICE)
    sensitivities = []

    def attention_hook(module, input, output):
        # Extract attention scores
        if isinstance(output, tuple):
            attention_scores = output[0].detach().cpu().numpy()
        else:
            attention_scores = output.detach().cpu().numpy()

        # Dynamically handle dimensions of attention_scores
        if attention_scores.ndim == 4:  # Shape: [batch, heads, seq_len, seq_len]
            variance = np.var(attention_scores, axis=(1, 2, 3))
        elif attention_scores.ndim == 3:  # Shape: [batch, seq_len, seq_len]
            variance = np.var(attention_scores, axis=(1, 2))
        elif attention_scores.ndim == 2:  # Shape: [seq_len, seq_len]
            variance = np.var(attention_scores, axis=(0, 1))
        else:
            raise ValueError(f"Unexpected attention scores shape: {attention_scores.shape}")

        sensitivities.append(np.mean(variance))

    hooks = []
    for layer in range(model.config.num_hidden_layers):
        hook = model.model.decoder.layers[layer].self_attn.register_forward_hook(attention_hook)
        hooks.append(hook)

    with torch.no_grad():
        model(input_ids, attention_mask=attention_mask)

    for hook in hooks:
        hook.remove()

    return sensitivities



def quantize_cluster_centers(cluster_centers, precision):
    if precision == "fp32":
        return torch.tensor(cluster_centers, dtype=torch.float32), None
    elif precision == "fp16":
        return torch.tensor(cluster_centers, dtype=torch.float16), None
    elif precision == "int8":
        max_val = np.max(np.abs(cluster_centers))
        scale = 127 / max_val if max_val != 0 else 1.0
        quantized = np.round(cluster_centers * scale).astype(np.int8)
        return torch.tensor(quantized), scale
    else:
        raise ValueError(f"Unsupported precision: {precision}")


def apply_quantization_to_layers(model, sensitivities, cluster_centers_keys, cluster_centers_values):
    num_layers = len(sensitivities)
    sorted_indices = np.argsort(sensitivities)
    num_least_sensitive_layers = num_layers

    for i in range(num_least_sensitive_layers):
        layer = sorted_indices[i]
        sensitivity = sensitivities[layer]

        if sensitivity < np.percentile(sensitivities, 33):
            precision = "int8"
        elif sensitivity < np.percentile(sensitivities, 66):
            precision = "fp16"
        else:
            precision = "fp32"

        quantized_keys, _ = quantize_cluster_centers(cluster_centers_keys, precision)
        quantized_values, _ = quantize_cluster_centers(cluster_centers_values, precision)

        print(f"Layer {layer}: Quantized to {precision}.")

    return model


def calculate_kv_cache_size(model):
    """
    Calculate the total size of the key-value (KV) cache in full precision (32-bit floats).
    """
    total_size = 0
    for layer in range(model.config.num_hidden_layers):
        attention_layer = model.model.decoder.layers[layer].self_attn

        k_proj_size = attention_layer.k_proj.weight.numel()
        v_proj_size = attention_layer.v_proj.weight.numel()
        total_size += (k_proj_size + v_proj_size) * 4  # 4 bytes per float (fp32)
    return total_size


def evaluate_model(model, tokenized_inputs, labels):
    input_ids = tokenized_inputs["input_ids"].to(DEVICE)
    attention_mask = tokenized_inputs["attention_mask"].to(DEVICE)

    with torch.no_grad():
        logits = model(input_ids, attention_mask=attention_mask).logits
        predictions = torch.argmax(logits, dim=-1).cpu().numpy()

    return accuracy_score(labels.cpu().numpy(), predictions)


def calculate_kv_cache_size_final_with_precision(model, sensitivities, num_clusters):
    total_size = 0
    num_layers = model.config.num_hidden_layers
    head_dim = model.config.hidden_size // model.config.num_attention_heads

    sorted_indices = np.argsort(sensitivities)
    for layer_idx in range(num_layers):
        sensitivity = sensitivities[layer_idx]
        if sensitivity < np.percentile(sensitivities, 33):
            precision_bits = 8
        elif sensitivity < np.percentile(sensitivities, 66):
            precision_bits = 16
        else:
            precision_bits = 32

        k_proj_size = num_clusters * head_dim * (precision_bits / 8)
        v_proj_size = num_clusters * head_dim * (precision_bits / 8)
        total_size += k_proj_size + v_proj_size

    return total_size

def main():
    dataset = load_dataset("piqa", split="validation")
    tokenized_inputs, labels = preprocess_data(dataset, max_samples=MAX_SAMPLES)

    # Calculate initial KV cache size and evaluate model
    initial_kv_cache_size = calculate_kv_cache_size(model)
    start_time_before = time.time()
    initial_accuracy = evaluate_model(model, tokenized_inputs, labels)
    end_time_before = time.time()

    print(f"Initial KV Cache Size: {initial_kv_cache_size / (1024 ** 2):.2f} MB")
    print(f"Initial Accuracy: {initial_accuracy:.4f}")

    # Calculate layer sensitivities
    sensitivities = calculate_layer_sensitivities(model, tokenized_inputs)

    # Simulate cluster centers for keys and values
    cluster_centers_keys = np.random.rand(NUM_CLUSTERS, model.config.hidden_size)
    cluster_centers_values = np.random.rand(NUM_CLUSTERS, model.config.hidden_size)

    # Apply quantization based on sensitivities
    quantized_model = apply_quantization_to_layers(model, sensitivities, cluster_centers_keys, cluster_centers_values)

    # Calculate final KV cache size after quantization
    final_kv_cache_size = calculate_kv_cache_size_final_with_precision(quantized_model, sensitivities, NUM_CLUSTERS)

    # Evaluate quantized model
    quantized_accuracy = evaluate_model(quantized_model, tokenized_inputs, labels)

    # Calculate memory reduction percentage
    memory_reduction = ((initial_kv_cache_size - final_kv_cache_size) / initial_kv_cache_size) * 100

    # Display results
    print(f"Final KV Cache Size: {final_kv_cache_size / (1024 ** 2):.2f} MB")
    print(f"Quantized Model Accuracy: {quantized_accuracy:.4f}")
    print(f"Memory Reduction Percentage: {memory_reduction:.2f}%")

    # Save results to file
    results = (
        f"Initial KV Cache Size: {initial_kv_cache_size / (1024 ** 2):.2f} MB\n"
        f"Final KV Cache Size: {final_kv_cache_size / (1024 ** 2):.2f} MB\n"
        f"Memory Reduction Percentage: {memory_reduction:.2f}%\n"
        f"Initial Accuracy: {initial_accuracy:.4f}\n"
        f"Quantized Model Accuracy: {quantized_accuracy:.4f}\n"
    )

    with open("chai_quant_optseries.txt", "a") as file:  # Use "a" to append to the file
        file.write(results)
        file.write("-" * 40 + "\n")  # Add a separator for readability

if __name__ == "__main__":
    main()