In [None]:
import time
import torch
import numpy as np
from tqdm import tqdm
from sklearn.cluster import KMeans
from transformers import AutoTokenizer, OPTForSequenceClassification
from datasets import load_dataset
from sklearn.metrics import accuracy_score

# Paths and parameters
MODEL_NAME = "facebook/opt-350m"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_CLUSTERS = 2000
MAX_SAMPLES = 100

# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

def calculate_kv_cache_size_final(model, num_clusters):
    """
    Calculate the total size of K and V projection weights for the clustered model.
    """
    total_size = 0
    num_layers = model.config.num_hidden_layers
    head_dim = model.config.hidden_size // model.config.num_attention_heads  # Per-head dimension

    for _ in range(num_layers):
        # Each cluster contributes one set of weights for keys and values
        k_proj_size = num_clusters * head_dim
        v_proj_size = num_clusters * head_dim
        total_size += k_proj_size + v_proj_size

    return total_size * 2

def preprocess_data(dataset, dataset_name, max_samples=100, max_length=512):
    """
    Preprocesses data for tokenization and extracts labels.
    """
    inputs, labels = [], []
    for i, example in enumerate(dataset):
        if i >= max_samples:
            break

        try:
            if dataset_name == "hellaswag":
                text = f"Context: {example['ctx']} Ending: {example['endings'][0]}"
                label = int(example['label'])
            elif dataset_name == "piqa":
                text = f"Question: {example['goal']} Choice: {example['sol1']}"
                label = int(example['label'])
            else:
                raise ValueError(f"Unsupported dataset: {dataset_name}")

            inputs.append(text)
            labels.append(label)
        except KeyError as e:
            print(f"Skipping example due to missing key: {e}")
        except ValueError as e:
            print(f"Skipping example due to invalid label: {e}")

    if len(inputs) == 0 or len(labels) == 0:
        raise ValueError(f"No valid examples found in dataset: {dataset_name}")

    tokenized_inputs = tokenizer(
        inputs, return_tensors="pt", padding=True, truncation=True, max_length=max_length
    )
    return tokenized_inputs, torch.tensor(labels, dtype=torch.long)

def calculate_kv_cache_size(model):
    """
    Calculates the total size of the key-value (KV) cache for the model.
    """
    kv_cache_size = 0
    for layer in model.model.decoder.layers:
        # Key and Value weights
        key_size = layer.self_attn.k_proj.weight.numel()
        value_size = layer.self_attn.v_proj.weight.numel()
        kv_cache_size += key_size + value_size
    return kv_cache_size

def cluster_attention_heads(model, num_clusters):
    """
    Clusters the key and value projection weights of all attention heads in all layers.
    """
    for layer_idx, layer in enumerate(model.model.decoder.layers):
        # Extract key and value weights
        k_weights = layer.self_attn.k_proj.weight.data
        v_weights = layer.self_attn.v_proj.weight.data

        # Reshape to (num_heads, head_dim, hidden_size)
        num_heads = model.config.num_attention_heads
        head_dim = model.config.hidden_size // num_heads

        k_weights = k_weights.view(num_heads, head_dim, -1).cpu().numpy()
        v_weights = v_weights.view(num_heads, head_dim, -1).cpu().numpy()

        # Flatten for clustering
        k_flat = k_weights.reshape(num_heads, -1)
        v_flat = v_weights.reshape(num_heads, -1)

        # Adjust number of clusters to be <= num_heads
        adjusted_clusters = min(num_heads, num_clusters)
        print(f"Layer {layer_idx}: Adjusting clusters to {adjusted_clusters} (num_heads={num_heads})")

        # Perform clustering
        kmeans_k = KMeans(n_clusters=adjusted_clusters, random_state=0).fit(k_flat)
        kmeans_v = KMeans(n_clusters=adjusted_clusters, random_state=0).fit(v_flat)

        # Replace with centroids
        k_clustered = torch.tensor(kmeans_k.cluster_centers_, device=DEVICE)
        v_clustered = torch.tensor(kmeans_v.cluster_centers_, device=DEVICE)

        # Reshape and update weights
        layer.self_attn.k_proj.weight.data = k_clustered.view_as(layer.self_attn.k_proj.weight.data)
        layer.self_attn.v_proj.weight.data = v_clustered.view_as(layer.self_attn.v_proj.weight.data)

    print("Clustering of attention heads completed.")

def evaluate_model(model, tokenized_inputs, labels):
    """
    Evaluates the model and returns accuracy.
    """
    input_ids = tokenized_inputs["input_ids"].to(DEVICE)
    attention_mask = tokenized_inputs["attention_mask"].to(DEVICE)

    start_time = time.time()

    with torch.no_grad():
        logits = model(input_ids, attention_mask=attention_mask).logits
        predictions = torch.argmax(logits, dim=-1).cpu().numpy()

    end_time = time.time()
    total_time = end_time - start_time
    speed = len(labels) / total_time
    print(f"Evaluation time: {total_time:.2f} seconds, Speed: {speed:.2f} samples/second")
    return accuracy_score(labels.cpu().numpy(), predictions), speed

def process_dataset(dataset_name):
    """
    Processes a specific dataset: loads, preprocesses, clusters, and evaluates the model.
    """
    try:
        dataset = load_dataset(dataset_name, split="validation")
    except Exception as e:
        print(f"Error loading dataset {dataset_name}: {e}")
        return

    try:
        start_time = time.time()
        tokenized_inputs, labels = preprocess_data(dataset, dataset_name, max_samples=MAX_SAMPLES)
        preprocess_time = time.time() - start_time
        print(f"Preprocessing time: {preprocess_time:.2f} seconds")
    except ValueError as e:
        print(f"Error preprocessing dataset {dataset_name}: {e}")
        return

    # Adjust number of labels dynamically
    num_labels = len(set(labels.tolist()))
    print(f"Number of labels in {dataset_name}: {num_labels}")

    # Load model with adjusted number of labels
    model = OPTForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=num_labels).to(DEVICE)

    # Calculate initial KV cache size
    initial_kv_cache_size = calculate_kv_cache_size(model)
    print(f"Initial KV cache size: {initial_kv_cache_size} elements")

    # Cluster attention heads
    cluster_attention_heads(model, NUM_CLUSTERS)

    # Calculate final KV cache size
    final_kv_cache_size = calculate_kv_cache_size_final(model, NUM_CLUSTERS)
    print(f"Final KV cache size: {final_kv_cache_size} elements")

    # Calculate KV cache reduction percentage
    kv_cache_reduction_percentage = ((initial_kv_cache_size - final_kv_cache_size) / initial_kv_cache_size) * 100
    print(f"KV cache reduction percentage: {kv_cache_reduction_percentage:.2f}%")

    print(f"Evaluating model on {dataset_name}...")
    accuracy, eval_speed = evaluate_model(model, tokenized_inputs, labels)
    print(f"Accuracy on {dataset_name}: {accuracy:.4f}")

def main():
    datasets = ["hellaswag", "piqa"]
    total_start_time = time.time()
    for dataset_name in datasets:
        print(f"Processing dataset: {dataset_name}")
        start_time = time.time()
        process_dataset(dataset_name)
        end_time = time.time()
        print(f"Total time for {dataset_name}: {end_time - start_time:.2f} seconds")
    total_end_time = time.time()
    print(f"Overall execution time: {total_end_time - total_start_time:.2f} seconds")

if __name__ == "__main__":
    main()
