In [1]:
import time
import torch
from transformers import AutoTokenizer, OPTForCausalLM
from datasets import load_dataset
from sklearn.cluster import KMeans
import numpy as np
from tqdm import tqdm
import os

In [3]:
from transformers import AutoTokenizer, BertForSequenceClassification

MODEL_NAME = "bert-base-uncased"  # Change the model name
NUM_CLUSTERS = 12  # Adjust clusters based on num_heads
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = BertForSequenceClassification.from_pretrained(MODEL_NAME).to(DEVICE)
model.eval()

def extract_attention_keys(model, tokenized_inputs, low_sensitivity_layers):
    """
    Extract attention keys for specified low-sensitivity layers.
    Adapted for BERT architecture.
    """
    input_ids = tokenized_inputs["input_ids"].to(DEVICE)
    attention_mask = tokenized_inputs["attention_mask"].to(DEVICE)
    keys = []

    def hook_fn(module, input, output):
        """
        Hook function to handle 3D outputs from BERT's self-attention layers.
        """
        print(f"Hook output shape: {output.shape}")  # Debugging
        batch_size, seq_length, embed_dim = output.shape
        num_heads = model.config.num_attention_heads
        head_dim = embed_dim // num_heads

        reshaped_output = output.view(batch_size, seq_length, num_heads, head_dim)
        reshaped_output = reshaped_output.permute(0, 2, 1, 3)  # [batch_size, num_heads, seq_length, head_dim]
        keys.append(reshaped_output.detach().cpu().numpy())

    with torch.no_grad():
        for layer_idx, layer in enumerate(model.bert.encoder.layer):
            if layer_idx in low_sensitivity_layers:
                hook_handle = layer.attention.self.key.register_forward_hook(hook_fn)
                model(input_ids, attention_mask=attention_mask)
                hook_handle.remove()

    return keys


def optimize_kv_cache_with_clustering(model, low_sensitivity_layers, cluster_centers):
    """
    Reduce the size of K,V cache by clustering attention heads and sharing cache space within clusters.
    Adapted for BERT architecture.
    """
    for layer_idx, layer in enumerate(model.bert.encoder.layer):
        if layer_idx not in low_sensitivity_layers:
            continue

        attention_layer = layer.attention.self
        num_heads = attention_layer.key.weight.size(0)
        head_dim = attention_layer.key.weight.size(-1)

        print(f"Layer {layer_idx}: num_heads = {num_heads}, head_dim = {head_dim}")
        print(f"Cluster centers shape = {cluster_centers.shape}")

        assert cluster_centers.shape[1] == head_dim, (
            f"Cluster centers feature dimension {cluster_centers.shape[1]} does not match head_dim {head_dim}"
        )

        cluster_assignments = np.argmin(
            np.linalg.norm(cluster_centers[:, None, :] - attention_layer.key.weight.view(num_heads, head_dim).cpu().numpy(), axis=2),
            axis=0
        )

        num_clusters = cluster_centers.shape[0]

        with torch.no_grad():
            new_k_proj = torch.zeros((num_clusters, head_dim), device=attention_layer.key.weight.device)
            new_v_proj = torch.zeros((num_clusters, head_dim), device=attention_layer.value.weight.device)

            for cluster_id in range(num_clusters):
                heads_in_cluster = torch.tensor(
                    [i for i, cluster in enumerate(cluster_assignments) if cluster == cluster_id],
                    device=attention_layer.key.weight.device,
                )
                if len(heads_in_cluster) == 0:
                    continue

                new_k_proj[cluster_id] = attention_layer.key.weight.view(num_heads, head_dim)[heads_in_cluster].mean(dim=0)
                new_v_proj[cluster_id] = attention_layer.value.weight.view(num_heads, head_dim)[heads_in_cluster].mean(dim=0)

            attention_layer.key.weight = torch.nn.Parameter(new_k_proj.view(-1, attention_layer.key.weight.size(1)))
            attention_layer.value.weight = torch.nn.Parameter(new_v_proj.view(-1, attention_layer.value.weight.size(1)))

    return model


def main():
    # Load dataset
    dataset = load_dataset("glue", "sst2", split="validation[:100]")
    tokenized_inputs = tokenizer(dataset["sentence"], return_tensors="pt", padding=True, truncation=True)
    tokenized_inputs = {k: v.to(DEVICE) for k, v in tokenized_inputs.items()}

    # Example labels for SST-2
    labels = torch.tensor(dataset["label"]).to(DEVICE)

    # Compute sensitivity scores
    sensitivity_scores = compute_sensitivity_scores(model, tokenized_inputs)
    low_sensitivity_layers = identify_low_sensitivity_layers(sensitivity_scores)

    # Extract attention keys
    attention_keys = extract_attention_keys(model, tokenized_inputs, low_sensitivity_layers)

    # Perform clustering
    kmeans = cluster_attention_heads(attention_keys)

    # Evaluate baseline
    baseline_accuracy, baseline_time = evaluate_model(model, tokenized_inputs, labels)
    print(f"Baseline Accuracy: {baseline_accuracy}, Baseline Time: {baseline_time}")

    # Apply clustering
    clustered_model = optimize_kv_cache_with_clustering(model, low_sensitivity_layers, kmeans.cluster_centers_)

    # Evaluate clustered model
    clustered_accuracy, clustered_time = evaluate_model(clustered_model, tokenized_inputs, labels)
    clustered_memory = report_memory()

    print(f"Clustered Accuracy: {clustered_accuracy}, Clustered Time: {clustered_time}")
    print(f"Memory Reduction Percentage: {((baseline_memory - clustered_memory) / baseline_memory) * 100}")


if __name__ == "__main__":
    main()


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


RuntimeError: CUDA error: out of memory
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
