In [None]:
pip install datasets

In [None]:
import time
import torch
import numpy as np
from tqdm import tqdm
from sklearn.cluster import KMeans
from transformers import AutoTokenizer, OPTForSequenceClassification
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from torch import nn

# Paths and parameters
MODEL_NAME = "facebook/opt-350m"  # Change to appropriate OPT model
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_CLUSTERS = 2000  # Number of clusters for attention heads
MAX_SAMPLES = 100  # Limit for dataset size during evaluation

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = OPTForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2).to(DEVICE)
model.eval()


def calculate_layer_sensitivities(model, tokenized_inputs):
    """
    Calculate sensitivities for each layer based on attention outputs.
    Sensitivity is approximated by the variance of attention scores across heads.
    """
    input_ids = tokenized_inputs["input_ids"].to(DEVICE)
    attention_mask = tokenized_inputs["attention_mask"].to(DEVICE)
    sensitivities = []

    def attention_hook(module, input, output):
        # Extract attention scores (usually the first element of the tuple)
        if isinstance(output, tuple):
            attention_scores = output[0]  # Assuming the first element is the attention scores
        else:
            attention_scores = output

        # Convert to numpy
        attention_scores = attention_scores.detach().cpu().numpy()
        print(f"Attention scores shape: {attention_scores.shape}")  # Debug print

        # Adjust variance calculation based on shape
        if attention_scores.ndim == 4:  # Expected shape: [batch, heads, seq_len, seq_len]
            variance = np.var(attention_scores, axis=(1, 2, 3))  # Across batch, heads, and seq
        elif attention_scores.ndim == 3:  # Shape: [batch, seq_len, seq_len]
            variance = np.var(attention_scores, axis=(1, 2))  # Across batch and seq
        elif attention_scores.ndim == 2:  # Shape: [seq_len, seq_len]
            variance = np.var(attention_scores, axis=(0, 1))  # Across seq
        else:
            raise ValueError(f"Unexpected attention scores shape: {attention_scores.shape}")

        sensitivities.append(np.mean(variance))

    hooks = []
    for layer in range(len(model.model.decoder.layers)):
        hook = model.model.decoder.layers[layer].self_attn.register_forward_hook(attention_hook)
        hooks.append(hook)

    with torch.no_grad():
        model(input_ids, attention_mask=attention_mask)

    for hook in hooks:
        hook.remove()

    return sensitivities

def calculate_kv_cache_size_final(model, num_clusters):
    """
    Calculate the total size of K and V projection weights for the clustered model.
    """
    total_size = 0
    num_layers = model.config.num_hidden_layers
    head_dim = model.config.hidden_size // model.config.num_attention_heads  # Per-head dimension

    for _ in range(num_layers):
        # Each cluster contributes one set of weights for keys and values
        k_proj_size = num_clusters * head_dim
        v_proj_size = num_clusters * head_dim
        total_size += k_proj_size + v_proj_size

    return total_size
def cluster_attention_heads(model, num_clusters):
    """
    Cluster attention heads based on their sensitivities and outputs cluster centers.
    """
    cluster_centers_keys = []
    cluster_centers_values = []

    for layer in range(model.config.num_hidden_layers):
        # Generate synthetic data for keys and values (placeholder for actual data)
        k_data = np.random.rand(NUM_CLUSTERS, model.config.hidden_size)
        v_data = np.random.rand(NUM_CLUSTERS, model.config.hidden_size)

        # Perform clustering for keys
        kmeans_keys = KMeans(n_clusters=num_clusters, random_state=42).fit(k_data)
        cluster_centers_keys.append(kmeans_keys.cluster_centers_)

        # Perform clustering for values
        kmeans_values = KMeans(n_clusters=num_clusters, random_state=42).fit(v_data)
        cluster_centers_values.append(kmeans_values.cluster_centers_)

    return cluster_centers_keys, cluster_centers_values


def preprocess_data(dataset, max_samples=100, max_length=512):
    inputs = []
    labels = []

    for i, example in enumerate(dataset):
        if i >= max_samples:
            break

        # Extract text input based on dataset structure
        if "question" in example:  # For BoolQ-like datasets
            text = f"Question: {example['question']} Context: {example.get('context', example.get('passage', ''))}"
        elif "ctx" in example and "endings" in example:  # For HellaSwag
            text = f"Context: {example['ctx']} Ending: {example['endings'][0]}"  # Using the first ending
        else:
            raise ValueError("Unsupported dataset format or missing keys.")

        # Extract the label dynamically
        if "answer" in example:  # BoolQ-like datasets
            label = int(example["answer"])  # Convert boolean to integer (True=1, False=0)
        elif "label" in example:
            label = example["label"]
        elif "gold_label" in example:
            label = example["gold_label"]
        else:
            label = None  # Default if no valid label is found

        if label is None:
            print(f"Skipping example due to missing label: {example}")
            continue  # Skip this example

        inputs.append(text)
        labels.append(label)

    tokenized_inputs = tokenizer(
        inputs, return_tensors="pt", padding=True, truncation=True, max_length=max_length
    )
    return tokenized_inputs, torch.tensor(labels, dtype=torch.long)


def evaluate_model(model, tokenized_inputs, labels):
    input_ids = tokenized_inputs["input_ids"].to(DEVICE)
    attention_mask = tokenized_inputs["attention_mask"].to(DEVICE)

    with torch.no_grad():
        logits = model(input_ids, attention_mask=attention_mask).logits
        predictions = torch.argmax(logits, dim=-1).cpu().numpy()

    return accuracy_score(labels.cpu().numpy(), predictions)


def process_dataset(dataset_name):
    dataset = load_dataset(dataset_name, split="validation")
    tokenized_inputs, labels = preprocess_data(dataset, max_samples=MAX_SAMPLES)

    initial_accuracy = evaluate_model(model, tokenized_inputs, labels)
    print(f"\nDataset: {dataset_name}")
    print(f"Initial Accuracy: {initial_accuracy:.4f}")

    sensitivities = calculate_layer_sensitivities(model, tokenized_inputs)
    cluster_centers_keys, cluster_centers_values = cluster_attention_heads(model, NUM_CLUSTERS)

    print(f"Cluster centers (keys): {np.array(cluster_centers_keys).shape}")
    print(f"Cluster centers (values): {np.array(cluster_centers_values).shape}")


def main():
    for dataset_name in ["boolq"]:
        process_dataset(dataset_name)


if __name__ == "__main__":
    main()
