# Notebook for attention probing

### The notebook is split up into the following sections:
1. Clean data
2. Split dataset into 80/20 split and tokenize
3. Finetune BERT
4. Evaluation metrics

In [1]:
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification,  BertConfig, BertTokenizerFast, BertForSequenceClassification
from pathlib import Path


# The real attention probing

In [2]:
# ---------------------------------------------------------
# CONFIG
# ---------------------------------------------------------

MODEL_NAME_OR_PATH = "./runs/bert_tuned_lr3e-5_ep3/best_model"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

FIRST_PERSON_PRONOUNS = {"i", "me", "my", "mine", "myself"}

print("Loading tokenizer and model...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME_OR_PATH)
model.to(DEVICE)
model.eval()

num_layers = model.config.num_hidden_layers
num_heads = model.config.num_attention_heads
print(f"Model has {num_layers} layers and {num_heads} heads")


# ---------------------------------------------------------
# HELPERS
# ---------------------------------------------------------

def find_first_pronoun_index(tokens):
    """
    Return the index of the FIRST first-person pronoun in the token list.
    Raises ValueError if none are found.
    """
    for i, tok in enumerate(tokens):
        if tok.lower() in FIRST_PERSON_PRONOUNS:
            return i
    raise ValueError(f"No first-person pronoun found in tokens: {tokens}")


def analyze_sentence(sentence, model, tokenizer, device, layer_for_tokens=-1):
    """
    Analyze one sentence.

    - Computes mean incoming attention to the FIRST first-person pronoun
      for each layer and each head.
    - Prints:
        * tokens with indices
        * Layer x Head matrix of mean attention
        * Token-level attention distribution for a chosen layer
          (averaged over heads), from all tokens -> pronoun.
    """

    # ----- Tokenize -----
    enc = tokenizer(
        sentence,
        return_tensors="pt",
        truncation=True,
        padding=False,
    ).to(device)

    with torch.no_grad():
        outputs = model(**enc, output_attentions=True, return_dict=True)

    attentions = outputs.attentions  # tuple of length num_layers
    input_ids = enc["input_ids"][0]
    tokens = tokenizer.convert_ids_to_tokens(input_ids)

    # Find pronoun
    pron_idx = find_first_pronoun_index(tokens)

    # Identify special tokens ([CLS], [SEP], [PAD])
    special_ids = {
        tok_id
        for tok_id in [
            tokenizer.cls_token_id,
            tokenizer.sep_token_id,
            tokenizer.pad_token_id,
        ]
        if tok_id is not None
    }

    # Mask of "other tokens" that send attention to the pronoun
    include_mask = []
    for i, tok_id in enumerate(input_ids.tolist()):
        if tok_id in special_ids:
            include_mask.append(False)
        elif i == pron_idx:
            include_mask.append(False)
        else:
            include_mask.append(True)
    include_mask = torch.tensor(include_mask, dtype=torch.bool, device=device)

    # ----- Per-layer, per-head means -----
    # shape -> (num_layers, num_heads)
    layer_head_means = np.zeros((num_layers, num_heads), dtype=np.float32)

    for layer_idx, attn in enumerate(attentions):
        # attn: (1, num_heads, seq_len, seq_len)
        seq_len = attn.shape[-1]

        for head_idx in range(num_heads):
            # attention for this head: (seq_len, seq_len)
            attn_head = attn[0, head_idx]  # [queries, keys]

            # Incoming attention to pronoun = column pron_idx: shape (seq_len,)
            incoming_to_pronoun = attn_head[:, pron_idx]

            # Only from non-special, non-pronoun tokens
            selected = incoming_to_pronoun[include_mask]

            if selected.numel() == 0:
                mean_val = float("nan")
            else:
                mean_val = selected.mean().item()

            layer_head_means[layer_idx, head_idx] = mean_val

    # ----- Print tokens -----
    print("\nSentence:")
    print(sentence)
    print("\nTokens (with indices):")
    for i, tok in enumerate(tokens):
        marker = "<- PRON" if i == pron_idx else ""
        print(f"{i:2d} {tok:15s} {marker}")

    # ----- Print per-layer, per-head means -----
    print("\nMean incoming attention to pronoun (other tokens -> pronoun)")
    print("Rows = layers (1..L), columns = heads (0..H-1)\n")

    # Pretty formatting
    header = "Layer  " + "  ".join([f"h{h:02d}" for h in range(num_heads)])
    print(header)
    for layer_idx in range(num_layers):
        values = "  ".join([f"{layer_head_means[layer_idx, h]:.4f}" for h in range(num_heads)])
        print(f"{layer_idx+1:2d}    {values}")

    # ----- Token-level view for one layer (averaged over heads) -----
    if layer_for_tokens < 0:
        layer_for_tokens = num_layers - 1  # default = last layer

    attn_layer = attentions[layer_for_tokens]  # (1, num_heads, seq_len, seq_len)
    attn_mean_heads = attn_layer.mean(dim=1)[0]  # (seq_len, seq_len)
    incoming_to_pronoun = attn_mean_heads[:, pron_idx]

    print(f"\nToken-level incoming attention to pronoun for layer {layer_for_tokens+1}:")
    print("Token\t\tAttention_to_pronoun")
    for i, (tok, score) in enumerate(zip(tokens, incoming_to_pronoun.tolist())):
        marker = "<- PRON" if i == pron_idx else ""
        print(f"{i:2d} {tok:15s} {score:.6f} {marker}")

    return layer_head_means

Loading tokenizer and model...
Model has 12 layers and 12 heads


### trying diff sentences

In [3]:
sentence = "I feel hopeless about the future."
layer_head_means = analyze_sentence(sentence, model, tokenizer, DEVICE)




Sentence:
I feel hopeless about the future.

Tokens (with indices):
 0 [CLS]           
 1 i               <- PRON
 2 feel            
 3 hopeless        
 4 about           
 5 the             
 6 future          
 7 .               
 8 [SEP]           

Mean incoming attention to pronoun (other tokens -> pronoun)
Rows = layers (1..L), columns = heads (0..H-1)

Layer  h00  h01  h02  h03  h04  h05  h06  h07  h08  h09  h10  h11
 1    0.1483  0.0502  0.1989  0.1790  0.0638  0.0628  0.1453  0.0958  0.0685  0.0346  0.0082  0.0337
 2    0.0614  0.0027  0.0609  0.0283  0.0461  0.0184  0.0005  0.0286  0.0405  0.0381  0.0892  0.0451
 3    0.0000  0.0292  0.0182  0.0479  0.0523  0.0771  0.0078  0.0946  0.0674  0.0000  0.0797  0.0299
 4    0.0009  0.0708  0.0153  0.0210  0.1036  0.1368  0.0322  0.0374  0.1212  0.0008  0.0313  0.0827
 5    0.0046  0.1161  0.0221  0.0002  0.0305  0.0078  0.0473  0.0098  0.0157  0.0252  0.1625  0.0548
 6    0.0158  0.0135  0.0867  0.0788  0.0390  0.0019  0.0522  0

In [4]:
sentence = "On nighs like this, I dont want to be here anymore."
layer_head_means = analyze_sentence(sentence, model, tokenizer, DEVICE)


Sentence:
On nighs like this, I dont want to be here anymore.

Tokens (with indices):
 0 [CLS]           
 1 on              
 2 ni              
 3 ##gh            
 4 ##s             
 5 like            
 6 this            
 7 ,               
 8 i               <- PRON
 9 don             
10 ##t             
11 want            
12 to              
13 be              
14 here            
15 anymore         
16 .               
17 [SEP]           

Mean incoming attention to pronoun (other tokens -> pronoun)
Rows = layers (1..L), columns = heads (0..H-1)

Layer  h00  h01  h02  h03  h04  h05  h06  h07  h08  h09  h10  h11
 1    0.0814  0.0282  0.0689  0.0524  0.0263  0.0342  0.0805  0.0526  0.0572  0.0233  0.0491  0.0618
 2    0.0362  0.0615  0.0687  0.0253  0.0556  0.0336  0.0034  0.0541  0.0275  0.0214  0.0443  0.0222
 3    0.0667  0.0252  0.0124  0.0216  0.0560  0.0614  0.0157  0.0481  0.0423  0.0667  0.0456  0.0143
 4    0.0009  0.0525  0.0201  0.0364  0.0615  0.0601  0.0473  0.011

In [5]:
sentence = "On nights like this, I don’t want to stay at this party anymore.”"
layer_head_means = analyze_sentence(sentence, model, tokenizer, DEVICE)


Sentence:
On nights like this, I don’t want to stay at this party anymore.”

Tokens (with indices):
 0 [CLS]           
 1 on              
 2 nights          
 3 like            
 4 this            
 5 ,               
 6 i               <- PRON
 7 don             
 8 ’               
 9 t               
10 want            
11 to              
12 stay            
13 at              
14 this            
15 party           
16 anymore         
17 .               
18 ”               
19 [SEP]           

Mean incoming attention to pronoun (other tokens -> pronoun)
Rows = layers (1..L), columns = heads (0..H-1)

Layer  h00  h01  h02  h03  h04  h05  h06  h07  h08  h09  h10  h11
 1    0.0776  0.0217  0.0478  0.0467  0.0297  0.0296  0.0598  0.0600  0.0562  0.0256  0.0429  0.0446
 2    0.0267  0.0383  0.0541  0.0294  0.0557  0.0313  0.0052  0.0305  0.0241  0.0188  0.0508  0.0220
 3    0.0588  0.0181  0.0140  0.0244  0.0477  0.0439  0.0165  0.0434  0.0259  0.0588  0.0458  0.0123
 4    0.0010 

In [6]:
sentence = "I feel hopeful about the future."
layer_head_means = analyze_sentence(sentence, model, tokenizer, DEVICE)


Sentence:
I feel hopeful about the future.

Tokens (with indices):
 0 [CLS]           
 1 i               <- PRON
 2 feel            
 3 hopeful         
 4 about           
 5 the             
 6 future          
 7 .               
 8 [SEP]           

Mean incoming attention to pronoun (other tokens -> pronoun)
Rows = layers (1..L), columns = heads (0..H-1)

Layer  h00  h01  h02  h03  h04  h05  h06  h07  h08  h09  h10  h11
 1    0.1732  0.0667  0.1917  0.1730  0.0525  0.0728  0.1472  0.0937  0.0755  0.0409  0.0084  0.0463
 2    0.0590  0.0045  0.0447  0.0182  0.0426  0.0181  0.0007  0.0198  0.0378  0.0430  0.0995  0.0351
 3    0.0000  0.0335  0.0169  0.0501  0.0504  0.0817  0.0090  0.0976  0.0589  0.0000  0.0639  0.0271
 4    0.0009  0.0672  0.0169  0.0199  0.0960  0.1270  0.0347  0.0336  0.1079  0.0018  0.0321  0.0780
 5    0.0043  0.0989  0.0176  0.0002  0.0163  0.0085  0.0518  0.0084  0.0171  0.0250  0.1782  0.0555
 6    0.0145  0.0206  0.1012  0.0651  0.0470  0.0019  0.0376  0.

In [7]:
sentence = "When everything feels pointless, I don’t see the point in staying anymore.”"
layer_head_means = analyze_sentence(sentence, model, tokenizer, DEVICE)


Sentence:
When everything feels pointless, I don’t see the point in staying anymore.”

Tokens (with indices):
 0 [CLS]           
 1 when            
 2 everything      
 3 feels           
 4 pointless       
 5 ,               
 6 i               <- PRON
 7 don             
 8 ’               
 9 t               
10 see             
11 the             
12 point           
13 in              
14 staying         
15 anymore         
16 .               
17 ”               
18 [SEP]           

Mean incoming attention to pronoun (other tokens -> pronoun)
Rows = layers (1..L), columns = heads (0..H-1)

Layer  h00  h01  h02  h03  h04  h05  h06  h07  h08  h09  h10  h11
 1    0.0785  0.0283  0.0443  0.0502  0.0272  0.0319  0.0592  0.0503  0.0483  0.0208  0.0513  0.0384
 2    0.0435  0.0321  0.0469  0.0182  0.0579  0.0215  0.0031  0.0279  0.0210  0.0192  0.0481  0.0270
 3    0.0625  0.0148  0.0132  0.0295  0.0448  0.0458  0.0196  0.0386  0.0347  0.0625  0.0610  0.0169
 4    0.0007  0.0502  0

In [8]:
sentence = "When the meeting feels pointless, I don’t see the point in staying anymore."
layer_head_means = analyze_sentence(sentence, model, tokenizer, DEVICE)


Sentence:
When the meeting feels pointless, I don’t see the point in staying anymore.

Tokens (with indices):
 0 [CLS]           
 1 when            
 2 the             
 3 meeting         
 4 feels           
 5 pointless       
 6 ,               
 7 i               <- PRON
 8 don             
 9 ’               
10 t               
11 see             
12 the             
13 point           
14 in              
15 staying         
16 anymore         
17 .               
18 [SEP]           

Mean incoming attention to pronoun (other tokens -> pronoun)
Rows = layers (1..L), columns = heads (0..H-1)

Layer  h00  h01  h02  h03  h04  h05  h06  h07  h08  h09  h10  h11
 1    0.0795  0.0306  0.0381  0.0442  0.0289  0.0274  0.0668  0.0500  0.0484  0.0188  0.0507  0.0415
 2    0.0680  0.0304  0.1035  0.0228  0.0565  0.0313  0.0030  0.0336  0.0263  0.0296  0.0484  0.0259
 3    0.0623  0.0148  0.0127  0.0341  0.0382  0.0478  0.0166  0.0434  0.0406  0.0624  0.0614  0.0226
 4    0.0008  0.0527  0

# BERTVIZ

In [9]:

# ===== 1. Load tokenizer + fine-tuned model with attentions =====
RUN_NAME = "bert_tuned_lr3e-5_ep3"
MODEL_DIR = Path(f"./runs/{RUN_NAME}/best_model")

BASE_MODEL_NAME = "bert-base-uncased"
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load tokenizer
tokenizer = BertTokenizerFast.from_pretrained(BASE_MODEL_NAME)

# Load config from base model and enable attention
config = BertConfig.from_pretrained(BASE_MODEL_NAME)
config.output_attentions = True

# Load fine-tuned model with this config
model = BertForSequenceClassification.from_pretrained(
    str(MODEL_DIR),
    config=config
)
model.to(device)
model.eval()

print("Model loaded with attention:", model.config.output_attentions)

Model loaded with attention: True


In [10]:
from bertviz import head_view

def visualize_sentence(sentence):
    inputs = tokenizer.encode_plus(sentence, return_tensors='pt').to(device)

    with torch.no_grad():
        outputs = model(**inputs, output_attentions=True, return_dict=True)

    attentions = outputs.attentions  # list of 12 layers
    tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])

    return head_view(attentions, tokens)

In [11]:
visualize_sentence("I feel hopeful about the future.")

<IPython.core.display.Javascript object>

In [12]:
visualize_sentence("On nighs like this, I dont want to be here anymore")

<IPython.core.display.Javascript object>

In [None]:
# plots
def get_layer_curve(sentence, model, tokenizer, device):
    """
    Returns a 1D numpy array of shape (num_layers,)
    with mean incoming attention to the pronoun per layer
    (averaged over heads, and over other tokens -> pronoun).
    """
    layer_head_means = analyze_sentence(sentence, model, tokenizer, device)
    # layer_head_means: (num_layers, num_heads)
    layer_means = layer_head_means.mean(axis=1)  # average over heads
    return layer_means

import matplotlib.pyplot as plt
import numpy as np

dep_sentence = "When everything feels pointless, I don’t see the point in staying anymore."
ctrl_sentence = "When the meeting feels pointless, I don’t see the point in staying anymore."

dep_curve = get_layer_curve(dep_sentence, model, tokenizer, DEVICE)
ctrl_curve = get_layer_curve(ctrl_sentence, model, tokenizer, DEVICE)

layers = np.arange(1, len(dep_curve) + 1)

plt.figure(figsize=(6,4))
plt.plot(layers, dep_curve, marker="o", label="Depressive sentence")
plt.plot(layers, ctrl_curve, marker="o", label="Neutral sentence")
plt.xlabel("Layer")
plt.ylabel("Mean attention to pronoun")
plt.xticks(layers)
plt.title("Incoming attention to 'I' across layers")
plt.legend()
plt.tight_layout()
plt.show()

# average across a lot of sentences

In [None]:
def get_group_curve(sentences, model, tokenizer, device):
    curves = []
    for s in sentences:
        curve = get_layer_curve(s, model, tokenizer, device)
        curves.append(curve)
    curves = np.stack(curves, axis=0)  # (num_sentences, num_layers)
    mean_curve = curves.mean(axis=0)
    std_curve = curves.std(axis=0)
    return mean_curve, std_curve

dep_mean, dep_std = get_group_curve(DEPRESSIVE_SENTENCES, model, tokenizer, DEVICE)
ctrl_mean, ctrl_std = get_group_curve(CONTROL_SENTENCES, model, tokenizer, DEVICE)

layers = np.arange(1, len(dep_mean) + 1)

plt.figure(figsize=(6,4))
plt.plot(layers, dep_mean, marker="o", label="Depressive (mean)")
plt.plot(layers, ctrl_mean, marker="o", label="Control (mean)")
plt.fill_between(layers, dep_mean - dep_std, dep_mean + dep_std, alpha=0.2)
plt.fill_between(layers, ctrl_mean - ctrl_std, ctrl_mean + ctrl_std, alpha=0.2)
plt.xlabel("Layer")
plt.ylabel("Mean attention to pronoun")
plt.xticks(layers)
plt.title("Group-level attention to first-person pronoun")
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
DEPRESSIVE_SENTENCES = [
    "Lately I feel like I can't handle anything anymore.",
    "Sometimes I wake up and I don't see the point in anything.",
    "On most days I feel completely drained and empty.",
    "When I think about the future, I just feel afraid.",
    "These days I feel like I'm only pretending to be okay.",
    "Even when I'm with people, I still feel alone.",
    "At night I lie awake and I wish I could disappear.",
    "Recently I feel like nothing I do really matters.",
    "More and more I feel stuck in the same dark place.",
    "When I look at myself, I just feel disappointed."
]

In [None]:
CONTROL_SENTENCES = [
    "Lately I feel like I want to learn something new.",
    "Sometimes I wake up excited to start the day.",
    "On most days I feel pretty relaxed and content.",
    "When I think about the future, I feel curious and hopeful.",
    "These days I feel like I'm finally getting into a good routine.",
    "Even when I'm busy, I feel glad to see my friends.",
    "At night I lie in bed and think about fun things I want to do.",
    "Recently I feel like my work has been going really well.",
    "More and more I feel motivated to take better care of myself.",
    "When I look at myself, I feel proud of how far I've come."
]

In [None]:
def get_group_curve(sentences, model, tokenizer, device):
    curves = []
    for s in sentences:
        curve = get_layer_curve(s, model, tokenizer, device)
        curves.append(curve)
    curves = np.stack(curves, axis=0)  # (num_sentences, num_layers)
    mean_curve = curves.mean(axis=0)
    std_curve = curves.std(axis=0)
    return mean_curve, std_curve

dep_mean, dep_std = get_group_curve(DEPRESSIVE_SENTENCES, model, tokenizer, DEVICE)
ctrl_mean, ctrl_std = get_group_curve(CONTROL_SENTENCES, model, tokenizer, DEVICE)

layers = np.arange(1, len(dep_mean) + 1)

plt.figure(figsize=(6,4))
plt.plot(layers, dep_mean, marker="o", label="Depressive (mean)")
plt.plot(layers, ctrl_mean, marker="o", label="Control (mean)")
plt.fill_between(layers, dep_mean - dep_std, dep_mean + dep_std, alpha=0.2)
plt.fill_between(layers, ctrl_mean - ctrl_std, ctrl_mean + ctrl_std, alpha=0.2)
plt.xlabel("Layer")
plt.ylabel("Mean attention to pronoun")
plt.xticks(layers)
plt.title("Group-level attention to first-person pronoun")
plt.legend()
plt.tight_layout()
plt.show()

# Negative words

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# ---------------------------------------------------------
# 1. Load model + tokenizer (adjust path if needed)
# ---------------------------------------------------------

MODEL_NAME_OR_PATH = "./runs/bert_tuned_lr3e-5_ep3/best_model"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME_OR_PATH)
model.to(device)
model.eval()

num_layers = model.config.num_hidden_layers
num_heads = model.config.num_attention_heads
print(f"Model loaded with {num_layers} layers and {num_heads} heads")

# ---------------------------------------------------------
# 2. Define depressive sentences + target token sets
# ---------------------------------------------------------

DEPRESSIVE_SENTENCES = [
    "Lately I feel completely hopeless about my life.",
    "Most days I wake up feeling empty and exhausted.",
    "When I think about the future, I just feel afraid.",
    "These days I feel worthless no matter what I do.",
    "Even when I'm with people, I still feel completely alone.",
    "More and more I feel stuck in the same dark place.",
    "At night I lie awake feeling useless and disappointed in myself.",
    "Recently I feel like everything is pointless and draining.",
    "Sometimes I feel numb and disconnected from everything around me.",
    "When I look at myself, I just feel deeply disappointed."
]


FIRST_PERSON_PRONOUNS = {"i", "me", "my", "mine", "myself"}

NEGATIVE_EMOTION_WORDS = {
    "hopeless",
    "empty",
    "exhausted",
    "afraid",
    "worthless",
    "alone",
    "stuck",
    "dark",
    "useless",
    "disappointed",
    "pointless",
    "draining",
    "numb",
    "disconnected",
}

# ---------------------------------------------------------
# 3. Generic helpers for “target tokens”
# ---------------------------------------------------------

def find_target_indices(tokens, target_set):
    """
    Find indices of tokens that are in `target_set` (lowercased).
    Raises ValueError if none are found.
    """
    indices = [i for i, tok in enumerate(tokens) if tok.lower() in target_set]
    if not indices:
        raise ValueError(f"No target tokens from {target_set} found in tokens: {tokens}")
    return indices


def analyze_sentence_for_targets(sentence, target_set, model, tokenizer, device):
    """
    For a single sentence and target token set, compute mean incoming attention
    to the target token(s) for each layer and head.

    Returns:
        layer_head_means: np.ndarray (num_layers, num_heads)
    """
    enc = tokenizer(
        sentence,
        return_tensors="pt",
        truncation=True,
        padding=False,
    ).to(device)

    with torch.no_grad():
        outputs = model(**enc, output_attentions=True, return_dict=True)

    attentions = outputs.attentions  # list length = num_layers
    input_ids = enc["input_ids"][0]
    tokens = tokenizer.convert_ids_to_tokens(input_ids)

    target_indices = find_target_indices(tokens, target_set)

    # Special tokens
    special_ids = {
        tok_id
        for tok_id in [
            tokenizer.cls_token_id,
            tokenizer.sep_token_id,
            tokenizer.pad_token_id,
        ]
        if tok_id is not None
    }

    # Sender mask: non-special, non-target tokens
    include_mask = []
    for i, tok_id in enumerate(input_ids.tolist()):
        if tok_id in special_ids:
            include_mask.append(False)
        elif i in target_indices:
            include_mask.append(False)
        else:
            include_mask.append(True)
    include_mask = torch.tensor(include_mask, dtype=torch.bool, device=device)

    num_layers = len(attentions)
    num_heads = attentions[0].shape[1]
    layer_head_means = np.zeros((num_layers, num_heads), dtype=np.float32)

    for layer_idx, attn in enumerate(attentions):
        # attn: (1, num_heads, seq_len, seq_len)
        for head_idx in range(num_heads):
            attn_head = attn[0, head_idx]  # (seq_len, seq_len)

            # all senders -> all targets
            incoming_to_targets = attn_head[:, target_indices]  # (seq_len, num_targets)
            selected = incoming_to_targets[include_mask, :]     # (num_senders, num_targets)

            if selected.numel() == 0:
                mean_val = float("nan")
            else:
                # mean over senders and over target tokens
                mean_val = selected.mean().item()

            layer_head_means[layer_idx, head_idx] = mean_val

    return layer_head_means


def get_layer_curve_for_targets(sentence, target_set, model, tokenizer, device):
    """
    Collapse (layers, heads) -> (layers,) by averaging over heads.
    """
    layer_head_means = analyze_sentence_for_targets(
        sentence, target_set, model, tokenizer, device
    )
    return layer_head_means.mean(axis=1)  # (num_layers,)


# ---------------------------------------------------------
# 4. Group-level curves: pronouns vs negative emotion words
# ---------------------------------------------------------

def get_group_curve(sentences, target_set, model, tokenizer, device):
    curves = []
    for s in sentences:
        try:
            curve = get_layer_curve_for_targets(s, target_set, model, tokenizer, device)
            curves.append(curve)
        except ValueError as e:
            print(f"Skipping sentence (no target tokens found): {s}")
            print(e)
    curves = np.stack(curves, axis=0)  # (num_sentences, num_layers)
    mean_curve = curves.mean(axis=0)
    std_curve = curves.std(axis=0)
    return mean_curve, std_curve


# Pronoun curve (same depressive sentences)
dep_pron_mean, dep_pron_std = get_group_curve(
    DEPRESSIVE_SENTENCES, FIRST_PERSON_PRONOUNS, model, tokenizer, device
)

# Negative emotion curve (same depressive sentences)
dep_neg_mean, dep_neg_std = get_group_curve(
    DEPRESSIVE_SENTENCES, NEGATIVE_EMOTION_WORDS, model, tokenizer, device
)

# ---------------------------------------------------------
# 5. Plot: pronouns vs negative emotion words (same sentences)
# ---------------------------------------------------------

layers = np.arange(1, num_layers + 1)

plt.figure(figsize=(6,4))
plt.plot(layers, dep_pron_mean, marker="o", label="Pronouns (depressive)")
plt.plot(layers, dep_neg_mean, marker="o", label="Negative emotion words (depressive)")
plt.fill_between(layers, dep_pron_mean - dep_pron_std, dep_pron_mean + dep_pron_std, alpha=0.2)
plt.fill_between(layers, dep_neg_mean - dep_neg_std, dep_neg_mean + dep_neg_std, alpha=0.2)
plt.xlabel("Layer")
plt.ylabel("Mean incoming attention (other tokens → target)")
plt.xticks(layers)
plt.title("Depressive sentences: pronouns vs negative emotion words")
plt.legend()
plt.tight_layout()
plt.show()

# Optional: print curves for quick inspection
print("Mean attention per layer (pronouns):", dep_pron_mean)
print("Mean attention per layer (neg emotions):", dep_neg_mean)
print("Difference (neg - pron):", dep_neg_mean - dep_pron_mean)

# Positive emotion words

In [None]:
#!/usr/bin/env python
"""
Control sentences: compare attention to
- first-person pronouns
- positive emotion words

Outputs a curve like:
  Pronouns (control) vs Positive emotion words (control)
"""

import torch
import numpy as np
import matplotlib.pyplot as plt
from transformers import AutoTokenizer, AutoModelForSequenceClassification

# ---------------------------------------------------------
# 1. Load model + tokenizer (adjust path if needed)
# ---------------------------------------------------------

MODEL_NAME_OR_PATH = "./runs/bert_tuned_lr3e-5_ep3/best_model"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME_OR_PATH)
model.to(device)
model.eval()

num_layers = model.config.num_hidden_layers
num_heads = model.config.num_attention_heads
print(f"Model loaded with {num_layers} layers and {num_heads} heads")

# ---------------------------------------------------------
# 2. Control sentences + target token sets
# ---------------------------------------------------------

CONTROL_SENTENCES = [
    "Lately I feel genuinely hopeful about my life.",
    "Most days I wake up feeling rested and energized.",
    "When I think about the future, I just feel excited.",
    "These days I feel confident about what I do.",
    "Even when I'm on my own, I still feel completely okay.",
    "More and more I feel settled in a good place.",
    "At night I lie in bed feeling calm and grateful for my day.",
    "Recently I feel like everything is meaningful and rewarding.",
    "Sometimes I feel present and connected to everything around me.",
    "When I look at myself, I just feel genuinely proud.",
]

FIRST_PERSON_PRONOUNS = {"i", "me", "my", "mine", "myself"}

POSITIVE_EMOTION_WORDS = {
    "hopeful",
    "rested",
    "energized",
    "excited",
    "confident",
    "okay",
    "settled",
    "calm",
    "grateful",
    "meaningful",
    "rewarding",
    "present",
    "connected",
    "proud",
}

# ---------------------------------------------------------
# 3. Generic helpers for target tokens
# ---------------------------------------------------------

def find_target_indices(tokens, target_set):
    """
    Find indices of tokens that are in `target_set` (lowercased).
    Raises ValueError if none are found.
    """
    indices = [i for i, tok in enumerate(tokens) if tok.lower() in target_set]
    if not indices:
        raise ValueError(f"No target tokens from {target_set} found in tokens: {tokens}")
    return indices


def analyze_sentence_for_targets(sentence, target_set, model, tokenizer, device):
    """
    For a single sentence and target token set, compute mean incoming attention
    to the target token(s) for each layer and head.

    Returns:
        layer_head_means: np.ndarray (num_layers, num_heads)
    """
    enc = tokenizer(
        sentence,
        return_tensors="pt",
        truncation=True,
        padding=False,
    ).to(device)

    with torch.no_grad():
        outputs = model(**enc, output_attentions=True, return_dict=True)

    attentions = outputs.attentions  # list length = num_layers
    input_ids = enc["input_ids"][0]
    tokens = tokenizer.convert_ids_to_tokens(input_ids)

    target_indices = find_target_indices(tokens, target_set)

    # Special tokens
    special_ids = {
        tok_id
        for tok_id in [
            tokenizer.cls_token_id,
            tokenizer.sep_token_id,
            tokenizer.pad_token_id,
        ]
        if tok_id is not None
    }

    # Sender mask: non-special, non-target tokens
    include_mask = []
    for i, tok_id in enumerate(input_ids.tolist()):
        if tok_id in special_ids:
            include_mask.append(False)
        elif i in target_indices:
            include_mask.append(False)
        else:
            include_mask.append(True)
    include_mask = torch.tensor(include_mask, dtype=torch.bool, device=device)

    num_layers = len(attentions)
    num_heads = attentions[0].shape[1]
    layer_head_means = np.zeros((num_layers, num_heads), dtype=np.float32)

    for layer_idx, attn in enumerate(attentions):
        # attn: (1, num_heads, seq_len, seq_len)
        for head_idx in range(num_heads):
            attn_head = attn[0, head_idx]  # (seq_len, seq_len)

            # all senders -> all targets
            incoming_to_targets = attn_head[:, target_indices]  # (seq_len, num_targets)
            selected = incoming_to_targets[include_mask, :]     # (num_senders, num_targets)

            if selected.numel() == 0:
                mean_val = float("nan")
            else:
                # mean over senders and over target tokens
                mean_val = selected.mean().item()

            layer_head_means[layer_idx, head_idx] = mean_val

    return layer_head_means


def get_layer_curve_for_targets(sentence, target_set, model, tokenizer, device):
    """
    Collapse (layers, heads) -> (layers,) by averaging over heads.
    """
    layer_head_means = analyze_sentence_for_targets(
        sentence, target_set, model, tokenizer, device
    )
    return layer_head_means.mean(axis=1)  # (num_layers,)


def get_group_curve(sentences, target_set, model, tokenizer, device):
    """
    Compute mean and std layer-wise curves over a list of sentences.
    """
    curves = []
    for s in sentences:
        try:
            curve = get_layer_curve_for_targets(s, target_set, model, tokenizer, device)
            curves.append(curve)
        except ValueError as e:
            print(f"Skipping sentence (no target tokens found): {s}")
            print(e)
    curves = np.stack(curves, axis=0)  # (num_sentences, num_layers)
    mean_curve = curves.mean(axis=0)
    std_curve = curves.std(axis=0)
    return mean_curve, std_curve

# ---------------------------------------------------------
# 4. Compute curves: pronouns vs positive emotion words
# ---------------------------------------------------------

ctrl_pron_mean, ctrl_pron_std = get_group_curve(
    CONTROL_SENTENCES, FIRST_PERSON_PRONOUNS, model, tokenizer, device
)

ctrl_pos_mean, ctrl_pos_std = get_group_curve(
    CONTROL_SENTENCES, POSITIVE_EMOTION_WORDS, model, tokenizer, device
)

# ---------------------------------------------------------
# 5. Plot
# ---------------------------------------------------------

layers = np.arange(1, num_layers + 1)

plt.figure(figsize=(6, 4))
plt.plot(layers, ctrl_pron_mean, marker="o", label="Pronouns (control)")
plt.plot(layers, ctrl_pos_mean, marker="o", label="Positive emotion words (control)")
plt.fill_between(
    layers,
    ctrl_pron_mean - ctrl_pron_std,
    ctrl_pron_mean + ctrl_pron_std,
    alpha=0.2,
)
plt.fill_between(
    layers,
    ctrl_pos_mean - ctrl_pos_std,
    ctrl_pos_mean + ctrl_pos_std,
    alpha=0.2,
)
plt.xlabel("Layer")
plt.ylabel("Mean incoming attention (other tokens → target)")
plt.xticks(layers)
plt.title("Control sentences: pronouns vs positive emotion words")
plt.legend()
plt.tight_layout()
plt.show()

print("Mean attention per layer (pronouns, control):", ctrl_pron_mean)
print("Mean attention per layer (positive, control):", ctrl_pos_mean)
print("Difference (positive - pronouns):", ctrl_pos_mean - ctrl_pron_mean)

# CLS

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch

def find_target_indices(tokens, target_set):
    """Same as before: all indices whose token (lowercased) is in target_set."""
    return [i for i, tok in enumerate(tokens) if tok.lower() in target_set]


def analyze_cls_for_targets(sentence, target_set, model, tokenizer, device):
    """
    For ONE sentence:
      - find [CLS] position
      - for each layer & head, take attention from CLS -> target token(s)
      - average over target tokens -> one value per head
    Returns:
      layer_head_means: np.ndarray of shape (num_layers, num_heads)
    """
    enc = tokenizer(
        sentence,
        return_tensors="pt",
        truncation=True,
        padding=False,
    ).to(device)

    with torch.no_grad():
        outputs = model(**enc, output_attentions=True, return_dict=True)

    attentions = outputs.attentions  # list of length num_layers
    input_ids = enc["input_ids"][0]
    tokens = tokenizer.convert_ids_to_tokens(input_ids)

    # find CLS index safely
    cls_id = tokenizer.cls_token_id
    cls_positions = (input_ids == cls_id).nonzero(as_tuple=True)[0]
    if len(cls_positions) == 0:
        raise ValueError("No [CLS] token found in input_ids")
    cls_idx = cls_positions[0].item()

    target_indices = find_target_indices(tokens, target_set)
    if not target_indices:
        raise ValueError(f"No target tokens from {target_set} in tokens: {tokens}")

    num_layers = len(attentions)
    num_heads = attentions[0].shape[1]
    layer_head_means = np.zeros((num_layers, num_heads), dtype=np.float32)

    for layer_idx, attn in enumerate(attentions):
        # attn: (1, num_heads, seq_len, seq_len)
        for head_idx in range(num_heads):
            # CLS row, target columns -> shape (num_targets,)
            vals = attn[0, head_idx, cls_idx, target_indices]
            mean_val = vals.mean().item()
            layer_head_means[layer_idx, head_idx] = mean_val

    return layer_head_means


def get_cls_layer_curve(sentence, target_set, model, tokenizer, device):
    """
    Collapse (layers, heads) -> (layers,) by averaging head-wise.
    """
    layer_head_means = analyze_cls_for_targets(
        sentence, target_set, model, tokenizer, device
    )
    return layer_head_means.mean(axis=1)  # (num_layers,)


def get_group_cls_curve(sentences, target_set, model, tokenizer, device):
    """
    Mean & std CLS→target attention across sentences.
    """
    curves = []
    for s in sentences:
        try:
            curve = get_cls_layer_curve(s, target_set, model, tokenizer, device)
            curves.append(curve)
        except ValueError as e:
            print(f"Skipping sentence (no target tokens): {s}")
            print(e)
    curves = np.stack(curves, axis=0)
    return curves.mean(axis=0), curves.std(axis=0)

In [None]:
dep_cls_pron_mean, dep_cls_pron_std = get_group_cls_curve(
    DEPRESSIVE_SENTENCES, FIRST_PERSON_PRONOUNS, model, tokenizer, DEVICE
)

dep_cls_neg_mean, dep_cls_neg_std = get_group_cls_curve(
    DEPRESSIVE_SENTENCES, NEGATIVE_EMOTION_WORDS, model, tokenizer, DEVICE
)

layers = np.arange(1, len(dep_cls_pron_mean) + 1)

plt.figure(figsize=(6,4))
plt.plot(layers, dep_cls_pron_mean, marker="o", label="[CLS] → pronouns (dep)")
plt.plot(layers, dep_cls_neg_mean, marker="o", label="[CLS] → neg emotion (dep)")
plt.fill_between(layers, dep_cls_pron_mean - dep_cls_pron_std,
                 dep_cls_pron_mean + dep_cls_pron_std, alpha=0.2)
plt.fill_between(layers, dep_cls_neg_mean - dep_cls_neg_std,
                 dep_cls_neg_mean + dep_cls_neg_std, alpha=0.2)
plt.xlabel("Layer")
plt.ylabel("Attention from [CLS] to target")
plt.xticks(layers)
plt.title("Depressive: [CLS] attention to pronouns vs negative words")
plt.legend()
plt.tight_layout()
plt.show()

In [None]:
ctrl_cls_pron_mean, ctrl_cls_pron_std = get_group_cls_curve(
    CONTROL_SENTENCES, FIRST_PERSON_PRONOUNS, model, tokenizer, DEVICE
)

ctrl_cls_pos_mean, ctrl_cls_pos_std = get_group_cls_curve(
    CONTROL_SENTENCES, POSITIVE_EMOTION_WORDS, model, tokenizer, DEVICE
)

layers = np.arange(1, len(ctrl_cls_pron_mean) + 1)

plt.figure(figsize=(6,4))
plt.plot(layers, ctrl_cls_pron_mean, marker="o", label="[CLS] → pronouns (ctrl)")
plt.plot(layers, ctrl_cls_pos_mean, marker="o", label="[CLS] → pos emotion (ctrl)")
plt.fill_between(layers, ctrl_cls_pron_mean - ctrl_cls_pron_std,
                 ctrl_cls_pron_mean + ctrl_cls_pron_std, alpha=0.2)
plt.fill_between(layers, ctrl_cls_pos_mean - ctrl_cls_pos_std,
                 ctrl_cls_pos_mean + ctrl_cls_pos_std, alpha=0.2)
plt.xlabel("Layer")
plt.ylabel("Attention from [CLS] to target")
plt.xticks(layers)
plt.title("Control: [CLS] attention to pronouns vs positive words")
plt.legend()
plt.tight_layout()
plt.show()