In [None]:

!pip install fsspec -q
!pip install cmake dotenv hf-xet huggingface-hub
!pip install pandas pillow protobuf python-dotenv
!pip install safetensors
!pip install scikit-learn tiktoken tokenizers
!pip install torch torchvision torchaudio
!pip install tqdm transformers datasets sentencepiece
!pip install  --upgrade datasets fsspec --quiet


In [None]:
from datasets import load_dataset
from huggingface_hub import login
from dotenv import load_dotenv
import os, re, random
import torch, numpy as np
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments
)
from sklearn.metrics import accuracy_score, f1_score


def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

set_seed()


load_dotenv()
login(token=os.getenv("HUGGING_FACE_HUB_TOKEN"))


datasets = {
    "eng": load_dataset("brighter-dataset/BRIGHTER-emotion-categories", "eng"),
    "zul": load_dataset("brighter-dataset/BRIGHTER-emotion-categories", "zul"),
    "pcm": load_dataset("brighter-dataset/BRIGHTER-emotion-categories", "pcm"),
}

emotion_labels = ['anger', 'disgust', 'fear', 'joy', 'sadness', 'surprise']

def clean_text(text):
    text = text.lower()
    text = re.sub(r"http\S+", "", text)
    text = re.sub(r"[^\w\s]", "", text)
    return text

def preprocess(example):
    example["text"] = clean_text(example["text"])
    example["labels"] = [float(example.get(label, 0) or 0) for label in emotion_labels]
    return example

for lang in datasets:
    datasets[lang] = datasets[lang].map(preprocess)

model_name = "castorini/afriberta_base"
tokenizer = AutoTokenizer.from_pretrained(model_name)

def tokenize(batch):
    return tokenizer(batch["text"], padding="max_length", truncation=True, max_length=128)

for lang in datasets:
    datasets[lang] = datasets[lang].map(tokenize, batched=True)
    datasets[lang].set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])

model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=len(emotion_labels),
    problem_type="multi_label_classification"
)

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = (torch.sigmoid(torch.tensor(logits)) > 0.5).int().numpy()
    labels = labels.astype(int)
    return {
        "accuracy": accuracy_score(labels.reshape(-1), preds.reshape(-1)),
        "f1": f1_score(labels, preds, average="macro", zero_division=0),
    }

def get_trainer(model, train_dataset, eval_dataset, out_dir):
    args = TrainingArguments(
        output_dir=out_dir,
        eval_strategy="epoch",
        save_strategy="no",
        per_device_train_batch_size=8,
        per_device_eval_batch_size=8,
        num_train_epochs=3,
        logging_steps=10,
        report_to="none",
    )
    return Trainer(
        model=model,
        args=args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics,
    )


print("Training on English...")
trainer = get_trainer(model, datasets["eng"]["train"], datasets["eng"]["dev"], "./eng_model")
trainer.train()


print("\nFine-tuning on Zulu...")
trainer.train_dataset = datasets["zul"]["dev"]
trainer.eval_dataset = datasets["zul"]["test"]
trainer.args.output_dir = "./zul_model"
trainer.train()

print("\nEvaluating on Zulu test set...")
zulu_test_results = trainer.predict(datasets["zul"]["test"])
zulu_metrics = compute_metrics((zulu_test_results.predictions, zulu_test_results.label_ids))
print("Zulu F1:", zulu_metrics["f1"])


print("\nFine-tuning on Pidgin...")
trainer.train_dataset = datasets["pcm"]["train"]
trainer.eval_dataset = datasets["pcm"]["dev"]
trainer.args.output_dir = "./pcm_model"
trainer.train()

print("\nEvaluating on Pidgin test set...")
pcm_test_results = trainer.predict(datasets["pcm"]["test"])
pcm_metrics = compute_metrics((pcm_test_results.predictions, pcm_test_results.label_ids))
print("Pidgin F1:", pcm_metrics["f1"])
model.save_pretrained("./final_model")
tokenizer.save_pretrained("./final_model")


In [None]:
from sklearn.metrics import accuracy_score, f1_score, classification_report

zulu_preds_output = trainer.predict(datasets["zul"]["test"])
logits = torch.tensor(zulu_preds_output.predictions)
preds = (torch.sigmoid(logits) > 0.5).int().numpy()
labels = zulu_preds_output.label_ids.astype(int)
flat_preds = preds.reshape(-1)
flat_labels = labels.reshape(-1)


zulu_accuracy = accuracy_score(flat_labels, flat_preds)
zulu_f1_macro = f1_score(labels, preds, average="macro", zero_division=0)
zulu_report = classification_report(labels, preds, target_names=emotion_labels, zero_division=0)

print("Zulu Accuracy:", zulu_accuracy)
print("Zulu Macro F1:", zulu_f1_macro)
print("\nZulu Classification Report:\n", zulu_report)

In [None]:
from sklearn.metrics import accuracy_score, f1_score, classification_report

pcm_preds_output = trainer.predict(datasets["pcm"]["test"])
logits = torch.tensor(pcm_preds_output.predictions)
preds = (torch.sigmoid(logits) > 0.5).int().numpy()
labels = pcm_preds_output.label_ids.astype(int)

pcm_accuracy = accuracy_score(labels.reshape(-1), preds.reshape(-1))
pcm_f1_macro = f1_score(labels, preds, average="macro", zero_division=0)
pcm_report = classification_report(labels, preds, target_names=emotion_labels, zero_division=0)

print("Pidgin Accuracy:", pcm_accuracy)
print("Pidgin Macro F1:", pcm_f1_macro)
print("\nPidgin Classification Report:\n", pcm_report)


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForSequenceClassification


model_path = "./final_model"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()


emotion_labels = ['anger', 'disgust', 'fear', 'joy', 'sadness', 'surprise']


sentence = "I dey happy well well today!"


inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)


with torch.no_grad():
    outputs = model(**inputs)
    probs = torch.sigmoid(outputs.logits).squeeze().numpy()


plt.figure(figsize=(8, 5))
plt.bar(emotion_labels, probs, color='skyblue')
plt.ylabel("Probability")
plt.title("Emotion Probabilities for Pidgin Sentence")
plt.ylim(0, 1)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()


print({label: round(float(prob), 3) for label, prob in zip(emotion_labels, probs)})

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForSequenceClassification


model_path = "./final_model"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()


emotion_labels = ['anger', 'disgust', 'fear', 'joy', 'sadness', 'surprise']


sentence = "Kusolwa umkhuba ekhaya ngotholwe emotweni yake eseshonile esonakele"


inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding=True)


with torch.no_grad():
    outputs = model(**inputs)
    probs = torch.sigmoid(outputs.logits).squeeze().numpy()


plt.figure(figsize=(8, 5))
plt.bar(emotion_labels, probs, color='skyblue')
plt.ylabel("Probability")
plt.title("Emotion Probabilities for Pidgin Sentence")
plt.ylim(0, 1)
plt.grid(axis='y', linestyle='--', alpha=0.7)
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()

In [None]:
!pip install --upgrade shap
!pip install bertviz lime

In [None]:
import shap
import torch
import numpy as np


model_path = "./final_model"
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForSequenceClassification.from_pretrained(model_path)
model.eval()
texts = ["Kusolwa umkhuba ekhaya ngotholwe emotweni yake eseshonile esonakele"]


print("Testing predict function...")


def predict(texts):

    if isinstance(texts, str):
        texts = [texts]
    elif not isinstance(texts, list):
        texts = [str(texts)]


    processed_texts = []
    for text in texts:
        if isinstance(text, str) and text.strip():
            processed_texts.append(text.strip())
        else:
            processed_texts.append("empty")

    try:

        inputs = tokenizer(processed_texts, padding=True, truncation=True, max_length=128, return_tensors="pt")

        with torch.no_grad():
            outputs = model(**inputs)
            probs = torch.sigmoid(outputs.logits).cpu().numpy()

        return probs

    except Exception as e:
        print(f"Error in predict: {e}")

        return np.zeros((len(processed_texts), model.config.num_labels))


test_pred = predict(texts)
print(f"Test prediction shape: {test_pred.shape}")
print(f"Test prediction: {test_pred}")


test_batch = predict([texts[0], "hello world"])
print(f"Batch prediction shape: {test_batch.shape}")


print("\nCreating SHAP explainer...")


try:
    from transformers import pipeline


    def pipeline_predict(texts):
        if isinstance(texts, str):
            texts = [texts]


        text_list = [str(text) for text in texts]


        inputs = tokenizer(text_list, padding=True, truncation=True, max_length=128, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**inputs)
            probs = torch.sigmoid(outputs.logits).cpu().numpy()

        return probs


    explainer = shap.Explainer(pipeline_predict, shap.maskers.Text(tokenizer))

    print("Computing SHAP values...")
    shap_values = explainer(texts, max_evals=100)


    shap.plots.text(shap_values[0])

except Exception as e:
    print(f"SHAP with Text masker failed: {e}")


    print("Falling back to manual explanation...")

    original_text = texts[0]
    tokens = original_text.split()


    baseline_pred = predict([original_text])[0]

    print(f"Original text: {original_text}")
    print(f"Baseline prediction: {baseline_pred}")
    print("\nWord-level importance analysis:")

    word_importances = []
    for i, word in enumerate(tokens):

        masked_tokens = tokens.copy()
        masked_tokens[i] = "[MASK]"
        masked_text = " ".join(masked_tokens)


        masked_pred = predict([masked_text])[0]


        importance = baseline_pred - masked_pred
        word_importances.append(importance)

        print(f"  '{word}': {importance}")


    print(f"\nSimple text visualization:")
    for word, importance in zip(tokens, word_importances):

        emphasis = "**" if np.max(np.abs(importance)) > 0.01 else ""
        print(f"{emphasis}{word}{emphasis}", end=" ")
    print()


print("\n" + "="*50)
print("Alternative: Using LIME for text explanation")
print("="*50)

try:
    from lime.lime_text import LimeTextExplainer


    lime_explainer = LimeTextExplainer(class_names=[f'class_{i}' for i in range(model.config.num_labels)])


    def lime_predict(texts):
        probs = predict(texts)
        return probs

    exp = lime_explainer.explain_instance(
        texts[0],
        lime_predict,
        num_features=10,
        num_samples=100
    )

    print("LIME explanation:")
    exp.show_in_notebook(text=True)


    print("\nLIME feature weights:")
    for feature, weight in exp.as_list():
        print(f"  '{feature}': {weight}")

except ImportError:
    print("LIME not available. Install with: pip install lime")
except Exception as e:
    print(f"LIME explanation failed: {e}")

In [None]:
import torch
from bertviz import model_view, head_view
from transformers import AutoTokenizer, AutoModel
import numpy as np


text = "I dey happy well well today!"



print(f"Analyzing text: '{text}'")
print(f"Model type: {type(model)}")
print(f"Model config: {model.config}")


inputs = tokenizer(text, return_tensors="pt", add_special_tokens=True)
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

print(f"Tokens: {tokens}")
print(f"Number of tokens: {len(tokens)}")


with torch.no_grad():
    outputs = model(**inputs, output_attentions=True, return_dict=True)


attention = outputs.attentions
print(f"Number of attention layers: {len(attention)}")
print(f"Attention shape per layer: {attention[0].shape}")
print("\n" + "="*50)
print("METHOD 1: Model View")
print("="*50)

try:

    attention_tensor = torch.stack(attention)


    model_view(
        attention=attention_tensor,
        tokens=tokens,
        sentence_b_start=None,
        prettify_tokens=True,
        display_mode='light'
    )
    print("Model view displayed successfully!")

except Exception as e:
    print(f"Model view failed: {e}")
    print("This might be due to Jupyter notebook requirements or display issues.")


print("\n" + "="*50)
print("METHOD 2: Head View")
print("="*50)

try:
    head_view(
        attention=attention_tensor,
        tokens=tokens,
        sentence_b_start=None,
        prettify_tokens=True,
        display_mode='light'
    )
    print("Head view displayed successfully!")

except Exception as e:
    print(f"Head view failed: {e}")


print("\n" + "="*50)
print("METHOD 3: Manual Attention Analysis")
print("="*50)

def analyze_attention_patterns(attention_weights, tokens):
    """Analyze and print attention patterns manually"""

    num_layers = len(attention_weights)
    num_heads = attention_weights[0].shape[1]
    seq_len = len(tokens)

    print(f"Model has {num_layers} layers with {num_heads} heads each")
    print(f"Sequence length: {seq_len} tokens")


    for layer_idx, layer_attention in enumerate(attention_weights):
        print(f"\n--- Layer {layer_idx + 1} ---")


        avg_attention = layer_attention[0].mean(dim=0)


        for i, token in enumerate(tokens):
            if token in ['[CLS]', '[SEP]', '[PAD]']:
                continue


            from_token_attention = avg_attention[i]
            top_indices = torch.topk(from_token_attention, k=min(3, len(tokens))).indices

            print(f"  '{token}' pays most attention to:")
            for idx in top_indices:
                attention_score = from_token_attention[idx].item()
                target_token = tokens[idx]
                print(f"    '{target_token}': {attention_score:.3f}")


    print(f"\n--- Overall Attention Summary ---")


    all_attention = torch.stack(attention_weights).mean(dim=(0, 2))


    total_attention_received = all_attention[0].sum(dim=0)
    print("Tokens ranked by total attention received:")
    sorted_indices = torch.argsort(total_attention_received, descending=True)

    for rank, idx in enumerate(sorted_indices[:5]):
        token = tokens[idx]
        attention_score = total_attention_received[idx].item()
        print(f"  {rank+1}. '{token}': {attention_score:.3f}")


analyze_attention_patterns(attention, tokens)


print("\n" + "="*50)
print("METHOD 4: Attention Heatmap")
print("="*50)

try:
    import matplotlib.pyplot as plt
    import seaborn as sns


    layer_to_viz = 0
    head_to_viz = 0


    specific_attention = attention[layer_to_viz][0, head_to_viz].cpu().numpy()


    plt.figure(figsize=(10, 8))
    sns.heatmap(
        specific_attention,
        xticklabels=tokens,
        yticklabels=tokens,
        annot=True,
        fmt='.2f',
        cmap='Blues',
        cbar=True
    )
    plt.title(f'Attention Heatmap - Layer {layer_to_viz + 1}, Head {head_to_viz + 1}')
    plt.xlabel('Attending to (Keys)')
    plt.ylabel('Attending from (Queries)')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

    print(f"Heatmap created for Layer {layer_to_viz + 1}, Head {head_to_viz + 1}")

    avg_attention_all = torch.stack(attention).mean(dim=(0, 2))[0].cpu().numpy()

    plt.figure(figsize=(10, 8))
    sns.heatmap(
        avg_attention_all,
        xticklabels=tokens,
        yticklabels=tokens,
        annot=True,
        fmt='.2f',
        cmap='Reds',
        cbar=True
    )
    plt.title('Average Attention Across All Layers and Heads')
    plt.xlabel('Attending to (Keys)')
    plt.ylabel('Attending from (Queries)')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

except ImportError:
    print("Matplotlib/Seaborn not available for heatmap visualization")
except Exception as e:
    print(f"Heatmap creation failed: {e}")

print("\n" + "="*50)
print("METHOD 5: Interactive Analysis Functions")
print("="*50)

def get_attention_for_token(token_text, layer_idx=None, head_idx=None):
    """Get attention patterns for a specific token"""
    try:
        token_idx = tokens.index(token_text)
    except ValueError:
        print(f"Token '{token_text}' not found in: {tokens}")
        return

    if layer_idx is None:
        avg_attention = torch.stack(attention).mean(dim=0)
        if head_idx is None:
            token_attention = avg_attention[0].mean(dim=0)[token_idx]
            print(f"Average attention from '{token_text}' to all tokens:")
        else:
            token_attention = avg_attention[0, head_idx][token_idx]
            print(f"Attention from '{token_text}' (head {head_idx}) to all tokens:")
    else:
        if head_idx is None:
            token_attention = attention[layer_idx][0].mean(dim=0)[token_idx]
            print(f"Attention from '{token_text}' (layer {layer_idx}) to all tokens:")
        else:
            token_attention = attention[layer_idx][0, head_idx][token_idx]
            print(f"Attention from '{token_text}' (layer {layer_idx}, head {head_idx}) to all tokens:")


    sorted_indices = torch.argsort(token_attention, descending=True)
    for idx in sorted_indices:
        target_token = tokens[idx]
        score = token_attention[idx].item()
        print(f"  '{target_token}': {score:.4f}")


print("\nExample: Attention from 'happy':")
get_attention_for_token('happy')

print("\nExample: Attention from 'nervous':")
get_attention_for_token('nervous')


print("\n" + "="*50)
print("ATTENTION INSIGHTS")
print("="*50)


def analyze_attention_types():
    """Analyze different types of attention patterns"""


    avg_attention = torch.stack(attention).mean(dim=(0, 2))[0]

    print("Self-attention scores (diagonal values):")
    for i, token in enumerate(tokens):
        if token not in ['[CLS]', '[SEP]', '[PAD]']:
            self_attention = avg_attention[i, i].item()
            print(f"  '{token}': {self_attention:.4f}")

    print("\nStrongest cross-token attention pairs:")

    mask = ~torch.eye(len(tokens), dtype=bool)
    masked_attention = avg_attention.clone()
    masked_attention[~mask] = 0


    flat_attention = masked_attention.flatten()
    top_k = 5
    top_indices = torch.topk(flat_attention, k=top_k).indices

    for idx in top_indices:
        i = idx // len(tokens)
        j = idx % len(tokens)
        score = avg_attention[i, j].item()
        from_token = tokens[i]
        to_token = tokens[j]
        print(f"  '{from_token}' -> '{to_token}': {score:.4f}")

analyze_attention_types()

print(f"\n{'='*50}")
print("BertViz analysis complete!")
print("For best visualization, run this in a Jupyter notebook with:")
print("!pip install bertviz")
print("from bertviz import model_view, head_view")
print(f"{'='*50}")

In [None]:
import torch
from bertviz import model_view, head_view
from transformers import AutoTokenizer, AutoModel
import numpy as np


text = "I am very happy today but also a bit nervous."



print(f"Analyzing text: '{text}'")
print(f"Model type: {type(model)}")
print(f"Model config: {model.config}")


inputs = tokenizer(text, return_tensors="pt", add_special_tokens=True)
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

print(f"Tokens: {tokens}")
print(f"Number of tokens: {len(tokens)}")

with torch.no_grad():
    outputs = model(**inputs, output_attentions=True, return_dict=True)


attention = outputs.attentions
print(f"Number of attention layers: {len(attention)}")
print(f"Attention shape per layer: {attention[0].shape}")
print("\n" + "="*50)
print("METHOD 1: Model View")
print("="*50)

try:

    attention_tensor = torch.stack(attention)

    model_view(
        attention=attention_tensor,
        tokens=tokens,
        sentence_b_start=None,
        prettify_tokens=True,
        display_mode='light'
    )
    print("Model view displayed successfully!")

except Exception as e:
    print(f"Model view failed: {e}")
    print("This might be due to Jupyter notebook requirements or display issues.")


print("\n" + "="*50)
print("METHOD 2: Head View")
print("="*50)

try:
    head_view(
        attention=attention_tensor,
        tokens=tokens,
        sentence_b_start=None,
        prettify_tokens=True,
        display_mode='light'
    )
    print("Head view displayed successfully!")

except Exception as e:
    print(f"Head view failed: {e}")


print("\n" + "="*50)
print("METHOD 3: Manual Attention Analysis")
print("="*50)

def analyze_attention_patterns(attention_weights, tokens):
    """Analyze and print attention patterns manually"""

    num_layers = len(attention_weights)
    num_heads = attention_weights[0].shape[1]
    seq_len = len(tokens)

    print(f"Model has {num_layers} layers with {num_heads} heads each")
    print(f"Sequence length: {seq_len} tokens")


    for layer_idx, layer_attention in enumerate(attention_weights):
        print(f"\n--- Layer {layer_idx + 1} ---")


        avg_attention = layer_attention[0].mean(dim=0)
        for i, token in enumerate(tokens):
            if token in ['[CLS]', '[SEP]', '[PAD]']:
                continue


            from_token_attention = avg_attention[i]
            top_indices = torch.topk(from_token_attention, k=min(3, len(tokens))).indices

            print(f"  '{token}' pays most attention to:")
            for idx in top_indices:
                attention_score = from_token_attention[idx].item()
                target_token = tokens[idx]
                print(f"    '{target_token}': {attention_score:.3f}")


    print(f"\n--- Overall Attention Summary ---")


    all_attention = torch.stack(attention_weights).mean(dim=(0, 2))
    total_attention_received = all_attention[0].sum(dim=0)

    print("Tokens ranked by total attention received:")
    sorted_indices = torch.argsort(total_attention_received, descending=True)

    for rank, idx in enumerate(sorted_indices[:5]):
        token = tokens[idx]
        attention_score = total_attention_received[idx].item()
        print(f"  {rank+1}. '{token}': {attention_score:.3f}")


analyze_attention_patterns(attention, tokens)


print("\n" + "="*50)
print("METHOD 4: Attention Heatmap")
print("="*50)

try:
    import matplotlib.pyplot as plt
    import seaborn as sns


    layer_to_viz = 0
    head_to_viz = 0


    specific_attention = attention[layer_to_viz][0, head_to_viz].cpu().numpy()


    plt.figure(figsize=(10, 8))
    sns.heatmap(
        specific_attention,
        xticklabels=tokens,
        yticklabels=tokens,
        annot=True,
        fmt='.2f',
        cmap='Blues',
        cbar=True
    )
    plt.title(f'Attention Heatmap - Layer {layer_to_viz + 1}, Head {head_to_viz + 1}')
    plt.xlabel('Attending to (Keys)')
    plt.ylabel('Attending from (Queries)')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

    print(f"Heatmap created for Layer {layer_to_viz + 1}, Head {head_to_viz + 1}")


    avg_attention_all = torch.stack(attention).mean(dim=(0, 2))[0].cpu().numpy()

    plt.figure(figsize=(10, 8))
    sns.heatmap(
        avg_attention_all,
        xticklabels=tokens,
        yticklabels=tokens,
        annot=True,
        fmt='.2f',
        cmap='Reds',
        cbar=True
    )
    plt.title('Average Attention Across All Layers and Heads')
    plt.xlabel('Attending to (Keys)')
    plt.ylabel('Attending from (Queries)')
    plt.xticks(rotation=45)
    plt.yticks(rotation=0)
    plt.tight_layout()
    plt.show()

except ImportError:
    print("Matplotlib/Seaborn not available for heatmap visualization")
except Exception as e:
    print(f"Heatmap creation failed: {e}")

print("\n" + "="*50)
print("METHOD 5: Interactive Analysis Functions")
print("="*50)

def get_attention_for_token(token_text, layer_idx=None, head_idx=None):
    """Get attention patterns for a specific token"""
    try:
        token_idx = tokens.index(token_text)
    except ValueError:
        print(f"Token '{token_text}' not found in: {tokens}")
        return

    if layer_idx is None:

        avg_attention = torch.stack(attention).mean(dim=0)
        if head_idx is None:

            token_attention = avg_attention[0].mean(dim=0)[token_idx]
            print(f"Average attention from '{token_text}' to all tokens:")
        else:
            token_attention = avg_attention[0, head_idx][token_idx]
            print(f"Attention from '{token_text}' (head {head_idx}) to all tokens:")
    else:
        if head_idx is None:

            token_attention = attention[layer_idx][0].mean(dim=0)[token_idx]
            print(f"Attention from '{token_text}' (layer {layer_idx}) to all tokens:")
        else:
            token_attention = attention[layer_idx][0, head_idx][token_idx]
            print(f"Attention from '{token_text}' (layer {layer_idx}, head {head_idx}) to all tokens:")


    sorted_indices = torch.argsort(token_attention, descending=True)
    for idx in sorted_indices:
        target_token = tokens[idx]
        score = token_attention[idx].item()
        print(f"  '{target_token}': {score:.4f}")


print("\nExample: Attention from 'happy':")
get_attention_for_token('happy')

print("\nExample: Attention from 'nervous':")
get_attention_for_token('nervous')


print("\n" + "="*50)
print("ATTENTION INSIGHTS")
print("="*50)


def analyze_attention_types():
    """Analyze different types of attention patterns"""


    avg_attention = torch.stack(attention).mean(dim=(0, 2))[0]

    print("Self-attention scores (diagonal values):")
    for i, token in enumerate(tokens):
        if token not in ['[CLS]', '[SEP]', '[PAD]']:
            self_attention = avg_attention[i, i].item()
            print(f"  '{token}': {self_attention:.4f}")

    print("\nStrongest cross-token attention pairs:")

    mask = ~torch.eye(len(tokens), dtype=bool)
    masked_attention = avg_attention.clone()
    masked_attention[~mask] = 0


    flat_attention = masked_attention.flatten()
    top_k = 5
    top_indices = torch.topk(flat_attention, k=top_k).indices

    for idx in top_indices:
        i = idx // len(tokens)
        j = idx % len(tokens)
        score = avg_attention[i, j].item()
        from_token = tokens[i]
        to_token = tokens[j]
        print(f"  '{from_token}' -> '{to_token}': {score:.4f}")

analyze_attention_types()

print(f"\n{'='*50}")
print("BertViz analysis complete!")
print("For best visualization, run this in a Jupyter notebook with:")
print("!pip install bertviz")
print("from bertviz import model_view, head_view")
print(f"{'='*50}")