In [3]:
import os
import time
import torch
import psutil
import numpy as np
from tqdm import tqdm
from transformers import AutoTokenizer, OPTForCausalLM
from sklearn.cluster import KMeans
from datasets import load_dataset

# Paths and parameters
MODEL_NAME = "facebook/opt-350m"
SAVE_DIR = "./results"
TOP_LOW_SENSITIVITY_PERCENT = 100  # Focus on the least sensitive 30% layers
NUM_CLUSTERS = 37  # Number of clusters for attention heads
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
MAX_SAMPLES = 100  # Limit for dataset size during evaluation

os.makedirs(SAVE_DIR, exist_ok=True)

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = OPTForCausalLM.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()


def report_memory():
    """
    Report memory usage in both GPU and CPU contexts.
    """
    memory_report = {}

    if torch.cuda.is_available():
        memory_report["GPU Memory Allocated (MB)"] = torch.cuda.memory_allocated() / 1e6
        memory_report["GPU Memory Reserved (MB)"] = torch.cuda.memory_reserved() / 1e6
    else:
        memory_report["GPU Memory"] = "CUDA not available"

    process = psutil.Process()
    memory_info = process.memory_info()
    memory_report["CPU Memory (MB)"] = memory_info.rss / 1e6

    return memory_report


def preprocess_data(dataset, max_samples=100, max_length=512):
    inputs = []
    labels = []

    for i, example in enumerate(dataset):
        if i >= max_samples:
            break
        goal = example.get("goal", "")
        sol1 = example.get("sol1", "")
        sol2 = example.get("sol2", "")
        label = example.get("label", None)
        text = f"{goal} {sol1} {sol2}"
        inputs.append(text)
        labels.append(label)

    tokenized_inputs = tokenizer(
        inputs, return_tensors="pt", padding=True, truncation=True, max_length=max_length
    )
    return tokenized_inputs, labels


def compute_sensitivity_scores(model, tokenized_inputs):
    num_layers = model.config.num_hidden_layers
    num_heads = model.config.num_attention_heads
    sensitivity_scores = torch.zeros(num_layers, num_heads, device=DEVICE)

    input_ids = tokenized_inputs["input_ids"].to(DEVICE)
    attention_mask = tokenized_inputs["attention_mask"].to(DEVICE)

    with torch.no_grad():
        original_output = model(input_ids, attention_mask=attention_mask).logits

    for layer in tqdm(range(num_layers), desc="Computing sensitivity scores"):
        for head in range(num_heads):
            def hook_fn(module, input, output):
                if output.ndim == 4:
                    output[:, head, :, :] = 0
                elif output.ndim == 3:
                    embed_dim = output.size(-1)
                    head_dim = embed_dim // num_heads
                    start_idx = head * head_dim
                    end_idx = (head + 1) * head_dim
                    output[:, :, start_idx:end_idx] = 0
                else:
                    raise ValueError(f"Unexpected output shape: {output.shape}")
                return output

            hook_handle = model.model.decoder.layers[layer].self_attn.out_proj.register_forward_hook(hook_fn)
            with torch.no_grad():
                perturbed_output = model(input_ids, attention_mask=attention_mask).logits
            sensitivity_scores[layer, head] = torch.norm(original_output - perturbed_output, p=2)
            hook_handle.remove()

    return sensitivity_scores


def identify_low_sensitivity_layers(sensitivity_scores, percent=TOP_LOW_SENSITIVITY_PERCENT):
    layer_scores = sensitivity_scores.mean(dim=1)
    num_layers_to_keep = int(len(layer_scores) * percent / 100)
    low_sensitivity_layers = torch.argsort(layer_scores)[:num_layers_to_keep]
    return low_sensitivity_layers


def extract_attention_keys(model, tokenized_inputs, low_sensitivity_layers):
    """
    Extract attention keys for specified low-sensitivity layers.
    """
    input_ids = tokenized_inputs["input_ids"].to(DEVICE)
    attention_mask = tokenized_inputs["attention_mask"].to(DEVICE)
    keys = []

    def hook_fn(module, input, output):
        batch_size, seq_length, embed_dim = output.shape
        num_heads = model.config.num_attention_heads
        head_dim = embed_dim // num_heads
        reshaped_output = output.view(batch_size, seq_length, num_heads, head_dim)
        reshaped_output = reshaped_output.permute(0, 2, 1, 3)  # [batch_size, num_heads, seq_length, head_dim]
        keys.append(reshaped_output.detach().cpu().numpy())

    with torch.no_grad():
        for layer in low_sensitivity_layers:
            hook_handle = model.model.decoder.layers[layer].self_attn.k_proj.register_forward_hook(hook_fn)
            model(input_ids, attention_mask=attention_mask)
            hook_handle.remove()

    return keys


def cluster_attention_heads(keys, num_clusters=NUM_CLUSTERS):
    """
    Cluster attention keys using KMeans.
    """
    reshaped_keys = [key.reshape(-1, key.shape[-1]) for key in keys]  # Flatten batch and sequence dimensions
    all_keys = np.concatenate(reshaped_keys, axis=0)  # Concatenate keys across layers
    print(f"Clustering on keys with shape: {all_keys.shape}")
    kmeans = KMeans(n_clusters=num_clusters, random_state=42).fit(all_keys)
    return kmeans
def calculate_kv_cache_size_full_model(model):
    """
    Calculate the total size of K and V projection weights for all layers in the OPT model.
    """
    total_size = 0
    num_layers = model.config.num_hidden_layers  # Total number of layers in the OPT model

    for layer in range(num_layers):
        # Access the self-attention layer in the decoder
        attention_layer = model.model.decoder.layers[layer].self_attn
        
        # Calculate sizes of key and value projection weights
        k_proj_size = attention_layer.k_proj.weight.numel()  # Number of elements in K projection weights
        v_proj_size = attention_layer.v_proj.weight.numel()  # Number of elements in V projection weights
        total_size += k_proj_size + v_proj_size

    return total_size

def calculate_kv_cache_size_final(model, num_clusters):
    """
    Calculate the total size of K and V projection weights for the clustered model.
    """
    total_size = 0
    num_layers = model.config.num_hidden_layers
    head_dim = model.config.hidden_size // model.config.num_attention_heads  # Per-head dimension

    for _ in range(num_layers):
        # Each cluster contributes one set of weights for keys and values
        k_proj_size = num_clusters * head_dim
        print("num_clusters")
        print(num_clusters)
        print("head_dim")
        print(head_dim)
        v_proj_size = num_clusters * head_dim
        total_size += k_proj_size + v_proj_size

    return total_size
def calculate_kv_cache_size(model, low_sensitivity_layers):
    """
    Calculate the total size of K and V projection weights for the specified low-sensitivity layers.
    Works specifically for OPT models (e.g., OPT-350M).
    """
    total_size = 0
    for layer in low_sensitivity_layers:
        # Access the self-attention layer in the decoder
        attention_layer = model.model.decoder.layers[layer].self_attn

        # Calculate sizes of key and value projection weights
        k_proj_size = attention_layer.k_proj.weight.numel()  # Number of elements in K projection weights
        v_proj_size = attention_layer.v_proj.weight.numel()  # Number of elements in V projection weights
        total_size += k_proj_size + v_proj_size

    return total_size
def optimize_kv_cache_with_clustering(model, low_sensitivity_layers, cluster_centers):
    """
    Replace K-V projection matrices in low-sensitivity layers using cluster centers.
    Optimizes memory and computation for low-sensitivity attention heads.
    Works specifically for OPT models (e.g., OPT-250M).
    """
    for layer in low_sensitivity_layers:
        # Access the self-attention layer of the OPT model
        attention_layer = model.model.decoder.layers[layer].self_attn
        num_heads = model.config.num_attention_heads
        embed_dim = attention_layer.k_proj.weight.size(1)  # Full embedding dimension (e.g., 512 for OPT-250M)
        head_dim = embed_dim // num_heads  # Per-head dimension (e.g., 64)

        # Validate cluster centers' shape
        assert cluster_centers.shape[1] == head_dim, (
            f"Cluster centers dimension {cluster_centers.shape[1]} does not match head_dim {head_dim}"
        )

        # Initialize new projection matrices for key and value projections
        new_k_proj = torch.zeros_like(attention_layer.k_proj.weight, device=attention_layer.k_proj.weight.device)
        new_v_proj = torch.zeros_like(attention_layer.v_proj.weight, device=attention_layer.v_proj.weight.device)

        # Distribute cluster centers across heads
        cluster_assignments = np.random.randint(0, cluster_centers.shape[0], size=num_heads)  # Random assignment
        for head in range(num_heads):
            cluster_id = cluster_assignments[head]
            start_idx = head * head_dim
            end_idx = (head + 1) * head_dim

            # Assign cluster center to the slice of the projection matrix
            # Expand the cluster center to match the [head_dim, embed_dim] shape
            cluster_center = torch.tensor(cluster_centers[cluster_id], device=new_k_proj.device)
            new_k_proj[start_idx:end_idx, :] = cluster_center.unsqueeze(1).expand(head_dim, embed_dim)
            new_v_proj[start_idx:end_idx, :] = cluster_center.unsqueeze(1).expand(head_dim, embed_dim)

        # Update the model's key and value projection weights
        with torch.no_grad():
            attention_layer.k_proj.weight.copy_(new_k_proj)
            attention_layer.v_proj.weight.copy_(new_v_proj)

        print(f"Layer {layer}: Key and Value weights updated with clustered weights.")

    return model





def evaluate_model(model, tokenized_inputs, labels):
    input_ids = tokenized_inputs["input_ids"].to(DEVICE)
    attention_mask = tokenized_inputs["attention_mask"].to(DEVICE)

    start_time = time.time()
    correct = 0

    for i in tqdm(range(len(labels)), desc="Evaluating model"):
        input_id = input_ids[i].unsqueeze(0)
        attention_mask_id = attention_mask[i].unsqueeze(0)

        with torch.no_grad():
            logits = model(input_id, attention_mask=attention_mask_id).logits

        final_token_logits = logits[0, -1, :]
        option1_score = final_token_logits[tokenizer.convert_tokens_to_ids("1")].item()
        option2_score = final_token_logits[tokenizer.convert_tokens_to_ids("2")].item()

        prediction = 0 if option1_score > option2_score else 1
        if prediction == labels[i]:
            correct += 1

    end_time = time.time()
    accuracy = correct / len(labels)
    inference_time = end_time - start_time
    return accuracy, inference_time


def main():
    dataset = load_dataset("piqa", split="validation")
    tokenized_inputs, labels = preprocess_data(dataset, max_samples=MAX_SAMPLES)

    baseline_kv_cache_size = calculate_kv_cache_size_full_model(model)
    print(f"Baseline K,V Cache Size: {baseline_kv_cache_size} elements")

    sensitivity_scores = compute_sensitivity_scores(model, tokenized_inputs)
    low_sensitivity_layers = identify_low_sensitivity_layers(sensitivity_scores)

    attention_keys = extract_attention_keys(model, tokenized_inputs, low_sensitivity_layers)
    kmeans = cluster_attention_heads(attention_keys)
    baseline_accuracy, baseline_time = evaluate_model(model, tokenized_inputs, labels)
    print(f"Baseline Accuracy: {baseline_accuracy}, Baseline Time: {baseline_time}")

    clustered_model = optimize_kv_cache_with_clustering(model, low_sensitivity_layers, kmeans.cluster_centers_)

    clustered_accuracy, clustered_time = evaluate_model(clustered_model, tokenized_inputs, labels)
    clustered_memory = report_memory()

    print(f"Clustered Accuracy: {clustered_accuracy}, Clustered Time: {clustered_time}")
    print(f"Clustered Memory Usage: {clustered_memory}")
    # print(f"Memory reduction Percentage (CPU): {memory_reduction*100/clustered_memory} MB")
    print(f"speed: {baseline_time/clustered_time}x")
    clustered_kv_cache_size = calculate_kv_cache_size(clustered_model, low_sensitivity_layers)
    print(f"Clustered K,V Cache Size: {clustered_kv_cache_size} elements")

    # Calculate K,V cache reduction percentage
    kv_cache_reduction = (
        (baseline_kv_cache_size - clustered_kv_cache_size) / baseline_kv_cache_size
    ) * 100
    print(f"K,V Cache Reduction Percentage: {kv_cache_reduction:.2f}%")


if __name__ == "__main__":
    main()




Baseline K,V Cache Size: 50331648 elements


Computing sensitivity scores: 100%|██████████| 24/24 [07:55<00:00, 19.80s/it]


Clustering on keys with shape: (8064000, 64)


Evaluating model: 100%|██████████| 100/100 [00:01<00:00, 55.06it/s]


Baseline Accuracy: 0.51, Baseline Time: 1.8173658847808838
Layer 4: Key and Value weights updated with clustered weights.
Layer 14: Key and Value weights updated with clustered weights.
Layer 3: Key and Value weights updated with clustered weights.
Layer 17: Key and Value weights updated with clustered weights.
Layer 2: Key and Value weights updated with clustered weights.
Layer 15: Key and Value weights updated with clustered weights.
Layer 16: Key and Value weights updated with clustered weights.
Layer 6: Key and Value weights updated with clustered weights.
Layer 18: Key and Value weights updated with clustered weights.
Layer 11: Key and Value weights updated with clustered weights.
Layer 12: Key and Value weights updated with clustered weights.
Layer 13: Key and Value weights updated with clustered weights.
Layer 8: Key and Value weights updated with clustered weights.
Layer 7: Key and Value weights updated with clustered weights.
Layer 5: Key and Value weights updated with cluster

Evaluating model: 100%|██████████| 100/100 [00:01<00:00, 55.39it/s]

Clustered Accuracy: 0.54, Clustered Time: 1.8071975708007812
Clustered Memory Usage: {'GPU Memory Allocated (MB)': 2658.093056, 'GPU Memory Reserved (MB)': 24188.551168, 'CPU Memory (MB)': 7748.448256}
speed: 1.0056265646570104x
Clustered K,V Cache Size: 50331648 elements
K,V Cache Reduction Percentage: 0.00%



