In [None]:
pip install datasets

In [None]:
import time
import torch
import numpy as np
from sklearn.cluster import KMeans
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertForSequenceClassification, AutoTokenizer

# Constants
MODEL_NAME = "bert-base-uncased"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
NUM_CLUSTERS = 16
MAX_SAMPLES = 1000
SENSITIVE_LAYER_PERCENTAGE = 0.3
EPOCHS = 3
BATCH_SIZE = 16
LEARNING_RATE = 1e-4

# Preprocessing Data
def preprocess_data(dataset, max_samples=100, max_length=512):
    inputs = []
    labels = []

    for i, example in enumerate(dataset):
        if i >= max_samples:
            break

        # Extract text input based on dataset structure
        if "question" in example:  # For BoolQ-like datasets
            text = f"Question: {example['question']} Context: {example.get('context', example.get('passage', ''))}"
        elif "ctx" in example and "endings" in example:  # For HellaSwag
            text = f"Context: {example['ctx']} Ending: {example['endings'][0]}"  # Using the first ending
        else:
            raise ValueError("Unsupported dataset format or missing keys.")

        # Extract the label dynamically
        if "answer" in example:  # BoolQ-like datasets
            label = int(example["answer"])  # Convert boolean to integer (True=1, False=0)
        elif "label" in example:
            label = example["label"]
        elif "gold_label" in example:
            label = example["gold_label"]
        else:
            label = None  # Default if no valid label is found

        if label is None:
            print(f"Skipping example due to missing label: {example}")
            continue  # Skip this example

        inputs.append(text)
        labels.append(label)
    MODEL_NAME = "bert-base-cased"
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    tokenized_inputs = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt")
    return tokenized_inputs, torch.tensor(labels, dtype=torch.long)


# Calculating Layer Sensitivities
def calculate_layer_sensitivities(model):
    sensitivities = []
    for layer in model.bert.encoder.layer:
        key_weights = layer.attention.self.key.weight.detach().cpu().numpy()
        value_weights = layer.attention.self.value.weight.detach().cpu().numpy()
        sensitivity = np.var(key_weights) + np.var(value_weights)
        sensitivities.append(sensitivity)
    return sensitivities

# Identifying Top Sensitive Layers
def get_top_sensitive_layers(sensitivities, percentage):
    num_sensitive_layers = int(len(sensitivities) * percentage)
    top_layers = np.argsort(sensitivities)[-num_sensitive_layers:]
    return sorted(top_layers)

# Clustering Layer Weights
def cluster_layers(model, num_clusters):
    for layer in model.bert.encoder.layer:
        for proj_name in ["key", "value"]:
            proj = getattr(layer.attention.self, proj_name)
            weights = proj.weight.detach().cpu().numpy()
            original_shape = weights.shape
            flattened_weights = weights.reshape(-1, 1)

            # Apply KMeans clustering
            kmeans = KMeans(n_clusters=num_clusters, random_state=0).fit(flattened_weights)
            clustered_weights = kmeans.cluster_centers_[kmeans.labels_]
            clustered_weights = clustered_weights.reshape(original_shape)

            # Update weights with clustered weights
            with torch.no_grad():
                proj.weight.copy_(torch.tensor(clustered_weights, device=DEVICE))

# Fine-Tuning Sensitive Layers
def fine_tune_model(model, dataloader, sensitive_layers):
    # Freeze all layers except the sensitive ones
    for name, param in model.named_parameters():
        param.requires_grad = any(f"bert.encoder.layer.{i}." in name for i in sensitive_layers)

    optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LEARNING_RATE)
    criterion = torch.nn.CrossEntropyLoss()

    model.train()
    for epoch in range(EPOCHS):
        for batch in dataloader:
            input_ids, attention_mask, labels = batch
            inputs = {'input_ids': input_ids.to(DEVICE), 'attention_mask': attention_mask.to(DEVICE)}
            labels = labels.to(DEVICE)

            optimizer.zero_grad()
            outputs = model(**inputs)

            logits = outputs.logits
            loss = criterion(logits, labels)

            loss.backward()
            optimizer.step()

# Evaluating the Model
def evaluate_model(model, dataloader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch in dataloader:
            input_ids, attention_mask, labels = batch
            inputs = {'input_ids': input_ids.to(DEVICE), 'attention_mask': attention_mask.to(DEVICE)}
            outputs = model(**inputs)

            logits = outputs.logits
            predictions = torch.argmax(logits, dim=-1)

            correct += (predictions == labels.to(DEVICE)).sum().item()
            total += labels.size(0)
    return correct / total

# Create Dataloader
def create_dataloader(inputs, labels, batch_size):
    input_ids = inputs['input_ids']
    attention_mask = inputs['attention_mask']
    labels = torch.tensor(labels)  # Convert labels to tensors
    dataset = TensorDataset(input_ids, attention_mask, labels)
    return DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Main Workflow
def main(dataset, dataset_name):
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = BertForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2).to(DEVICE)

    tokenized_inputs, labels = preprocess_data(dataset, MAX_SAMPLES, dataset_name)

    # if not inputs or not labels:
    #     raise ValueError("The dataset is empty or preprocessing returned no data.")

    # tokenized_inputs = tokenizer(inputs, padding=True, truncation=True, return_tensors="pt")
    dataloader = create_dataloader(tokenized_inputs, labels, BATCH_SIZE)

    sensitivities = calculate_layer_sensitivities(model)
    sensitive_layers = get_top_sensitive_layers(sensitivities, SENSITIVE_LAYER_PERCENTAGE)

    print(f"Sensitive Layers for {dataset_name}:", sensitive_layers)

    cluster_layers(model, NUM_CLUSTERS)

    start_time = time.time()
    accuracy_before = evaluate_model(model, dataloader)
    print(f"Accuracy Before Fine-Tuning on {dataset_name}: {accuracy_before}")

    fine_tune_model(model, dataloader, sensitive_layers)

    accuracy_after = evaluate_model(model, dataloader)
    end_time = time.time()
    print(f"Accuracy After Fine-Tuning on {dataset_name}: {accuracy_after}")
    print(f"Accuracy Drop on {dataset_name}: {accuracy_before - accuracy_after}")
    print(f"Total Time for {dataset_name}: {end_time - start_time} seconds")

# Example Dataset Placeholder
from datasets import load_dataset

# for dataset_name in ["hellaswag", "piqa", "boolq"]:
for dataset_name in ["boolq"]:
  # if dataset_name == "hellaswag" or "piqa" :
  #   dataset = load_dataset(dataset_name, split="train[:1000]")
  #   main(dataset, dataset_name)
  # else:
    dataset  = load_dataset("super_glue", "boolq", split="validation")
    main(dataset, dataset_name)
