In [1]:
pip install datasets

Defaulting to user installation because normal site-packages is not writeable
Note: you may need to restart the kernel to use updated packages.


In [1]:
import torch
import numpy as np
from tqdm import tqdm
from sklearn.cluster import KMeans
from transformers import AutoTokenizer, BertForSequenceClassification
from datasets import load_dataset
from sklearn.metrics import accuracy_score
from torch import nn

# Paths and parameters
MODEL_NAME = "bert-base-uncased"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
NUM_CLUSTERS = 200  # Number of clusters for attention heads
MAX_SAMPLES = 100  # Limit for dataset size during evaluation

# Load model and tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2).to(DEVICE)
model.eval()


def preprocess_data(dataset, max_samples=100, max_length=512):
    """
    Preprocess the PIQA dataset for binary classification.
    Each example consists of a goal and two solutions (sol1, sol2).
    The task is to predict which solution is better (label: 0 or 1).
    """
    inputs = []
    labels = []

    for i, example in enumerate(dataset):
        if i >= max_samples:
            break
        goal = example["goal"]
        sol1 = example["sol1"]
        sol2 = example["sol2"]
        label = example["label"]

        # Concatenate the goal and solutions into a single input
        text = f"Goal: {goal} [SEP] Solution 1: {sol1} [SEP] Solution 2: {sol2}"
        inputs.append(text)
        labels.append(label)

    # Tokenize the inputs
    tokenized_inputs = tokenizer(
        inputs, return_tensors="pt", padding=True, truncation=True, max_length=max_length
    )
    return tokenized_inputs, torch.tensor(labels)

def extract_attention_heads(model, tokenized_inputs):
    """
    Extract attention keys and values for all layers and heads.
    """
    input_ids = tokenized_inputs["input_ids"].to(DEVICE)
    attention_mask = tokenized_inputs["attention_mask"].to(DEVICE)
    keys, values = [], []

    def key_hook(module, input, output):
        # Dynamically compute dimensions
        batch_size, seq_length, embed_dim = output.size()
        num_heads = model.config.num_attention_heads
        head_dim = embed_dim // num_heads

        print(f"Key output shape: {output.shape}")
        print(f"Batch size: {batch_size}, Seq length: {seq_length}, Embed dim: {embed_dim}")
        print(f"Num heads: {num_heads}, Head dim: {head_dim}")

        reshaped_keys = output.view(batch_size, seq_length, num_heads, head_dim)
        reshaped_keys = reshaped_keys.permute(0, 2, 1, 3)  # [batch_size, num_heads, seq_length, head_dim]
        keys.append(reshaped_keys.detach().cpu().numpy())

    def value_hook(module, input, output):
        # Dynamically compute dimensions
        batch_size, seq_length, embed_dim = output.size()
        num_heads = model.config.num_attention_heads
        head_dim = embed_dim // num_heads

        print(f"Value output shape: {output.shape}")
        print(f"Batch size: {batch_size}, Seq length: {seq_length}, Embed dim: {embed_dim}")
        print(f"Num heads: {num_heads}, Head dim: {head_dim}")

        reshaped_values = output.view(batch_size, seq_length, num_heads, head_dim)
        reshaped_values = reshaped_values.permute(0, 2, 1, 3)  # [batch_size, num_heads, seq_length, head_dim]
        values.append(reshaped_values.detach().cpu().numpy())

    with torch.no_grad():
        for layer in range(model.config.num_hidden_layers):
            key_handle = model.bert.encoder.layer[layer].attention.self.key.register_forward_hook(key_hook)
            value_handle = model.bert.encoder.layer[layer].attention.self.value.register_forward_hook(value_hook)
            model(input_ids, attention_mask=attention_mask)
            key_handle.remove()
            value_handle.remove()

    # Concatenate across layers
    keys = np.concatenate(keys, axis=1)  # Combine all layers for keys
    values = np.concatenate(values, axis=1)  # Combine all layers for values
    return keys, values



def cluster_heads(keys, values, num_clusters):
    """
    Cluster keys and values using KMeans and return the cluster centers.
    """
    flattened_keys = keys.reshape(-1, keys.shape[-1])  # Flatten [num_heads * seq_length, head_dim]
    flattened_values = values.reshape(-1, values.shape[-1])  # Flatten [num_heads * seq_length, head_dim]

    print(f"Clustering keys and values with shapes: {flattened_keys.shape}, {flattened_values.shape}")
    kmeans_keys = KMeans(n_clusters=num_clusters, random_state=42).fit(flattened_keys)
    kmeans_values = KMeans(n_clusters=num_clusters, random_state=42).fit(flattened_values)

    return kmeans_keys.cluster_centers_, kmeans_values.cluster_centers_


def apply_clustered_projections(model, cluster_centers_keys, cluster_centers_values, num_clusters):
    """
    Apply clustering to all layers in the model without considering sensitivities.
    """
    for layer in range(model.config.num_hidden_layers):
        attention_layer = model.bert.encoder.layer[layer].attention.self
        num_heads = model.config.num_attention_heads
        embed_dim = attention_layer.key.weight.size(1)
        head_dim = embed_dim // num_heads
        reduced_dim = num_clusters * head_dim

        # Create new projection matrices based on the cluster centers
        new_key_weight = torch.zeros(
            (reduced_dim, embed_dim), device=attention_layer.key.weight.device
        )
        new_value_weight = torch.zeros(
            (reduced_dim, embed_dim), device=attention_layer.value.weight.device
        )

        # Populate the new weights using cluster centers
        for cluster_idx in range(num_clusters):
            cluster_key_center = torch.tensor(cluster_centers_keys[cluster_idx], device=new_key_weight.device)
            cluster_value_center = torch.tensor(cluster_centers_values[cluster_idx], device=new_value_weight.device)

            start_idx = cluster_idx * head_dim
            end_idx = (cluster_idx + 1) * head_dim

            # Use the cluster centers as the new weight values
            new_key_weight[start_idx:end_idx, :] = cluster_key_center.unsqueeze(1).expand(head_dim, embed_dim)
            new_value_weight[start_idx:end_idx, :] = cluster_value_center.unsqueeze(1).expand(head_dim, embed_dim)

        # Replace the original projection weights with clustered weights
        with torch.no_grad():
            attention_layer.key = nn.Linear(embed_dim, reduced_dim, bias=False)
            attention_layer.key.weight = nn.Parameter(new_key_weight)

            attention_layer.value = nn.Linear(embed_dim, reduced_dim, bias=False)
            attention_layer.value.weight = nn.Parameter(new_value_weight)

        print(f"Layer {layer}: Key and Value weights replaced with clustered projections.")

    # Update the model configuration to reflect the reduced number of clusters
    model.config.num_attention_heads = num_clusters
    return model


def evaluate_model(model, tokenized_inputs, labels):
    """
    Evaluate the model and compute accuracy.
    """
    input_ids = tokenized_inputs["input_ids"].to(DEVICE)
    attention_mask = tokenized_inputs["attention_mask"].to(DEVICE)

    predictions = []
    with torch.no_grad():
        logits = model(input_ids, attention_mask=attention_mask).logits
        predictions = torch.argmax(logits, dim=-1).cpu().numpy()

    return accuracy_score(labels.cpu().numpy(), predictions)


def main():
    dataset = load_dataset("piqa", split="validation")
    tokenized_inputs, labels = preprocess_data(dataset, max_samples=MAX_SAMPLES)
    
    # Extract attention keys and values
    keys, values = extract_attention_heads(model, tokenized_inputs)

    # Cluster the keys and values
    cluster_centers_keys, cluster_centers_values = cluster_heads(keys, values, NUM_CLUSTERS)

    # Apply clustered projections
    clustered_model = apply_clustered_projections(model, cluster_centers_keys, cluster_centers_values, NUM_CLUSTERS)

    # Evaluate clustered model
    clustered_accuracy = evaluate_model(clustered_model, tokenized_inputs, labels)
    print(f"Clustered Accuracy: {clustered_accuracy}")


if __name__ == "__main__":
    main()


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/440M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


piqa.py:   0%|          | 0.00/5.36k [00:00<?, ?B/s]

README.md:   0%|          | 0.00/8.41k [00:00<?, ?B/s]

The repository for piqa contains custom code which must be executed to correctly load the dataset. You can inspect the repository content at https://hf.co/datasets/piqa.
You can avoid this prompt in future by passing the argument `trust_remote_code=True`.

Do you wish to run the custom code? [y/N]  y


Downloading data:   0%|          | 0.00/1.82M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/815k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/16113 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3084 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1838 [00:00<?, ? examples/s]

Key output shape: torch.Size([100, 218, 768])
Batch size: 100, Seq length: 218, Embed dim: 768
Num heads: 12, Head dim: 64
Value output shape: torch.Size([100, 218, 768])
Batch size: 100, Seq length: 218, Embed dim: 768
Num heads: 12, Head dim: 64
Key output shape: torch.Size([100, 218, 768])
Batch size: 100, Seq length: 218, Embed dim: 768
Num heads: 12, Head dim: 64
Value output shape: torch.Size([100, 218, 768])
Batch size: 100, Seq length: 218, Embed dim: 768
Num heads: 12, Head dim: 64
Key output shape: torch.Size([100, 218, 768])
Batch size: 100, Seq length: 218, Embed dim: 768
Num heads: 12, Head dim: 64
Value output shape: torch.Size([100, 218, 768])
Batch size: 100, Seq length: 218, Embed dim: 768
Num heads: 12, Head dim: 64
Key output shape: torch.Size([100, 218, 768])
Batch size: 100, Seq length: 218, Embed dim: 768
Num heads: 12, Head dim: 64
Value output shape: torch.Size([100, 218, 768])
Batch size: 100, Seq length: 218, Embed dim: 768
Num heads: 12, Head dim: 64
Key outp

RuntimeError: shape '[100, 218, 12, 64]' is invalid for input of size 279040000