# Project Setup

In [None]:
!pip install transformers datasets scikit-learn torch datasets bertviz

In [None]:
import os
import torch
from transformers import BertTokenizer, BertModel, BertForSequenceClassification, Trainer, TrainingArguments, AutoTokenizer, EarlyStoppingCallback
from sklearn.linear_model import LogisticRegression
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score
from torch.nn.functional import softmax
from datasets import load_dataset
from bertviz import head_view
from tqdm import tqdm
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

# Load dataset
train_datasets = load_dataset('ag_news', split='train')
test_dataset = load_dataset('ag_news', split='test')

# Split into training and validation dataset
train_val_split = train_datasets.train_test_split(test_size=0.2, seed=42)
train_data = train_val_split['train']
val_data = train_val_split['test']

print("Full training dataset", len(train_datasets))
print("\nTraining dataset:", train_data)
print("Validation dataset:", val_data)
print("Test dataset:", test_dataset)

# Load tokenizer and model
tokenizer = BertTokenizer.from_pretrained("google-bert/bert-base-uncased")
model = BertModel.from_pretrained("google-bert/bert-base-uncased")

# Move the model to GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
print(f"\nBERT tokenizer and model loaded onto {device}")

In [None]:
# Data distribution

def plot_label_distribution(dataset, title):
    # Convert to DataFrame for easier manipulation
    df = pd.DataFrame(dataset)
    label_counts = df['label'].value_counts().sort_index()
    label_names = ['World', 'Sports', 'Business', 'Sci/Tech']

    # Plot
    plt.figure(figsize=(6, 4))
    sns.barplot(x=label_names, y=label_counts.values)
    plt.title(title)
    plt.ylabel('Count')
    plt.xlabel('Label')
    plt.ylim(0, max(label_counts.values) * 1.1)
    plt.tight_layout()
    plt.show()

plot_label_distribution(train_data, "Training Set Label Distribution")
plot_label_distribution(val_data, "Validation Set Label Distribution")
plot_label_distribution(test_dataset, "Test Set Label Distribution")

# Helper Functions

In [None]:
def knn_hyperparameter_tuning(X_train, y_train, X_val, y_val, k_values=[1, 3, 5, 7, 9, 11]):
    best_k = None
    best_acc = 0.0
    acc_dict = {}

    for k in k_values:
        knn = KNeighborsClassifier(n_neighbors=k)
        knn.fit(X_train, y_train)
        val_preds = knn.predict(X_val)
        acc = accuracy_score(y_val, val_preds)
        acc_dict[k] = acc
        if acc > best_acc:
            best_acc = acc
            best_k = k

    return best_k, best_acc, acc_dict

In [None]:
def tokenize(item):
    return tokenizer(item["text"], truncation=True, padding="max_length", max_length=128)

In [None]:
def save_model(model, tokenizer, path="bert-finetuned"):
    model.save_pretrained(path)
    tokenizer.save_pretrained(path)

def load_model(path="bert-finetuned"):
    if os.path.exists(path):
        model = BertForSequenceClassification.from_pretrained(path)
        tokenizer = BertTokenizer.from_pretrained(path)
        return model, tokenizer
    else:
        raise FileNotFoundError(f"No saved model at: {path}")

In [None]:
def save_features(X_train, y_train, X_test, y_test, prefix="mean", directory="features_val"):
    os.makedirs(directory, exist_ok=True)
    np.save(os.path.join(directory, f"X_train_{prefix}.npy"), X_train)
    np.save(os.path.join(directory, f"y_train_{prefix}.npy"), y_train)
    np.save(os.path.join(directory, f"X_test_{prefix}.npy"), X_test)
    np.save(os.path.join(directory, f"y_test_{prefix}.npy"), y_test)

def load_features(prefix="mean", directory="features_val"):
    try:
        X_train = np.load(os.path.join(directory, f"X_train_{prefix}.npy"))
        y_train = np.load(os.path.join(directory, f"y_train_{prefix}.npy"))
        X_test = np.load(os.path.join(directory, f"X_test_{prefix}.npy"))
        y_test = np.load(os.path.join(directory, f"y_test_{prefix}.npy"))
        return X_train, y_train, X_test, y_test
    except FileNotFoundError:
        return None, None, None, None

# Task 3 - Experiments

## Task 3.1 - Probing

In [None]:
def extract_features(dataset, strategy="cls", max_samples=None):
    features, labels = [], []
    for i, item in tqdm(enumerate(dataset), total=min(len(dataset), max_samples or len(dataset))):
        if max_samples and i >= max_samples:
            break

        inputs = tokenizer(item['text'], return_tensors='pt', padding='max_length',
                           truncation=True, max_length=128)
        inputs = {k: v.to(device) for k, v in inputs.items()}
        with torch.no_grad():
            outputs = model(**inputs)
            hidden = outputs.last_hidden_state  # (batch_size, seq_len, hidden_dim)
            attention_mask = inputs['attention_mask']

            if strategy == "cls":
                emb = hidden[:, 0, :]
            elif strategy == "first":
                emb = hidden[:, 1, :]  # first after [CLS]
            elif strategy == "last":
                seq_lengths = attention_mask.sum(dim=1)
                emb = hidden[range(hidden.size(0)), seq_lengths - 2, :]  # before [SEP]
            elif strategy == "mean":
                masked = hidden * attention_mask.unsqueeze(-1)
                sum_ = masked.sum(dim=1)
                count = attention_mask.sum(dim=1).unsqueeze(-1)
                emb = sum_ / count
            else:
                raise ValueError("Invalid strategy")

        features.append(emb.squeeze().cpu().numpy())
        labels.append(item['label'])

    return np.array(features), np.array(labels)

In [None]:
model.eval()
strategies = ["cls", "first", "last", "mean"]
k_values = range(1,21)
probing_results = {}

for strat in strategies:
    print(f"Extracting features using strategy: {strat}")

    X_train, y_train, X_val, y_val = load_features(prefix=f"{strat}_val") # Load previously saved features
    if X_train is None:
        print("No saved features found. Extracting...")
        X_train, y_train = extract_features(train_data, strategy=strat, max_samples=4000)  # LARGER MAX_SAMPLES IN FINAL VERSION
        X_val, y_val = extract_features(val_data, strategy=strat, max_samples=1000)
        save_features(X_train, y_train, X_val, y_val, prefix=f"{strat}_val")
    else:
        print("→ Loaded saved features.")

    # KNN
    best_k, best_knn_acc, knn_accs = knn_hyperparameter_tuning(X_train, y_train, X_val, y_val, k_values=k_values)

    # Logistic Regression
    logreg = LogisticRegression(max_iter=1000)
    logreg.fit(X_train, y_train)
    logreg_acc = accuracy_score(y_val, logreg.predict(X_val))

    probing_results [strat] = {
        "logreg": logreg_acc,
        "knn_best_k": best_k,
        "knn_acc": best_knn_acc,
        "knn_all": knn_accs
    }

for strat, scores in probing_results .items():
    print(f"{strat:<6} \t\t LogReg Acc: {scores['logreg']:.4f} \t\t Best K: {scores['knn_best_k']} \t KNN Acc: {scores['knn_acc']:.4f}")

In [None]:
logreg_scores = [probing_results[s]["logreg"] for s in strategies]
knn_scores = [probing_results[s]["knn_acc"] for s in strategies]

df_summary = pd.DataFrame({
    "Embedding Strategy": strategies,
    "Logistic Regression Accuracy": logreg_scores,
    "KNN Best Accuracy": knn_scores
})

print("Table 1: Validation Accuracy by Embedding Strategy and Classifier")
print(df_summary)
print("\n")

# -------------------
# Table 2: Detailed KNN Accuracies across Different K Values
# -------------------
rows = []
for strat in strategies:
    k_acc_dict = probing_results[strat]['knn_all']
    for k, acc in k_acc_dict.items():
        rows.append({
            "Embedding Strategy": strat,
            "K": k,
            "KNN Accuracy": acc
        })

df_knn = pd.DataFrame(rows)

print("Table 2: Validation Accuracy vs. K for Each Embedding Strategy (KNN)")
print(df_knn)

# -------------------
# You can also plot the data as before:
# Plot 1: Bar plot for summary accuracies
x = np.arange(len(strategies))
width = 0.35

fig, ax = plt.subplots(figsize=(8, 5))
bars1 = ax.bar(x - width/2, logreg_scores, width, label='Logistic Regression')
bars2 = ax.bar(x + width/2, knn_scores, width, label='KNN (Best K)')

ax.set_ylabel('Accuracy')
ax.set_title('Accuracy by Embedding Strategy and Classifier')
ax.set_xticks(x)
ax.set_xticklabels(strategies)
ax.set_ylim([0, 1.0])
ax.legend()

for bar in bars1 + bars2:
    height = bar.get_height()
    ax.annotate(f'{height:.3f}',
                xy=(bar.get_x() + bar.get_width() / 2, height),
                xytext=(0, 3), textcoords="offset points",
                ha='center', va='bottom')

plt.grid(True, axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()
plt.show()

fig2, ax2 = plt.subplots(figsize=(8, 5))

k_values = list(next(iter(probing_results.values()))['knn_all'].keys())
for strat in strategies:
    k_acc_dict = probing_results[strat]['knn_all']
    k_vals = list(k_acc_dict.keys())
    accs = list(k_acc_dict.values())
    ax2.plot(k_vals, accs, marker='o', label=f'{strat}')

ax2.set_xlabel("Number of Neighbors (K)")
ax2.set_ylabel("Validation Accuracy")
ax2.set_title("KNN Accuracy vs. K for Each Embedding Strategy")
ax2.set_xticks(k_values)
ax2.set_ylim([0, 1.0])
ax2.legend(title="Strategy")
ax2.grid(True, linestyle='--', alpha=0.7)

plt.tight_layout()
plt.show()

## Task 3.2 - Fine-Tuning BERT

In [None]:
# Use a subset of the dataset (20k train, 2k val)
small_train = train_data.shuffle(seed=42).select(range(20000))
small_val = val_data.shuffle(seed=42).select(range(2000))

tokenized_train = small_train.map(tokenize, batched=True)
tokenized_val = small_val.map(tokenize, batched=True)

tokenized_train = tokenized_train.rename_column("label", "labels")
tokenized_val = tokenized_val.rename_column("label", "labels")
tokenized_train.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
tokenized_val.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

We chose to use Hugging Face’s transformers library for fine-tuning, instead of the model setup provided in the assignment instructions. This implementation allows full fine-tuning of all BERT parameters along with an added classification head.

Reference: https://huggingface.co/docs/transformers/en/training

In [None]:
try:
    finetune_model, tokenizer = load_model()
    print("Loaded fine-tuned model from disk.")
except FileNotFoundError:
    print("No saved model found. Fine-tuning from scratch...\n")

    finetune_model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=4)

    training_args = TrainingArguments(
        output_dir="./bert_checkpoints",
        eval_strategy="epoch",
        save_strategy="epoch",
        logging_strategy="epoch",
        learning_rate=2e-5,
        per_device_train_batch_size=16,
        per_device_eval_batch_size=16,
        num_train_epochs=5,
        weight_decay=0.01,
        logging_dir="./bert_logs",
        load_best_model_at_end=True,
        report_to="none"
    )

    trainer = Trainer(
        model=finetune_model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_val,
        tokenizer=tokenizer,
        compute_metrics=lambda p: {"accuracy": (p.predictions.argmax(axis=1) == p.label_ids).mean()},
        callbacks=[EarlyStoppingCallback(early_stopping_patience=1)]
    )

    trainer.train()
    save_model(finetune_model, tokenizer)

if 'trainer' not in locals():
    training_args = TrainingArguments(
        output_dir="./bert_checkpoints",
        per_device_eval_batch_size=16,
        do_train=False,
        do_eval=True,
        report_to="none"
    )
    trainer = Trainer(
        model=finetune_model,
        args=training_args,
        eval_dataset=tokenized_val,
        tokenizer=tokenizer,
        compute_metrics=lambda p: {"accuracy": (p.predictions.argmax(axis=1) == p.label_ids).mean()}
    )

bert_results = trainer.evaluate()
print(f"Validation Accuracy after fine-tuning: {bert_results['eval_accuracy']:.4f}")

## Task 3.3 - Classification Performance

In [None]:
best_strategy = max(probing_results, key=lambda s: max(probing_results[s]['logreg'], probing_results[s]['knn_acc']))
print(f"Best embedding strategy: {best_strategy}")

X_full_train, y_full_train, X_test, y_test = load_features(prefix=best_strategy, directory="features_test")
if X_full_train is None:
    print("No saved features found in 'features_test'. Extracting now...")
    X_full_train, y_full_train = extract_features(train_datasets, strategy=best_strategy)
    X_test, y_test = extract_features(test_dataset, strategy=best_strategy)
    save_features(X_full_train, y_full_train, X_test, y_test, prefix=best_strategy, directory="features_test")
else:
    print("Loaded features from 'features_test'.")

print(f"\n===  KNN and Logistic Regression Test Accuracy ===")
best_k = probing_results[best_strategy]['knn_best_k']
knn_final = KNeighborsClassifier(n_neighbors=best_k)
knn_final.fit(X_full_train, y_full_train)
knn_test_acc = accuracy_score(y_test, knn_final.predict(X_test))
print(f"KNN Test Accuracy (strategy: {best_strategy}, k={best_k}): {knn_test_acc:.4f}")

logreg_final = LogisticRegression(max_iter=1000)
logreg_final.fit(X_full_train, y_full_train)
logreg_test_acc = accuracy_score(y_test, logreg_final.predict(X_test))
print(f"Logistic Regression Test Accuracy (strategy: {best_strategy}): {logreg_test_acc:.4f}")

In [None]:
tokenized_test = test_dataset.map(tokenize, batched=True)
tokenized_test = tokenized_test.rename_column("label", "labels")
tokenized_test.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

fine_tune_results = trainer.evaluate(tokenized_test)
fine_tune_test_acc = fine_tune_results["eval_accuracy"]
print(f"Fine-tuned BERT Test Accuracy: {fine_tune_test_acc:.4f}")

In [None]:
labels = ["KNN (Probing)", "LogReg (Probing)", "Fine-tuned BERT"]
test_accuracies = [knn_test_acc, logreg_test_acc, fine_tune_test_acc]

x = np.arange(len(labels))
width = 0.7

fig, ax = plt.subplots(figsize=(6, 5))
bars = ax.bar(x, test_accuracies, width, color=["blue", "green", "red"])

ax.set_ylabel("Test Accuracy")
ax.set_title("Test Accuracies for Best Probing Strategy vs Fine-Tuned BERT")
ax.set_xticks(x)
ax.set_xticklabels(labels)
ax.set_ylim([0.7, 1.0])

for bar in bars:
    height = bar.get_height()
    ax.annotate(f"{height:.4f}",
                xy=(bar.get_x() + bar.get_width() / 2, height),
                xytext=(0, 3),
                textcoords="offset points",
                ha='center', va='bottom')

plt.tight_layout()
plt.show()

In [None]:
'''
Do we need to report the test accuracy for each strategy? Or just the best one?
Phrasing is a little awkward.

If the former, we can modify this code block from
using (currently) validation accuracy to test accuracy for each strategy.
'''

# comparison_data = []
# for strat, accs in probing_results.items():
#     logreg = accs.get("logreg", "-")
#     knn = accs.get("knn_acc", "-")

#     comparison_data.append({
#         "Strategy": strat,
#         "LogReg Test Acc": f"{logreg:.4f}" if isinstance(logreg, float) else logreg,
#         "KNN Test Acc": f"{knn:.4f}" if isinstance(knn, float) else knn
#     })

# comparison_df = pd.DataFrame(comparison_data)
# comparison_df["_sort"] = pd.to_numeric(comparison_df["LogReg Test Acc"], errors="coerce")
# comparison_df = comparison_df.sort_values(by="_sort", ascending=False).drop(columns="_sort")

# print("=== Multiclass Classification Test Accuracy Comparison ===")
# print(comparison_df)

# print("\n=== Fine-tuned BERT Test Accuracy ===")
# print(f"LogReg Test Acc: {fine_tune_test_acc:.4f}")

## Task 3.4 - Attention Matrix

In [None]:
import torch.nn.functional as F
from bertviz import head_view, model_view
import seaborn as sns

tokenizer = BertTokenizer.from_pretrained("bert-finetuned")

bert_attention_model = BertModel.from_pretrained(
    "bert-base-uncased",
    output_attentions=True,
    attn_implementation="eager"
)
bert_attention_model.eval()
bert_attention_model.to(device)

clf_model = BertForSequenceClassification.from_pretrained("bert-finetuned")
clf_model.eval()
clf_model.to(device)


def get_model_prediction(text, tokenizer, model, device):
    """
    Tokenizes and runs the given model (which outputs attentions) on the text.
    Returns:
      - pred_label: the predicted label.
      - confidence: the softmax confidence for that label.
      - attentions: the attention outputs (if available).
      - input_ids: tokenized input IDs.
    """
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=128
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True)
    logits = outputs.logits
    probs = torch.softmax(logits, dim=-1)
    pred_label = torch.argmax(probs, dim=-1).item()
    confidence = probs[0, pred_label].item()
    return pred_label, confidence, outputs.attentions, inputs["input_ids"]

def predict_label_with_prob(text):
    """
    Uses the classification model to get prediction and confidence.
    """
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding="max_length",
        max_length=128
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = clf_model(**inputs)
        probs = F.softmax(outputs.logits, dim=1)
        pred = torch.argmax(probs, dim=1).item()
        max_prob = probs[0, pred].item()
    return pred, max_prob

def visualize_heatmap(text, tokenizer, model, device, layer=8, head=0):
    """
    Runs the model on a given text sample and extracts attention from the specified
    layer and head. Then plots a heatmap (using seaborn) of the top 10 tokens (by attention)
    attended to by the [CLS] token.
    """
    pred_label, confidence, attentions, input_ids = get_model_prediction(text, tokenizer, model, device)
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])

    attn_matrix = attentions[layer][0, head].cpu().numpy()
    cls_attention = attn_matrix[0]

    token_indices = list(range(1, len(tokens)))
    sorted_indices = sorted(token_indices, key=lambda i: cls_attention[i], reverse=True)
    top_indices = sorted_indices[:10]

    top_tokens = [tokens[i] for i in top_indices]
    top_values = [cls_attention[i] for i in top_indices]

    plt.figure(figsize=(10, 1.5))
    sns.heatmap([top_values], annot=True, xticklabels=top_tokens, yticklabels=["[CLS]"], cmap="viridis")
    plt.title(f"Attention from [CLS] (Layer {layer+1}, Head {head})\nPredicted: {pred_label} | Confidence: {confidence:.2f}")
    plt.show()

def visualize_attention(text):
    """
    Obtains the full attention outputs from bert_attention_model and launches
    the interactive head_view visualization.
    Ensures that tensors are on the correct device.
    """
    inputs = tokenizer(
        text,
        return_tensors='pt',
        truncation=True,
        max_length=128
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = bert_attention_model(**inputs)
    attention = tuple(a.cpu() for a in outputs.attentions)
    tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
    head_view(attention, tokens)

def select_examples(test_dataset, tokenizer, model, device, conf_threshold=0.9, label_list=[0,1,2,3]):
    """
    Iterates over test_dataset and selects, for each label in label_list:
      - One correctly predicted example (with confidence >= conf_threshold).
      - One incorrectly predicted example (with confidence >= conf_threshold).

    Returns two dictionaries:
      correct_examples[label] and incorrect_examples[label], for each label in label_list.
    """
    correct_examples = {label: None for label in label_list}
    incorrect_examples = {label: None for label in label_list}

    for example in tqdm(test_dataset, desc="Selecting examples for all labels"):
        text = example["text"]
        true_label = example["label"]
        if true_label not in label_list:
            continue
        pred_label, confidence, _, _ = get_model_prediction(text, tokenizer, model, device)
        if confidence < conf_threshold:
            continue
        if pred_label == true_label:
            if correct_examples[true_label] is None:
                correct_examples[true_label] = text
        else:
            if incorrect_examples[true_label] is None:
                incorrect_examples[true_label] = text
        # Break early if we have one each for all labels.
        if all(correct_examples[label] is not None for label in label_list) and \
           all(incorrect_examples[label] is not None for label in label_list):
            break

    return correct_examples, incorrect_examples

In [None]:
test_dataset = load_dataset('ag_news', split='test')

correct_dict, incorrect_dict = select_examples(test_dataset, tokenizer, clf_model, device, conf_threshold=0.9)

### Label 0

In [None]:
print("=== Analysis for Label 0 ===")
if correct_dict[0]:
    print("Correctly Predicted Example for Label 0:")
    print(correct_dict[0][:] + "...")
    visualize_heatmap(correct_dict[0], tokenizer, clf_model, device, layer=11, head=0)
    visualize_attention(correct_dict[0])
else:
    print("No correctly predicted example found for label 0.")

if incorrect_dict[0]:
    print("\nIncorrectly Predicted Example for Label 0:")
    print(incorrect_dict[0][:])
    visualize_heatmap(incorrect_dict[0], tokenizer, clf_model, device, layer=11, head=0)
    visualize_attention(incorrect_dict[0])
else:
    print("No incorrectly predicted example found for label 0.")


### Label 1

In [None]:
print("=== Analysis for Label 1 ===")
if correct_dict[1]:
    print("Correctly Predicted Example for Label 1:")
    print(correct_dict[1][:] + "...")
    visualize_heatmap(correct_dict[1], tokenizer, clf_model, device, layer=11, head=0)
    visualize_attention(correct_dict[1])
else:
    print("No correctly predicted example found for label 1.")

if incorrect_dict[1]:
    print("\nIncorrectly Predicted Example for Label 1:")
    print(incorrect_dict[1][:] + "...")
    visualize_heatmap(incorrect_dict[1], tokenizer, clf_model, device, layer=11, head=0)
    visualize_attention(incorrect_dict[1])
else:
    print("No incorrectly predicted example found for label 1.")


### Label 2

In [None]:
print("=== Analysis for Label 2 ===")
if correct_dict[2]:
    print("Correctly Predicted Example for Label 2:")
    print(correct_dict[2][:] + "...")
    visualize_heatmap(correct_dict[2], tokenizer, clf_model, device, layer=11, head=0)
    visualize_attention(correct_dict[2])
else:
    print("No correctly predicted example found for label 2.")

if incorrect_dict[2]:
    print("\nIncorrectly Predicted Example for Label 2:")
    print(incorrect_dict[2][:] + "...")
    visualize_heatmap(incorrect_dict[2], tokenizer, clf_model, device, layer=11, head=0)
    visualize_attention(incorrect_dict[2])
else:
    print("No incorrectly predicted example found for label 2.")


### Label 3

In [None]:
print("=== Analysis for Label 3 ===")
if correct_dict[3]:
    print("Correctly Predicted Example for Label 3:")
    print(correct_dict[3][:] + "...")
    visualize_heatmap(correct_dict[3], tokenizer, clf_model, device, layer=11, head=0)
    visualize_attention(correct_dict[3])
else:
    print("No correctly predicted example found for label 3.")

if incorrect_dict[3]:
    print("\nIncorrectly Predicted Example for Label 3:")
    print(incorrect_dict[3][:] + "...")
    visualize_heatmap(incorrect_dict[3], tokenizer, clf_model, device, layer=11, head=0)
    visualize_attention(incorrect_dict[3])
else:
    print("No incorrectly predicted example found for label 3.")


# Extra experiment

In [None]:
def analyze_attention_statistics(text, layer=11, head=0):
    pred_label, confidence, attentions, input_ids = get_model_prediction(text, tokenizer, clf_model, device)
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    attn_matrix = attentions[layer][0, head].cpu().numpy()
    cls_attn = attn_matrix[0]
    cls_attn_excluding_cls = cls_attn[1:]

    plt.figure(figsize=(6,4))
    plt.hist(cls_attn_excluding_cls, bins=20, color="skyblue", edgecolor="black")
    plt.title(f"Attention Weight Distribution (Layer {layer+1}, Head {head})")
    plt.xlabel("Attention Weight")
    plt.ylabel("Frequency")
    plt.show()

def compare_heads(text, layer=11, heads=[0,1,2,3]):
    pred_label, confidence, attentions, input_ids = get_model_prediction(text, tokenizer, clf_model, device)
    tokens = tokenizer.convert_ids_to_tokens(input_ids[0])
    plt.figure(figsize=(15, 4))
    for i, head in enumerate(heads):
        attn_matrix = attentions[layer][0, head].cpu().numpy()
        cls_attn = attn_matrix[0]
        cls_attn_excluding_cls = cls_attn[1:]
        sorted_indices = np.argsort(-cls_attn_excluding_cls)  # descending order
        top_indices = sorted_indices[:10] + 1  # adjust indices to include offset for [CLS]
        top_tokens = [tokens[i] for i in top_indices]
        top_values = cls_attn[top_indices]
        plt.subplot(1, len(heads), i+1)
        plt.bar(range(len(top_tokens)), top_values, tick_label=top_tokens)
        plt.title(f"Head {head}\nPred: {pred_label} | Conf: {confidence:.2f}")
        plt.xticks(rotation=45)
    plt.tight_layout()
    plt.show()

In [None]:
print("=== Analysis for Label 0 ===")
if correct_dict[0]:
    print("Correctly Predicted Example for Label 0:")
    print(correct_dict[0][:200] + "...")
    analyze_attention_statistics(correct_dict[0], layer=11, head=0)
    compare_heads(correct_dict[0], layer=11, heads=[0,1,2,3])
else:
    print("No correctly predicted example found for label 0.")

if incorrect_dict[0]:
    print("\nIncorrectly Predicted Example for Label 0:")
    print(incorrect_dict[0][:200] + "...")
    analyze_attention_statistics(incorrect_dict[0], layer=11, head=0)
    compare_heads(incorrect_dict[0], layer=11, heads=[0,1,2,3])
else:
    print("No incorrectly predicted example found for label 0.")


In [None]:
print("=== Analysis for Label 0 ===")
if correct_dict[0]:
    print("Correctly Predicted Example for Label 0:")
    print(correct_dict[0][:200] + "...")
    analyze_attention_statistics(correct_dict[0], layer=11, head=0)
    compare_heads(correct_dict[0], layer=11, heads=[0,1,2,3])
else:
    print("No correctly predicted example found for label 0.")

if incorrect_dict[0]:
    print("\nIncorrectly Predicted Example for Label 0:")
    print(incorrect_dict[0][:200] + "...")
    analyze_attention_statistics(incorrect_dict[0], layer=11, head=0)
    compare_heads(incorrect_dict[0], layer=11, heads=[0,1,2,3])
else:
    print("No incorrectly predicted example found for label 0.")

In [None]:
print("=== Analysis for Label 2 ===")
if correct_dict[2]:
    print("Correctly Predicted Example for Label 2:")
    print(correct_dict[2][:200] + "...")
    analyze_attention_statistics(correct_dict[2], layer=11, head=0)
    compare_heads(correct_dict[2], layer=11, heads=[0,1,2,3])
else:
    print("No correctly predicted example found for label 2.")

if incorrect_dict[2]:
    print("\nIncorrectly Predicted Example for Label 2:")
    print(incorrect_dict[2][:200] + "...")
    analyze_attention_statistics(incorrect_dict[2], layer=11, head=0)
    compare_heads(incorrect_dict[2], layer=11, heads=[0,1,2,3])
else:
    print("No incorrectly predicted example found for label 2.")


In [None]:
print("=== Analysis for Label 3 ===")
if correct_dict[3]:
    print("Correctly Predicted Example for Label 3:")
    print(correct_dict[3][:200] + "...")
    analyze_attention_statistics(correct_dict[3], layer=11, head=0)
    compare_heads(correct_dict[3], layer=11, heads=[0,1,2,3])
else:
    print("No correctly predicted example found for label 3.")

if incorrect_dict[3]:
    print("\nIncorrectly Predicted Example for Label 3:")
    print(incorrect_dict[3][:200] + "...")
    analyze_attention_statistics(incorrect_dict[3], layer=11, head=0)
    compare_heads(incorrect_dict[3], layer=11, heads=[0,1,2,3])
else:
    print("No incorrectly predicted example found for label 3.")
