# Setup

In [None]:
from transformers import BertTokenizer, BertForSequenceClassification
from datasets import load_dataset
import torch
import numpy as np
import pandas as pd
from tqdm import tqdm

from google.colab import drive
drive.mount('/content/drive')

In [None]:
#constants
ID_TO_LABEL = {0: "contradiction", 1: "entailment", 2: "neutral"}
LABEL_TO_ID = {v: k for k, v in ID_TO_LABEL.items()}
LABEL_MAP = {0: 2, 1: 0, 2: 1}

# Calculating the attention score

In [None]:
# Load pretrained SNLI model
model_name = "textattack/bert-base-uncased-snli"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, output_attentions=True)

# Add correct label mapping to match the pre-trained model
model.config.id2label = ID_TO_LABEL
model.config.label2id = LABEL_TO_ID
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()

# Load SNLI test subset
dataset = load_dataset("snli", split="test")
dataset = dataset.filter(lambda x: x["label"] != -1)
subset = dataset.select(range(1000))  # index 0–999

# Define Cross Attention Metric
def compute_cross_attention(attentions, sentence_b_start: int):
    """
    Compute Cross Attention Score (A→B attention strength) for all layers/heads.
    attentions: list of (batch, num_heads, seq_len, seq_len)
    """
    num_layers = len(attentions)
    num_heads = attentions[0].shape[1]
    cross_scores = np.zeros((num_layers, num_heads), dtype=np.float32)

    A_end = sentence_b_start
    B_start = sentence_b_start
    seq_len = attentions[0].shape[-1]

    for l in range(num_layers):
        attn = attentions[l][0].detach().cpu().numpy()  # shape (num_heads, L, L)
        for h in range(num_heads):
            if B_start < seq_len and A_end > 0:
                cross_scores[l, h] = attn[h, :A_end, B_start:].mean()
            else:
                cross_scores[l, h] = 0.0
    return cross_scores

# Inference
rows = []
for i, example in tqdm(enumerate(subset), total=len(subset)):
    premise = example["premise"]
    hypothesis = example["hypothesis"]
    true_label = example["label"]

    inputs = tokenizer(premise, hypothesis, return_tensors="pt", truncation=True,
                       padding=True, max_length=128).to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        pred_label_raw = torch.argmax(outputs.logits, dim=1).item()
        pred_label = LABEL_MAP[pred_label_raw]
        attentions = outputs.attentions


    token_type_ids = inputs["token_type_ids"][0].cpu().numpy()
    sentence_b_start = np.where(token_type_ids == 1)[0][0]

    # Compute cross attention
    cross_scores = compute_cross_attention(attentions, sentence_b_start)

    # Flatten (layer x head) scores (12 x 12 = 144)
    flat_scores = {}
    for l in range(12):
        for h in range(12):
            flat_scores[f"L{l}_H{h}"] = cross_scores[l, h]

    rows.append({
        "index": i,
        "premise": premise,
        "hypothesis": hypothesis,
        "true_label": true_label,
        "pred_label": pred_label,
        **flat_scores
    })

df = pd.DataFrame(rows)
output_path = "/content/drive/MyDrive/Colab Notebooks/snli_cross_attention_scores.csv"
df.to_csv(output_path, index=False)
print(f"Saved to {output_path}")

df.to_csv("snli_cross_attention_scores.csv", index=False)
print("Saved to snli_cross_attention_scores.csv")

# Visualize attention score distribution

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

df = pd.read_csv("/content/drive/MyDrive/Colab Notebooks/snli_cross_attention_scores.csv")

# Keep the correct results
correct = df[df["true_label"] == df["pred_label"]].copy()
print(f"Correct predictions: {len(correct)} / {len(df)}")

attn_cols = [c for c in df.columns if c.startswith("L")]
num_layers = len(set([c.split("_")[0] for c in attn_cols]))
num_heads = len(set([c.split("_")[1] for c in attn_cols]))

#attention by class
mean_by_class = {}
for label, name in zip([0,1,2], ["entailment", "neutral", "contradiction"]):
    subset = correct[correct["true_label"] == label]
    mean_scores = subset[attn_cols].mean().values.reshape(num_layers, num_heads)
    mean_by_class[name] = mean_scores


for cls, mat in mean_by_class.items():
    plt.figure(figsize=(8,6),dpi=1200)
    sns.heatmap(mat, cmap="YlGnBu")
    plt.title(f"Cross Attention Mean – Correct Predictions ({cls})")
    plt.xlabel("Head Index")
    plt.ylabel("Layer Index")
    plt.show()


# variance across classes for each head
stacked = np.stack(list(mean_by_class.values()), axis=0)  # shape [3, 12, 12]
head_variance = stacked.var(axis=0)

#display the first and last N results
N = 20
sorted_indices = np.argsort(head_variance.ravel())
bottom_indices = np.unravel_index(sorted_indices[:N], head_variance.shape)
top_indices = np.unravel_index(sorted_indices[-N:], head_variance.shape)

best_heads = []
for l, h in zip(*top_indices):
  v = head_variance[l, h]
  best_heads.append((l, h, v))

worst_heads = []
for l, h in zip(*bottom_indices):
  v = head_variance[l, h]
  worst_heads.append((l, h, v))

print(f"Top {N} most discriminative heads (layer, head, variance):")
for l, h, v in sorted(best_heads, key=lambda x: -x[2]):
    print(f"   Layer {l:<2}, Head {h:<2}, variance={v:.6f}")

print(f"Bottom {N} least discriminative heads (layer, head, variance):")
for l, h, v in sorted(worst_heads, key=lambda x: x[2]):
    print(f"   Layer {l:<2}, Head {h:<2}, variance={v:.6f}")

mask = np.zeros_like(head_variance, dtype=int)
for l, h in zip(*top_indices):
    mask[l, h] = 1
for l, h in zip(*bottom_indices):
    mask[l, h] = -1

plt.figure(figsize=(9, 7),dpi=1200)
sns.heatmap(head_variance, cmap="coolwarm", annot=False, cbar_kws={"label": "Variance"})

for l, h in zip(*top_indices):
    plt.gca().add_patch(plt.Rectangle((h, l), 1, 1, fill=False, edgecolor='lime', lw=2.5))
for l, h in zip(*bottom_indices):
    plt.gca().add_patch(plt.Rectangle((h, l), 1, 1, fill=False, edgecolor='yellow', lw=2.5))

plt.title("Head Variance Heatmap (12 Layers x 12 Heads)\nYellow = Insignificant, Green = Significant", fontsize=13)
plt.xlabel("Head Index (0–11)")
plt.ylabel("Layer Index (0–11)")
plt.tight_layout()
plt.show()

# Classification results of the original model

In [None]:
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
import seaborn as sns


subset = dataset.select(range(1000))

preds, labels = [], []
for example in subset:
    inputs = tokenizer(example["premise"], example["hypothesis"],
                       return_tensors="pt", truncation=True,
                       padding=True, max_length=128).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        pred = torch.argmax(outputs.logits, dim=1).item()
    preds.append(pred)
    labels.append(example["label"])

# correct the labels
mapped_preds = [LABEL_MAP[p] for p in preds]


acc = accuracy_score(labels, mapped_preds)
target_names = ["entailment", "neutral", "contradiction"]
report = classification_report(labels, mapped_preds, target_names=target_names, digits=4, output_dict=True)

print()
print(f"Baseline Accuracy: {acc:.4f}")
print(classification_report(labels, mapped_preds, target_names=target_names, digits=4))

# Confusion matrix
plt.figure(figsize=(6,5), dpi=600)
cm = confusion_matrix(labels, mapped_preds)
sns.heatmap(cm, annot=True, fmt="d", cmap="OrRd",
            xticklabels=target_names, yticklabels=target_names)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.title("Confusion Matrix – Baseline Model (No Head Removed)")
plt.show()

print("\nPer-class F1:", {k: round(v['f1-score'], 3) for k, v in report.items() if k in target_names})

# Removing redundant attention heads

In [None]:
threshold = 0.000001  # setup the variance threshold
remove_heads = {}
for layer in range(num_layers):
    low_heads = []
    for h in range(num_heads):
        if head_variance[layer, h] <= threshold:
            low_heads.append(h)

    if low_heads:
        remove_heads[layer] = low_heads

total_removed = 0
for v in remove_heads.values():
    total_removed += len(v)

print(f"\n Total removed heads: {total_removed}/{num_layers * num_heads} "
      f"({100*total_removed/(num_layers*num_heads):.1f}%)")

#Highlight removed attention heads
plt.figure(figsize=(9, 7), dpi=300)
sns.heatmap(head_variance, cmap="coolwarm", cbar_kws={"label": "Variance"})
plt.title(f"Head Variance Heatmap (variance < {threshold})\nRed boxes = removed heads")

for l, hs in remove_heads.items():
    for h in hs:
        plt.gca().add_patch(plt.Rectangle((h, l), 1, 1, fill=False, edgecolor='red', lw=2.5))

plt.xlabel("Head Index (0–11)")
plt.ylabel("Layer Index (0–11)")
plt.tight_layout()
plt.show()

# Remove redundant attention heads
def apply_head_removal(model, remove_heads):
    for layer_idx, layer in enumerate(model.bert.encoder.layer):
        num_heads = layer.attention.self.num_attention_heads
        head_dim = layer.attention.self.attention_head_size
        heads_to_remove = remove_heads.get(layer_idx, [])
        if not heads_to_remove:
            continue
        mask = torch.ones(num_heads)
        mask[heads_to_remove] = 0.0
        mask = mask.repeat_interleave(head_dim).view(1, -1)
        for proj_name in ["query", "key", "value"]:
            proj = getattr(layer.attention.self, proj_name)
            with torch.no_grad():
                proj.weight *= mask.T.to(proj.weight.device)
        print(f"Layer {layer_idx}: removed heads {heads_to_remove}")


def evaluate_model(model, tokenizer, dataset):
    preds, labels = [], []
    for example in dataset:
        inputs = tokenizer(example["premise"], example["hypothesis"],
                           return_tensors="pt", truncation=True,
                           padding=True, max_length=128).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
            pred = torch.argmax(outputs.logits, dim=1).item()
        preds.append(pred)
        labels.append(example["label"])

    # 0=contradiction, 1=entailment, 2=neutral）
    mapped_preds = [LABEL_MAP[p] for p in preds]

    acc = accuracy_score(labels, mapped_preds)
    prec = precision_score(labels, mapped_preds, average='macro', zero_division=0)
    rec = recall_score(labels, mapped_preds, average='macro', zero_division=0)
    f1 = f1_score(labels, mapped_preds, average='macro', zero_division=0)

    return acc, prec, rec, f1

# Applied to the rest of test set (8824 samples)

In [None]:
# dataset = load_dataset("snli", split="test")
# dataset = dataset.filter(lambda x: x["label"] != -1)

#evaluate the rest of test set
subset_test = dataset.select(range(1000, len(dataset)))
print(f" Validation subset: {len(subset_test)} samples")

baseline_model = BertForSequenceClassification.from_pretrained(model_name)
baseline_model.to(device).eval()
baseline_acc, baseline_prec, baseline_rec, baseline_f1 = evaluate_model(
    baseline_model, tokenizer, subset_test
)
print(f"\n Baseline — Acc={baseline_acc:.4f}, Prec={baseline_prec:.4f}, "
      f"Rec={baseline_rec:.4f}, F1={baseline_f1:.4f}")

model_removed = BertForSequenceClassification.from_pretrained(model_name)
model_removed.to(device).eval()
apply_head_removal(model_removed, remove_heads)

removed_acc, removed_prec, removed_rec, removed_f1 = evaluate_model(
    model_removed, tokenizer, subset_test
)
print(f"\n After Removal — Acc={removed_acc:.4f}, Prec={removed_prec:.4f}, "
      f"Rec={removed_rec:.4f}, F1={removed_f1:.4f}")

# Column for comparison
metrics = ["Accuracy", "Precision", "Recall", "F1-score"]
baseline_values = [baseline_acc, baseline_prec, baseline_rec, baseline_f1]
removed_values = [removed_acc, removed_prec, removed_rec, removed_f1]

x = np.arange(len(metrics))
width = 0.35

plt.figure(figsize=(7,5), dpi=300)
plt.bar(x - width/2, baseline_values, width, label="Baseline", color="#4C72B0")
plt.bar(x + width/2, removed_values, width, label=f"Removed (<{threshold})", color="#DD8452")

for i, (b, r) in enumerate(zip(baseline_values, removed_values)):
    plt.text(i - width/2, b + 0.002, f"{b:.4f}", ha="center", va="bottom", fontsize=9)
    plt.text(i + width/2, r + 0.002, f"{r:.4f}", ha="center", va="bottom", fontsize=9)

plt.xticks(x, metrics)
plt.ylim(0, 1.0)
plt.ylabel("Score", fontsize=11)
plt.title(f"Validation on Remaining Test Set ({len(subset_test)} samples)\n"
          f"Comparison of Baseline vs Head Removal (variance < {threshold})", fontsize=12)
plt.legend()
plt.grid(axis="y", linestyle="--", alpha=0.5)
plt.tight_layout()
plt.show()

In [None]:
# Sensitivity analysis
thresholds = [0.0] + list(np.linspace(0.000001, 0.00001, 10))
results = []

# num_layers = len(set([c.split("_")[0] for c in attn_cols]))
# num_heads = len(set([c.split("_")[1] for c in attn_cols]))
total_heads = num_layers * num_heads  # 12 * 12 = 144

for threshold in tqdm(thresholds, desc="Running sensitivity analysis"):
    model_copy = BertForSequenceClassification.from_pretrained(model_name)
    # model_copy.to(device).eval()
    model_copy.eval()
    remove_heads = {}
    removed_count = 0
    for layer in range(num_layers):
        low_heads = []
        for h in range(num_heads):
            if head_variance[layer, h] <= threshold:
              low_heads.append(h)
        if low_heads:
            remove_heads[layer] = low_heads
            removed_count += len(low_heads)

    apply_head_removal(model_copy, remove_heads)

    removed_ratio = removed_count / total_heads
    retained_ratio = 1 - removed_ratio

    acc, prec, rec, f1 = evaluate_model(model_copy, tokenizer, subset)
    results.append((threshold, acc, prec, rec, f1, removed_ratio))
    print(f"Threshold={threshold:.6f} | Removed={removed_count}/{total_heads} ({100*removed_ratio:.1f}%) | Acc={acc:.4f}, F1={f1:.4f}")


results_df = pd.DataFrame(results, columns=["threshold", "accuracy", "precision", "recall", "f1", "removed_ratio"])

fig, ax1 = plt.subplots(figsize=(9,6), dpi=300)

ax1.plot(results_df["threshold"], results_df["accuracy"], marker='o', label="Accuracy")
ax1.plot(results_df["threshold"], results_df["f1"], marker='d', label="F1-score")
ax1.plot(results_df["threshold"], results_df["precision"], marker='s', label="Precision")
ax1.plot(results_df["threshold"], results_df["recall"], marker='^', label="Recall")
ax1.set_xlabel("Variance Threshold", fontsize=12)
ax1.set_ylabel("Performance Metric", fontsize=12)
ax1.legend(loc="upper left")
ax1.grid(True, linestyle="--", alpha=0.6)

ax2 = ax1.twinx()
ax2.plot(results_df["threshold"], results_df["removed_ratio"]*100, color='red', linestyle='--', marker='x', label="Heads Removed (%)")
ax2.set_ylabel("Removed Heads (%)", color='red', fontsize=12)
ax2.tick_params(axis='y', labelcolor='red')

plt.title("Sensitivity Analysis: Performance vs. Head Removal Threshold", fontsize=13)
fig.tight_layout()
plt.show()

In [None]:
# Sensitivity analysis by removal ratio
removal_rates = np.arange(0.0, 1.0, 0.1)
results = []

flat_indices = np.unravel_index(np.argsort(head_variance.ravel()), head_variance.shape)
sorted_heads = list(zip(flat_indices[0], flat_indices[1]))

def get_predictions(model, tokenizer, subset):
    preds, labels = [], []
    for example in subset:
        inputs = tokenizer(example["premise"], example["hypothesis"],
                           return_tensors="pt", truncation=True,
                           padding=True, max_length=128).to(device)
        with torch.no_grad():
            outputs = model(**inputs)
            pred = torch.argmax(outputs.logits, dim=1).item()
        preds.append(pred)
        labels.append(example["label"])
    label_map = {0: 2, 1: 0, 2: 1}
    mapped_preds = [label_map[p] for p in preds]
    return np.array(mapped_preds), np.array(labels)

for rate in tqdm(removal_rates, desc="Running removal-rate sensitivity analysis"):
    model_copy = BertForSequenceClassification.from_pretrained(model_name).to(device).eval()

    n_remove = int(rate * total_heads)
    remove_heads = {}
    for i in range(n_remove):
        l, h = sorted_heads[i]
        remove_heads.setdefault(l, []).append(h)
    apply_head_removal(model_copy, remove_heads)

    preds, _ = get_predictions(model_copy, tokenizer, subset)
    # preds, _ = evaluate_model(model_copy, tokenizer, subset)
    # preds, _ = evaluate_model(model_copy, tokenizer, subset)

    acc = accuracy_score(labels, preds)
    prec = precision_score(labels, preds, average='macro')
    rec = recall_score(labels, preds, average='macro')
    f1 = f1_score(labels, preds, average='macro')

    results.append((rate, n_remove, acc, prec, rec, f1))
    print(f"Removed={rate*100:.1f}% ({n_remove}/{total_heads}) | Acc={acc:.4f}, Prec={prec:.4f}, Rec={rec:.4f}, F1={f1:.4f}")

results_df = pd.DataFrame(results, columns=["removal_rate", "removed_count", "accuracy", "precision", "recall", "f1"])


plt.figure(figsize=(8,6), dpi=300)
x = results_df["removal_rate"] * 100

plt.plot(x, results_df["accuracy"], marker='o', label="Accuracy")
plt.plot(x, results_df["f1"], marker='d', label="F1-score")
plt.plot(x, results_df["precision"], marker='s', label="Precision")
plt.plot(x, results_df["recall"], marker='^', label="Recall")

plt.xlabel("Removed Heads (%)", fontsize=12)
plt.ylabel("Performance Metric", fontsize=12)
plt.title("Sensitivity Analysis: Performance vs. Head Removal Rate", fontsize=13)
plt.grid(True, linestyle="--", alpha=0.6)
plt.legend()


plt.tight_layout()
plt.show()