In [1]:
pip install conllu

Collecting conllu
  Downloading conllu-6.0.0-py3-none-any.whl.metadata (21 kB)
Downloading conllu-6.0.0-py3-none-any.whl (16 kB)
Installing collected packages: conllu
Successfully installed conllu-6.0.0


In [2]:
from collections import defaultdict
import conllu

def extract_minimal_pairs(conllu_file):
    """
    Creating the minimal pairs dataset where the articles are varied according to the same case, but with incorrect gender
    """

    GERMAN_CONTRACTIONS = {
    'am': ['an', 'dem'],
    'ans': ['an', 'das'],
    'aufs': ['auf', 'das'],
    'beim': ['bei', 'dem'],
    'durchs': ['durch', 'das'],
    'fürs': ['für', 'das'],
    'im': ['in', 'dem'],
    'ins': ['in', 'das'],
    'unterm': ['unter', 'dem'],
    'unters': ['unter', 'das'],
    'vom': ['von', 'dem'],
    'vorm': ['vor', 'dem'],
    'zum': ['zu', 'dem'],
    'zur': ['zu', 'der'],
    'überm': ['über', 'dem'],
    'übers': ['über', 'das'],
    'hinterm': ['hinter', 'dem'],
    'hinters': ['hinter', 'das'],
    }

    # Extracting gender/case minimal pairs from CoNLL-U file, while skipping contracted articles.
    with open(conllu_file, "r", encoding="utf-8") as f:
        data = f.read()

    sentences = conllu.parse(data)
    pairs = defaultdict(list)

    for sentence in sentences:
        tokens = [token["form"] for token in sentence] # no expansion of contractions done

        # Skipping the entire sentence if it contains ANY contractions
        has_contractions = any(token.lower() in GERMAN_CONTRACTIONS for token in tokens)
        if has_contractions:
            continue

        # Only processing sentences with no contractions at all
        for i, token in enumerate(sentence):
            feats = token.get("feats") or {}

            if (token["upos"] in ["NOUN", "PRON"] and
                "Case" in feats and
                "Gender" in feats and
                feats.get("Number") == "Sing"):  # working with only singular nouns

                # Defining the case mappings
                case_mappings = {
                    # Masculine to Feminine
                    ("Masc", "Nom", "der"): "Die",
                    ("Masc", "Acc", "den"): "die",
                    ("Masc", "Dat", "dem"): "der",

                    # Feminine to Masculine
                    ("Fem", "Nom", "die"): "Der",
                    ("Fem", "Acc", "die"): "den",
                    ("Fem", "Dat", "der"): "dem",

                    # Masculine to Neuter
                    ("Masc", "Nom", "der"): "Das",
                    ("Masc", "Acc", "den"): "das",
                    ("Masc", "Dat", "dem"): "dem",

                    # Neuter to Masculine
                    ("Neut", "Nom", "das"): "Der",
                    ("Neut", "Acc", "das"): "den",
                    ("Neut", "Dat", "dem"): "dem",

                    # Feminine to Neuter
                    ("Fem", "Nom", "die"): "Das",
                    ("Fem", "Acc", "die"): "das",
                    ("Fem", "Dat", "der"): "dem",

                    # Neuter to Feminine
                    ("Neut", "Nom", "das"): "Die",
                    ("Neut", "Acc", "das"): "die",
                    ("Neut", "Dat", "dem"): "der"
                }

                gender = feats["Gender"]
                case = feats["Case"]

                if gender in ["Masc", "Fem", "Neut"] and case in ["Nom", "Acc", "Dat"]:
                    idx = i - 1  # Looking for the determiner before the noun
                    if idx >= 0:
                        det = sentence[idx]
                        det_feats = det.get("feats") or {}
                        if det["upos"] == "DET" and det_feats.get("Case") == case:
                            det_form = det["form"]  # Keeping original capitalization for sentence starts
                            det_form_lower = det_form.lower()

                            # Only processing standard uncontracted articles
                            if det_form_lower not in ['der', 'die', 'das', 'dem', 'den']:
                                continue  # else, skipping

                            # Checks for the mapping for this combination
                            mapping_key = (gender, case, det_form_lower)
                            if mapping_key in case_mappings:
                                replacement = tokens.copy()
                                new_article = case_mappings[mapping_key]

                                if det_form[0].isupper():
                                    new_article = new_article.capitalize()
                                else:
                                    new_article = new_article.lower()

                                replacement[idx] = new_article

                                original_sent = " ".join(tokens)
                                modified_sent = " ".join(replacement)

                                if original_sent != modified_sent:
                                    # Stored as tuple: (original_sentence, modified_sentence, original_gender)
                                    pairs["case"].append((original_sent, modified_sent, gender))

    return pairs

In [4]:
pairs = extract_minimal_pairs("de_gsd-ud-train.conllu")

In [8]:
def patch_attention_head_activations(model, tokenizer, clean_sent, corrupted_sent,
                                         target_layer, target_head, article_positions):
    """
    Attention head patching for all heads of mBERT
    """

    if isinstance(article_positions, int):
        article_positions = [article_positions]
    elif not isinstance(article_positions, list):
        raise ValueError(f"article_positions must be int or list, got {type(article_positions)}")

    # Tokenizing sentences
    clean_tokens = tokenizer(clean_sent, return_tensors="pt")
    corrupted_tokens = tokenizer(corrupted_sent, return_tensors="pt")

    def get_hidden_states(outputs):
        if hasattr(outputs, 'last_hidden_state'):
            return outputs.last_hidden_state
        elif hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None:
            return outputs.hidden_states[-1]
        else:
            raise AttributeError(f"Cannot find hidden states in {type(outputs)}")

    # Get baseline hidden states
    with torch.no_grad():
        orig_outputs = model(**clean_tokens, output_attentions=True)
        baseline_hidden = get_hidden_states(orig_outputs)[0, article_positions[0]]  # Now works!


    # Get intervention attention pattern
    with torch.no_grad():
        corrupted_outputs = model(**corrupted_tokens, output_attentions=True)

        # Check if sequences are compatible
        clean_seq_len = orig_outputs.attentions[target_layer].shape[-1]
        corrupted_seq_len = corrupted_outputs.attentions[target_layer].shape[-1]

        if abs(clean_seq_len - corrupted_seq_len) > 3:
            print(f"Sequence length mismatch too large: {clean_seq_len} vs {corrupted_seq_len}")
            return 0.0

        intervention_attention = corrupted_outputs.attentions[target_layer][0, target_head]

    # Safe intervention: modify model temporarily
    hook_handle = None
    try:
        # Access the attention module for mBERT
        if hasattr(model, 'encoder'):
            attention_layer = model.encoder.layer[target_layer].attention
        elif hasattr(model, 'bert'):
            attention_layer = model.bert.encoder.layer[target_layer].attention
        else:
            raise AttributeError("Cannot find encoder layers")

        # Store original attention dropout for restoration
        original_dropout = attention_layer.self.dropout

        # Create custom dropout that replaces attention for our specific head
        class InterventionDropout(torch.nn.Module):
            def __init__(self, original_dropout, intervention_attn, target_head, min_seq_len):
                super().__init__()
                self.original_dropout = original_dropout
                self.intervention_attn = intervention_attn
                self.target_head = target_head
                self.training = original_dropout.training
                self.min_seq_len = min_seq_len

            def forward(self, attention_probs):
                # Replace the target head's attention safely
                if attention_probs.dim() == 4:  # [batch, heads, seq, seq]
                    modified_probs = attention_probs.clone()

                    # Only patch if dimensions are compatible
                    current_seq_len = attention_probs.shape[-1]
                    patch_len = min(current_seq_len, self.min_seq_len)

                    if patch_len > 0:
                        modified_probs[0, self.target_head, :patch_len, :patch_len] = \
                            self.intervention_attn[:patch_len, :patch_len]

                    return self.original_dropout(modified_probs)
                else:
                    return self.original_dropout(attention_probs)

        # Calculate safe sequence length
        min_seq_len = min(clean_seq_len, corrupted_seq_len)

        # Replace dropout temporarily
        attention_layer.self.dropout = InterventionDropout(
            original_dropout, intervention_attention, target_head, min_seq_len
        )

        # Forward pass with intervention
        intervened_outputs = model(**clean_tokens, output_attentions=True)
        intervened_hidden = get_hidden_states(intervened_outputs)[0, article_positions[0]]

        # Calculate effect
        effect = torch.norm(intervened_hidden - baseline_hidden).item()

        # Restore original dropout
        attention_layer.self.dropout = original_dropout

    except Exception as e:
        print(f"Error with head {target_layer}.{target_head}: {e}")
        return 0.0

    return effect

In [9]:
import torch
import numpy as np
from transformers import AutoModel, AutoTokenizer

def get_attention_head_effect(model, tokenizer, orig_sent, pert_sent):
    """
    Get the effect of each attention head using safe attention patching
    Optimized for mBERT
    """

    # Tokenize sentences
    orig_tokens = tokenizer(orig_sent, return_tensors="pt")
    pert_tokens = tokenizer(pert_sent, return_tensors="pt")

    # Checking if lengths are compatible
    orig_len = orig_tokens.input_ids.shape[1]
    pert_len = pert_tokens.input_ids.shape[1]

    if abs(orig_len - pert_len) > 3:
        print(f"Skipping pair due to length mismatch: {orig_len} vs {pert_len}")
        return {}

    # Getting article position
    orig_ids = orig_tokens['input_ids'][0]
    pert_ids = pert_tokens['input_ids'][0]

    diff_positions = []
    min_len = min(len(orig_ids), len(pert_ids))

    for i in range(min_len):
        if orig_ids[i] != pert_ids[i]:
            diff_positions.append(i)

    if not diff_positions:
        print("No differences found between sentences")
        return {}

    article_pos = diff_positions[0]

    if article_pos >= min(orig_len, pert_len):
        print(f"Article position {article_pos} out of bounds")
        return {}

    # Articles obtained
    correct_article = tokenizer.decode(orig_ids[article_pos]).strip()
    incorrect_article = tokenizer.decode(pert_ids[article_pos]).strip()


    # Testing each attention head (mBERT has 12 layers, 12 heads each)
    head_effects = {}

    for layer in range(model.config.num_hidden_layers):  # Should be 12 for mBERT
        for head in range(model.config.num_attention_heads):  # Should be 12 for mBERT
            try:
                effect = patch_attention_head_activations(
                    model, tokenizer, orig_sent, pert_sent, layer, head,
                    [article_pos]
                )
                head_effects[(layer, head)] = effect

            except Exception as e:
                print(f"Error with head {layer}.{head}: {e}")
                head_effects[(layer, head)] = 0.0

    return head_effects


In [10]:
def find_token_position(tokenizer, sentence, target_word):
    tokens = tokenizer.tokenize(sentence)
    for i, token in enumerate(tokens):
        if token.replace("##", "") == target_word:
            return i
    raise ValueError(f"Target word '{target_word}' not found in sentence: {sentence}")

def get_article_probabilities(logits, correct_article, incorrect_article):
    """To extract probabilities for correct/incorrect articles"""
    probs = torch.softmax(logits, dim=-1)

    correct_id = tokenizer.convert_tokens_to_ids(correct_article)
    incorrect_id = tokenizer.convert_tokens_to_ids(incorrect_article)

    p_correct = probs[correct_id]
    p_incorrect = probs[incorrect_id]

    return {
        'correct': p_correct.item(),
        'incorrect': p_incorrect.item(),
        'ratio': p_incorrect / p_correct
    }

In [11]:
def store_attention_head_results(model, tokenizer, sentence_pairs, max_sentences):
    """
    Storing results of attention analysis

    Args:
        model: BERT model with output_attentions=True
        tokenizer: Corresponding tokenizer
        sentence_pairs: List of (original_sentence, perturbed_sentence) tuples

    Returns:
        List of dictionaries with structured results
    """

    all_results = []

    for i, (orig_sent, pert_sent, gender) in enumerate(sentence_pairs[:max_sentences]):
        #print(f"\n Analyzing: {orig_sent[:50]}...")

        try:
            # Getting attention head effects
            head_effects = get_attention_head_effect(model, tokenizer, orig_sent, pert_sent)

            # Finding article position and the perturbed difference
            orig_tokens = tokenizer(orig_sent, return_tensors="pt")
            pert_tokens = tokenizer(pert_sent, return_tensors="pt")

            orig_ids = orig_tokens['input_ids'][0]
            pert_ids = pert_tokens['input_ids'][0]

            # Find the position where tokens differ
            article_pos = None
            correct_article = None
            incorrect_article = None

            for pos in range(min(len(orig_ids), len(pert_ids))):
                if orig_ids[pos] != pert_ids[pos]:
                    article_pos = pos
                    correct_article = tokenizer.decode(orig_ids[pos]).strip()
                    incorrect_article = tokenizer.decode(pert_ids[pos]).strip()
                    break

            if article_pos is None:
                print(f"  Warning: No difference found between sentences, skipping...")
                continue

            # Store structured result
            result = {
                'sentence': orig_sent,
                'article_pos': article_pos,
                'correct_article': correct_article,
                'incorrect_article': incorrect_article,
                'head_effects': head_effects,  # Full dictionary of all heads
            }

            all_results.append(result)

            sorted_heads = sorted(head_effects.items(), key=lambda x: x[1], reverse=True)


        except Exception as e:
            print(f" Error processing sentence {i+1}: {e}")
            continue

    print(f"\n Successfully processed {len(all_results)} sentences")
    return all_results


In [12]:
model = AutoModel.from_pretrained('bert-base-multilingual-cased', output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained('bert-base-multilingual-cased')

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/625 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/714M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/996k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.96M [00:00<?, ?B/s]

In [13]:
# Run analysis and store results
results = store_attention_head_results(model, tokenizer, pairs["case"], max_sentences=200)

  return forward_call(*args, **kwargs)


KeyboardInterrupt: 

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from collections import defaultdict, Counter

def analyze_head_consistency(all_results):
    """
    Analyzes how consistently attention heads contribute across different sentences

    all_results contains:
    {
        'sentence': str,
        'article_pos': int,
        'correct_article': str,
        'incorrect_article': str,
        'head_effects': dict
    }

    """

    # Consistency Analysis
    head_appearances = defaultdict(list)
    head_sentence_count = defaultdict(int)

    sentence_data = []

    for result in all_results:
        sentence = result['sentence']
        article_pos = result['article_pos']
        head_effects = result['head_effects']

        # Get top 10 heads for this sentence
        top_10_heads = sorted(head_effects.items(), key=lambda x: x[1], reverse=True)[:10]

        for (layer, head), effect in head_effects.items():
            head_appearances[(layer, head)].append(effect)

            # Count if this head is in top 10 for this sentence
            if (layer, head) in [h[0] for h in top_10_heads]:
                head_sentence_count[(layer, head)] += 1

        # Storing sentence metadata
        sentence_data.append({
            'sentence': sentence[:50] + "...",  # Truncate for display
            'article_pos': article_pos,
            'correct_article': result['correct_article'],
            'incorrect_article': result['incorrect_article'],
            'top_head': top_10_heads[0][0],  # (layer, head) of most important
            'top_effect': top_10_heads[0][1],  # effect of most important
        })

    # Finding most consistent heads
    total_sentences = len(all_results)
    consistency_scores = {}

    for head, appearances in head_appearances.items():
        if len(appearances) >= 3:  # Only consider heads that appear in multiple sentences
            consistency_scores[head] = {
                'mean_effect': np.mean(appearances),
                'std_effect': np.std(appearances),
                'appearance_rate': head_sentence_count[head] / total_sentences,
                'total_appearances': head_sentence_count[head],
                'coefficient_of_variation': np.std(appearances) / np.mean(appearances) if np.mean(appearances) > 0 else float('inf')
            }

    print("=" * 80)
    print("ATTENTION HEAD CONSISTENCY ANALYSIS")
    print("=" * 80)

    print(f"\nAnalyzed {total_sentences} sentences with gender agreement")

    # Most consistent heads (appear frequently with stable effects)
    print("\n MOST CONSISTENT HEADS (appear in many sentences):")
    print("-" * 60)
    consistent_heads = sorted(
        [(head, stats) for head, stats in consistency_scores.items()],
        key=lambda x: x[1]['appearance_rate'],
        reverse=True
    )

    for i, ((layer, head), stats) in enumerate(consistent_heads[:15]):
        print(f"{i+1:2d}. Layer {layer:2d}, Head {head:2d}: "
              f"appears {stats['total_appearances']:2d}/{total_sentences} sentences "
              f"({stats['appearance_rate']:.1%}) | "
              f"avg effect: {stats['mean_effect']:.3f} ± {stats['std_effect']:.3f}")

    # Most variable heads (context-dependent)
    print("\nMOST CONTEXT-DEPENDENT HEADS (high variability):")
    print("-" * 60)
    variable_heads = sorted(
        [(head, stats) for head, stats in consistency_scores.items()
         if stats['total_appearances'] >= 3],
        key=lambda x: x[1]['coefficient_of_variation'],
        reverse=True
    )

    for i, ((layer, head), stats) in enumerate(variable_heads[:10]):
        print(f"{i+1:2d}. Layer {layer:2d}, Head {head:2d}: "
              f"CV: {stats['coefficient_of_variation']:.2f} | "
              f"range: {min(head_appearances[(layer, head)]):.3f} - "
              f"{max(head_appearances[(layer, head)]):.3f}")

    # Sentence-specific analysis
    print("\n SENTENCE-SPECIFIC PATTERNS:")
    print("-" * 60)
    sentence_df = pd.DataFrame(sentence_data)

    # Group by article position
    pos_groups = sentence_df.groupby('article_pos')['top_head'].apply(list)
    print("\nTop heads by article position:")
    for pos in sorted(pos_groups.index):
        heads = pos_groups[pos]
        head_counts = Counter(heads)
        print(f"  Position {pos:2d}: {dict(head_counts.most_common(3))}")

    # Group by article type
    article_groups = sentence_df.groupby(['correct_article', 'incorrect_article'])['top_head'].apply(list)
    print(f"\nTop heads by article transition:")
    for (correct, incorrect), heads in article_groups.items():
        head_counts = Counter(heads)
        print(f"  '{correct}' → '{incorrect}': {dict(head_counts.most_common(2))}")

    return {
        'consistency_scores': consistency_scores,
        'sentence_data': sentence_df,
        'head_appearances': dict(head_appearances)
    }



def compare_sentence_types(all_results):
    """Compare head usage across different types of sentences"""

    print("\nDETAILED SENTENCE TYPE ANALYSIS:")
    print("=" * 60)

    # Group by article position ranges
    early_pos = [r for r in all_results if r['article_pos'] <= 5]
    mid_pos = [r for r in all_results if 6 <= r['article_pos'] <= 15]
    late_pos = [r for r in all_results if r['article_pos'] > 15]

    position_groups = {
        'Early (pos 1-5)': early_pos,
        'Middle (pos 6-15)': mid_pos,
        'Late (pos 16+)': late_pos
    }

    for group_name, group_results in position_groups.items():
        if not group_results:
            continue

        print(f"\n{group_name}: {len(group_results)} sentences")

        # Get top heads for this group
        all_head_effects = defaultdict(list)
        for result in group_results:
            for head, effect in result['head_effects'].items():
                all_head_effects[head].append(effect)

        # Calculate average effects and find top heads
        avg_effects = {
            head: np.mean(effects)
            for head, effects in all_head_effects.items()
            if len(effects) >= len(group_results) * 0.3  # Appear in at least 30% of sentences
        }

        top_heads = sorted(avg_effects.items(), key=lambda x: x[1], reverse=True)[:5]

        for i, ((layer, head), avg_effect) in enumerate(top_heads):
            appearance_count = len(all_head_effects[(layer, head)])
            print(f"  {i+1}. Layer {layer:2d}, Head {head:2d}: "
                  f"avg {avg_effect:.3f} "
                  f"(appears {appearance_count}/{len(group_results)} times)")



In [None]:
analysis = analyze_head_consistency(results)
# visualize_head_patterns(analysis)
compare_sentence_types(results)

ATTENTION HEAD CONSISTENCY ANALYSIS

Analyzed 200 sentences with gender agreement

 MOST CONSISTENT HEADS (appear in many sentences):
------------------------------------------------------------
 1. Layer  4, Head  4: appears 156/200 sentences (78.0%) | avg effect: 0.649 ± 0.340
 2. Layer  6, Head  2: appears 79/200 sentences (39.5%) | avg effect: 0.392 ± 0.195
 3. Layer  7, Head  4: appears 77/200 sentences (38.5%) | avg effect: 0.371 ± 0.198
 4. Layer  5, Head  6: appears 75/200 sentences (37.5%) | avg effect: 0.375 ± 0.183
 5. Layer  8, Head 11: appears 75/200 sentences (37.5%) | avg effect: 0.391 ± 0.205
 6. Layer  6, Head  0: appears 67/200 sentences (33.5%) | avg effect: 0.363 ± 0.221
 7. Layer  6, Head  1: appears 64/200 sentences (32.0%) | avg effect: 0.385 ± 0.225
 8. Layer  4, Head  2: appears 59/200 sentences (29.5%) | avg effect: 0.348 ± 0.182
 9. Layer  9, Head  4: appears 59/200 sentences (29.5%) | avg effect: 0.389 ± 0.259
10. Layer  5, Head  0: appears 57/200 sentences 

In [None]:
import pandas as pd

def find_article_position_from_pairs(orig_sent, pert_sent, tokenizer):
    """
    Getting article position by comparing original and perturbed sentences

    Args:
        orig_sent: Original sentence
        pert_sent: Perturbed sentence (with different article)
        tokenizer: The tokenizer

    Returns:
        int or None: Position where sentences differ (article position)
    """
    orig_tokens = tokenizer(orig_sent, return_tensors="pt")
    pert_tokens = tokenizer(pert_sent, return_tensors="pt")

    # Finding the differing article position
    orig_ids = orig_tokens['input_ids'][0]
    pert_ids = pert_tokens['input_ids'][0]

    diff_positions = []
    min_len = min(len(orig_ids), len(pert_ids))

    for i in range(min_len):
        if orig_ids[i] != pert_ids[i]:
            diff_positions.append(i)

    if not diff_positions:
        print("No differences found between sentences")
        return None

    # The first difference is taken as article position
    return diff_positions[0]

def identify_gender_mapping_heads(model, tokenizer, minimal_pairs):
    """
    Finding heads that consistently encode gender information using minimal pairs

    Args:
        minimal_pairs: List of tuples like [
            ("Leider war das Schnitzel aus der Fritöse .",
             "Leider war das Schnitzel aus dem Fritöse .",
             "Fem"),
            ("Das Restaurant schließt schon Um 21 Uhr .",
             "Die Restaurant schließt schon Um 21 Uhr .",
             "Neut"),
            ...
        ]
    """

    gender_mapping_data = []

    for orig_sent, pert_sent, gender in minimal_pairs:

        # Find article position by comparing the two sentences
        article_pos = find_article_position_from_pairs(orig_sent, pert_sent, tokenizer)

        if article_pos is None:
            continue

        # Process only the original sentence
        tokens = tokenizer(orig_sent, return_tensors="pt")
        with torch.no_grad():
            outputs = model(**tokens, output_attentions=True, output_hidden_states=True)

        # Extract representations for each head
        for layer in range(model.config.num_hidden_layers):
            for head in range(model.config.num_attention_heads):

                hidden_states = outputs.hidden_states[layer][0]
                head_dim = hidden_states.size(-1) // model.config.num_attention_heads
                start_idx = head * head_dim
                end_idx = (head + 1) * head_dim

                # Get head representation at article position
                head_repr = hidden_states[article_pos, start_idx:end_idx]

                # Also get attention pattern from article
                attention = outputs.attentions[layer][0, head]
                article_attention = attention[article_pos, :]

                gender_mapping_data.append({
                    'sentence': orig_sent,
                    'gender': gender,
                    'layer': layer,
                    'head': head,
                    'article_position': article_pos,
                    'head_representation': head_repr.numpy(),
                    'attention_pattern': article_attention.numpy(),
                    # Store some key stats
                    'repr_norm': torch.norm(head_repr).item(),
                    'attention_entropy': -torch.sum(article_attention * torch.log(article_attention + 1e-10)).item()
                })

    return pd.DataFrame(gender_mapping_data)

In [None]:
import numpy as np
import pandas as pd
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
from sklearn.preprocessing import LabelEncoder

def simple_gender_analysis(gender_mapping_df):

    results = []

    # Group by each head (layer, head combination)
    for (layer, head), head_data in gender_mapping_df.groupby(['layer', 'head']):

        # Skip if we don't have enough data
        if len(head_data) < 10:
            continue

        if head_data['gender'].nunique() < 2:
            continue

        print(f"Analyzing Layer {layer}, Head {head} - {len(head_data)} samples")

        # X
        representations = np.stack(head_data['head_representation'].values)

        # Getting the gender labels and converting to labels, y
        genders = head_data['gender'].values

        le = LabelEncoder()
        gender_numbers = le.fit_transform(genders)

        # Training a simple classifier
        classifier = RandomForestClassifier(n_estimators=50, random_state=42)

        # Using cross-validation
        cv_scores = cross_val_score(classifier, representations, gender_numbers,
                                  cv=min(5, len(head_data)), scoring='accuracy')

        # Store results
        results.append({
            'layer': layer,
            'head': head,
            'accuracy': cv_scores.mean(),
            'accuracy_std': cv_scores.std(),
            'sample_count': len(head_data),
            'gender_count': len(np.unique(genders))
        })

    results_df = pd.DataFrame(results)
    results_df = results_df.sort_values('accuracy', ascending=False)

    return results_df

In [None]:
model = AutoModel.from_pretrained('bert-base-multilingual-cased', output_attentions=True)
tokenizer = AutoTokenizer.from_pretrained('bert-base-multilingual-cased')

gender_mapping_df = identify_gender_mapping_heads(model, tokenizer, pairs["case"][:200])



In [None]:
results_df = simple_gender_analysis(gender_mapping_df)

Analyzing Layer 0, Head 0 - 200 samples
Analyzing Layer 0, Head 1 - 200 samples
Analyzing Layer 0, Head 2 - 200 samples
Analyzing Layer 0, Head 3 - 200 samples
Analyzing Layer 0, Head 4 - 200 samples
Analyzing Layer 0, Head 5 - 200 samples
Analyzing Layer 0, Head 6 - 200 samples
Analyzing Layer 0, Head 7 - 200 samples
Analyzing Layer 0, Head 8 - 200 samples
Analyzing Layer 0, Head 9 - 200 samples
Analyzing Layer 0, Head 10 - 200 samples
Analyzing Layer 0, Head 11 - 200 samples
Analyzing Layer 1, Head 0 - 200 samples
Analyzing Layer 1, Head 1 - 200 samples
Analyzing Layer 1, Head 2 - 200 samples
Analyzing Layer 1, Head 3 - 200 samples
Analyzing Layer 1, Head 4 - 200 samples
Analyzing Layer 1, Head 5 - 200 samples
Analyzing Layer 1, Head 6 - 200 samples
Analyzing Layer 1, Head 7 - 200 samples
Analyzing Layer 1, Head 8 - 200 samples
Analyzing Layer 1, Head 9 - 200 samples
Analyzing Layer 1, Head 10 - 200 samples
Analyzing Layer 1, Head 11 - 200 samples
Analyzing Layer 2, Head 0 - 200 samp

In [None]:
# Ultimately, not using these results as a more thorough design is done in Gender_Encoding.ipynb

def analyze_results(results_df, top_n=10):


    print("=== GENDER ENCODING ANALYSIS RESULTS ===")

    print("- Random guessing would be ~33% for 3 genders (Masc/Fem/Neut), so good heads should have >60% accuracy")


    print(f"\n=== TOP {top_n} GENDER-ENCODING HEADS ===")

    top_heads = results_df.head(top_n)

    for idx, row in top_heads.iterrows():
        accuracy_pct = row['accuracy'] * 100
        print(f"Layer {row['layer']}, Head {row['head']}: {accuracy_pct:.1f}% accuracy "
              f"({row['sample_count']} samples, {row['gender_count']} genders)")

    print(f"\n=== SUMMARY ===")
    print(f"Total heads analyzed: {len(results_df)}")
    good_heads = results_df[results_df['accuracy'] > 0.6]
    print(f"Heads with >60% accuracy: {len(good_heads)}")
    best_accuracy = results_df['accuracy'].max() * 100
    print(f"Best accuracy achieved: {best_accuracy:.1f}%")

    return top_heads


In [None]:
top_gender_predictive_heads = analyze_results(results_df, 10)

=== GENDER ENCODING ANALYSIS RESULTS ===

What this means:
- Accuracy = How well we can predict gender from this head's representations
- Higher accuracy = This head encodes gender information better
- Random guessing would be ~33% for 3 genders (Masc/Fem/Neut)
- Good heads should have >60% accuracy

=== TOP 10 GENDER-ENCODING HEADS ===
Layer 4.0, Head 3.0: 100.0% accuracy (200.0 samples, 3.0 genders)
Layer 5.0, Head 5.0: 99.5% accuracy (200.0 samples, 3.0 genders)
Layer 3.0, Head 10.0: 99.5% accuracy (200.0 samples, 3.0 genders)
Layer 4.0, Head 10.0: 99.5% accuracy (200.0 samples, 3.0 genders)
Layer 4.0, Head 2.0: 99.5% accuracy (200.0 samples, 3.0 genders)
Layer 4.0, Head 5.0: 99.5% accuracy (200.0 samples, 3.0 genders)
Layer 3.0, Head 4.0: 99.5% accuracy (200.0 samples, 3.0 genders)
Layer 5.0, Head 6.0: 99.0% accuracy (200.0 samples, 3.0 genders)
Layer 3.0, Head 6.0: 99.0% accuracy (200.0 samples, 3.0 genders)
Layer 5.0, Head 3.0: 99.0% accuracy (200.0 samples, 3.0 genders)

=== SUM

In [None]:
top_gender_predictive_heads = list(zip(top_gender_predictive_heads['layer'], top_gender_predictive_heads['head']))


In [None]:
import torch
import pandas as pd
import numpy as np
from typing import List, Tuple, Dict, Optional


def probe_gender_differences_perturbed(model, tokenizer, sentence_pairs):
    """
    Focus on the DIFFERENCES between original and perturbed sentences
    This is perfect for understanding what changes when gender is wrong!
    """
    difference_data = []

    for pair_idx, (orig_sentence, pert_sentence, gender) in enumerate(sentence_pairs):

        article_pos = find_article_position_from_perturbation(
            tokenizer, orig_sentence, pert_sentence
        )

        if article_pos is None:
            continue

        # Process both sentences
        orig_tokens = tokenizer(orig_sentence, return_tensors="pt")
        pert_tokens = tokenizer(pert_sentence, return_tensors="pt")

        # Check if sequences have different lengths
        orig_length = orig_tokens.input_ids.shape[1]
        pert_length = pert_tokens.input_ids.shape[1]

        # Skip if article position is out of bounds for either sentence
        if article_pos >= orig_length or article_pos >= pert_length:
            print(f"Warning: Article position {article_pos} out of bounds for pair {pair_idx}")
            continue

        with torch.no_grad():
            orig_outputs = model(**orig_tokens, output_attentions=True, output_hidden_states=True)
            pert_outputs = model(**pert_tokens, output_attentions=True, output_hidden_states=True)

        # Extract article info
        orig_article = extract_article_from_sentence(orig_sentence, article_pos, tokenizer)
        pert_article = extract_article_from_sentence(pert_sentence, article_pos, tokenizer)

        #orig_gender = classify_german_article(orig_article)
        #pert_gender = classify_german_article(pert_article)

        # Compare representations across heads
        for layer in range(model.config.num_hidden_layers):
            for head in range(model.config.num_attention_heads):

                # Getting representations
                orig_hidden = orig_outputs.hidden_states[layer][0] # shape: [seq_len, hidden_size]
                pert_hidden = pert_outputs.hidden_states[layer][0]

                # making sure that the article position is not greater than the number of tokens
                if article_pos >= orig_hidden.shape[0] or article_pos >= pert_hidden.shape[0]:
                    continue

                head_dim = orig_hidden.size(-1) // model.config.num_attention_heads   # 768//12 = 64
                start_idx = head * head_dim
                end_idx = (head + 1) * head_dim

                orig_head_repr = orig_hidden[article_pos, start_idx:end_idx] # each head's representation at the article position
                pert_head_repr = pert_hidden[article_pos, start_idx:end_idx]

                # Calculate differences
                repr_diff = pert_head_repr - orig_head_repr  # vector difference
                repr_distance = torch.norm(repr_diff).item()  # magnitude difference
                cosine_sim = torch.cosine_similarity(
                    orig_head_repr.unsqueeze(0),
                    pert_head_repr.unsqueeze(0)
                ).item()

                # Attention differences - handle different sequence lengths
                orig_attention = orig_outputs.attentions[layer][0, head]
                pert_attention = pert_outputs.attentions[layer][0, head]

                # Only compare attention if both sequences are long enough
                attention_diff = None
                if (article_pos < orig_attention.shape[0] and # again, checking to see that attention position is within the attention bounds
                    article_pos < pert_attention.shape[0]):   # attention matrix of perturbed and original tokens might be of different size

                    # Compare attention patterns up to the minimum length
                    min_seq_len = min(orig_attention.shape[0], pert_attention.shape[0])

                    orig_attn_slice = orig_attention[article_pos, :min_seq_len]
                    pert_attn_slice = pert_attention[article_pos, :min_seq_len]

                    attention_diff = torch.norm(pert_attn_slice - orig_attn_slice).item()

                difference_data.append({
                    'pair_id': pair_idx,
                    'layer': layer,
                    'head': head,
                    'orig_sentence': orig_sentence,
                    'pert_sentence': pert_sentence,
                    'orig_article': orig_article,
                    'pert_article': pert_article,
                    #'orig_gender': orig_gender['gender'],
                    #'pert_gender': pert_gender['gender'],
                    #'gender_changed': orig_gender['gender'] != pert_gender['gender'],
                    'orig_length': orig_length,
                    'pert_length': pert_length,
                    'length_diff': abs(orig_length - pert_length),

                    # Representation differences
                    'representation_distance': repr_distance,
                    'cosine_similarity': cosine_sim,
                    'difference_vector': repr_diff.numpy(),

                    # Attention differences (can be None if sequences too different)
                    'attention_change': attention_diff,

                    # Original representations for probing
                    'orig_representation': orig_head_repr.numpy(),
                    'pert_representation': pert_head_repr.numpy(),
                })

    return pd.DataFrame(difference_data)

In [None]:
def calculate_error_sensitivity_score_from_raw(difference_df):
    """
    To calculate sensitivity scores from raw difference data (before aggregation)

    Args:
        difference_df: DataFrame with columns:
            - layer, head, representation_distance, cosine_similarity, attention_change

    Returns:
        DataFrame: Aggregated head sensitivity with composite scores
    """

    # Aggregating the raw data by head (same as analyze_gender_sensitivity_by_head)
    head_sensitivity = difference_df.groupby(['layer', 'head']).agg({
        'representation_distance': ['mean', 'std', 'max'],
        'cosine_similarity': ['mean', 'std', 'min'],
        'attention_change': ['mean', 'std', 'max'],
        #'gender_changed': 'sum'  # Count of gender changes
    }).round(4)

    # Flatten column names
    head_sensitivity.columns = ['_'.join(col).strip() for col in head_sensitivity.columns]


    repr_dist = head_sensitivity['representation_distance_mean'].fillna(0)
    cosine_sim = head_sensitivity['cosine_similarity_mean'].fillna(head_sensitivity['cosine_similarity_mean'].median())
    attention_change = head_sensitivity['attention_change_mean'].fillna(0)

    # Normalize each metric to 0-1 scale
    if repr_dist.max() > repr_dist.min():
        repr_dist_norm = (repr_dist - repr_dist.min()) / (repr_dist.max() - repr_dist.min())
    else:
        repr_dist_norm = repr_dist * 0  # All zeros if no variation

    # For cosine similarity, we want LOW values (more different = more sensitive)
    if cosine_sim.max() > cosine_sim.min():
        cosine_dissim_norm = 1 - ((cosine_sim - cosine_sim.min()) / (cosine_sim.max() - cosine_sim.min()))
    else:
        cosine_dissim_norm = cosine_sim * 0  # All zeros if no variation

    if attention_change.max() > attention_change.min():
        attention_change_norm = (attention_change - attention_change.min()) / (attention_change.max() - attention_change.min())
    else:
        attention_change_norm = attention_change * 0  # All zeros if no variation

    # Weighted combination
    sensitivity_score = (0.5 * repr_dist_norm +      # 50% weight to representation distance
                        0.3 * cosine_dissim_norm +    # 30% weight to cosine dissimilarity
                        0.2 * attention_change_norm)  # 20% weight to attention change

    # Add sensitivity score to the dataframe
    head_sensitivity['sensitivity_score'] = sensitivity_score

    # Sort by sensitivity score (highest first)
    head_sensitivity = head_sensitivity.sort_values('sensitivity_score', ascending=False)

    return head_sensitivity

def analyze_error_sensitivity_complete(difference_df):
    """
    Complete analysis pipeline from raw difference data to ranked sensitive heads
    """
    print(" COMPLETE ERROR SENSITIVITY ANALYSIS")
    print("=" * 60)

    # Calculate sensitivity scores
    head_sensitivity = calculate_error_sensitivity_score_from_raw(difference_df)

    # Display results
    print(f"\n Analyzed {len(head_sensitivity)} attention heads")
    print(f" Data from {len(difference_df)} sentence pairs")

    print("\n TOP 10 MOST ERROR-SENSITIVE HEADS:")
    print("-" * 50)

    top_10 = head_sensitivity.head(10)
    for (layer, head), row in top_10.iterrows():
        print(f"Layer {layer:2d}, Head {head:2d}: "
              f"Score={row['sensitivity_score']:.3f} "
              f"(RepDist={row['representation_distance_mean']:.3f}, "
              f"CosSim={row['cosine_similarity_mean']:.3f}, "
              f"AttnChg={row['attention_change_mean']:.3f})")

    print(f"\n SENSITIVITY DISTRIBUTION:")
    print(f"Mean sensitivity score: {head_sensitivity['sensitivity_score'].mean():.3f}")
    print(f"Max sensitivity score:  {head_sensitivity['sensitivity_score'].max():.3f}")
    print(f"Heads with score > 0.7: {len(head_sensitivity[head_sensitivity['sensitivity_score'] > 0.7])}")
    print(f"Heads with score > 0.5: {len(head_sensitivity[head_sensitivity['sensitivity_score'] > 0.5])}")

    return head_sensitivity


In [None]:
diff_df = probe_gender_differences_perturbed(model,tokenizer, pairs["case"][:200])

head_sensitive_heads = analyze_error_sensitivity_complete(diff_df)

 COMPLETE ERROR SENSITIVITY ANALYSIS

 Analyzed 144 attention heads
 Data from 28800 sentence pairs

 TOP 10 MOST ERROR-SENSITIVE HEADS:
--------------------------------------------------
Layer  0, Head  8: Score=0.837 (RepDist=6.196, CosSim=0.414, AttnChg=0.116)
Layer  0, Head  0: Score=0.828 (RepDist=5.952, CosSim=0.430, AttnChg=0.228)
Layer  0, Head  3: Score=0.811 (RepDist=6.104, CosSim=0.431, AttnChg=0.105)
Layer  0, Head  4: Score=0.780 (RepDist=5.968, CosSim=0.479, AttnChg=0.149)
Layer  0, Head 11: Score=0.778 (RepDist=5.928, CosSim=0.497, AttnChg=0.189)
Layer  0, Head  1: Score=0.717 (RepDist=5.623, CosSim=0.521, AttnChg=0.181)
Layer  0, Head  7: Score=0.708 (RepDist=5.628, CosSim=0.540, AttnChg=0.183)
Layer  1, Head  0: Score=0.702 (RepDist=5.631, CosSim=0.497, AttnChg=0.092)
Layer  0, Head  2: Score=0.679 (RepDist=5.551, CosSim=0.568, AttnChg=0.173)
Layer  4, Head  4: Score=0.673 (RepDist=4.840, CosSim=0.651, AttnChg=0.621)

 SENSITIVITY DISTRIBUTION:
Mean sensitivity score: 

In [14]:
import torch

def find_article_position_simple(tokenizer, orig_sentence, pert_sentence):
    """Helper function to find article position"""
    orig_tokens = tokenizer(orig_sentence, return_tensors="pt")
    pert_tokens = tokenizer(pert_sentence, return_tensors="pt")

    orig_ids = orig_tokens['input_ids'][0]
    pert_ids = pert_tokens['input_ids'][0]

    min_len = min(len(orig_ids), len(pert_ids))

    for i in range(min_len):
        if orig_ids[i] != pert_ids[i]:
            return i

    return 0  # Fallback to first position

def patch_single_head(model, tokenizer, clean_sent, corrupt_sent, layer, head):
    """
    Patch a single head using the existing intervention method
    """
    article_pos = find_article_position_simple(tokenizer, clean_sent, corrupt_sent)

    # Using existing intervention function
    effect = patch_attention_head_activations(
        model, tokenizer, clean_sent, corrupt_sent, layer, head,
        article_pos
    )

    return effect

def patch_multiple_heads(model, tokenizer, clean_sent, corrupt_sent, head_list):
    """
    Patch multiple heads simultaneously
    """
    # Tokenizing the sentences
    clean_tokens = tokenizer(clean_sent, return_tensors="pt")
    corrupt_tokens = tokenizer(corrupt_sent, return_tensors="pt")

    # For the article position, we look at hidden state
    article_pos = find_article_position_simple(tokenizer, clean_sent, corrupt_sent)

    def get_hidden_states(outputs):
        if hasattr(outputs, 'last_hidden_state'):
            return outputs.last_hidden_state
        elif hasattr(outputs, 'hidden_states') and outputs.hidden_states is not None:
            return outputs.hidden_states[-1]
        else:
            raise AttributeError(f"Cannot find hidden states in {type(outputs)}")

    # Getting baseline hidden states for clean input
    with torch.no_grad():
        orig_outputs = model(**clean_tokens, output_attentions=True)
        baseline_hidden = get_hidden_states(orig_outputs)[0, article_pos]

    # Getting intervened attention for all heads
    with torch.no_grad():
        pert_outputs = model(**corrupt_tokens, output_attentions=True)
        intervention_attentions = {}
        for layer, head in head_list:
            intervention_attentions[(layer, head)] = pert_outputs.attentions[layer][0, head]

    # Applying multi-head intervention
    original_dropouts = {}
    intervention_dropouts = {}

    try:
        if hasattr(model, 'encoder'):
            encoder_layers = model.encoder.layer
        elif hasattr(model, 'bert'):
            encoder_layers = model.bert.encoder.layer
        else:
            raise AttributeError("Cannot find encoder layers")

        # Intervention dropouts created for each layer that has heads to patch
        layers_to_patch = set(layer for layer, head in head_list)

        for layer_idx in layers_to_patch:
            attention_layer = encoder_layers[layer_idx].attention
            original_dropouts[layer_idx] = attention_layer.self.dropout

            # Get all heads to patch in this layer
            heads_in_layer = [head for layer, head in head_list if layer == layer_idx]

            # Create multi-head intervention dropout class
            class MultiHeadInterventionDropout(torch.nn.Module):
                def __init__(self, original_dropout, intervention_attns, target_heads):
                    super().__init__()
                    self.original_dropout = original_dropout
                    self.intervention_attns = intervention_attns
                    self.target_heads = target_heads
                    self.training = original_dropout.training

                def forward(self, attention_probs):
                    if attention_probs.dim() == 4:  # [batch, heads, seq, seq]
                        modified_probs = attention_probs.clone()
                        for head_idx in self.target_heads:
                            if head_idx < modified_probs.shape[1]:
                                modified_probs[0, head_idx] = self.intervention_attns[head_idx]
                        return self.original_dropout(modified_probs)
                    else:
                        return self.original_dropout(attention_probs)

            # Preparing intervention attentions for this layer and replacing dropout
            layer_interventions = {}
            for head in heads_in_layer:
                layer_interventions[head] = intervention_attentions[(layer_idx, head)]

            intervention_dropouts[layer_idx] = MultiHeadInterventionDropout(
                original_dropouts[layer_idx], layer_interventions, heads_in_layer
            )
            attention_layer.self.dropout = intervention_dropouts[layer_idx]

        # Doing a Forward pass with multi-head intervention
        intervened_outputs = model(**clean_tokens, output_attentions=True)
        intervened_hidden = get_hidden_states(intervened_outputs)[0, article_pos]

        # finally, calculating effect
        effect = torch.norm(intervened_hidden - baseline_hidden).item()

    except Exception as e:
        print(f"Error with multi-head patching: {e}")
        effect = 0.0

    finally:
        # Restoring all original dropouts
        for layer_idx, original_dropout in original_dropouts.items():
            encoder_layers[layer_idx].attention.self.dropout = original_dropout

    return effect

def get_head_representation(model, tokenizer, sentence, layer, head):
    """
    Helper function to get representation from a specific head
    """
    # Tokenizing sentence
    tokens = tokenizer(sentence, return_tensors="pt")

    article_pos = 0
    if len(tokens['input_ids'][0]) > 1:
        for i, token_id in enumerate(tokens['input_ids'][0]):
            token = tokenizer.decode([token_id]).strip().lower()
            if token in ['der', 'die', 'das', 'den', 'dem', 'des']:
                article_pos = i
                break

    # With the model outputs, we extract head representations
    with torch.no_grad():
        outputs = model(**tokens, output_attentions=True, output_hidden_states=True)

    # Extracting head representation od size [seq_len, hidden_size]
    hidden_states = outputs.hidden_states[layer][0]

    # Getting the head dimensions
    head_dim = hidden_states.size(-1) // model.config.num_attention_heads
    start_idx = head * head_dim
    end_idx = (head + 1) * head_dim

    # Finding head representation at article position
    head_repr = hidden_states[article_pos, start_idx:end_idx]

    return head_repr

def get_head_representation_with_patch(model, tokenizer, clean_sent, corrupt_sent, source_head, target_head):
    """
    Helper function to get target head representation when source head is patched
    """
    clean_tokens = tokenizer(clean_sent, return_tensors="pt")
    corrupt_tokens = tokenizer(corrupt_sent, return_tensors="pt")
    article_pos = find_article_position_simple(tokenizer, clean_sent, corrupt_sent)

    source_layer, source_head_idx = source_head
    target_layer, target_head_idx = target_head

    # Patching commences
    with torch.no_grad():
        pert_outputs = model(**corrupt_tokens, output_attentions=True)
        intervention_attention = pert_outputs.attentions[source_layer][0, source_head_idx]

    original_dropout = None

    try:
        # Access encoder layers
        if hasattr(model, 'encoder'):
            encoder_layers = model.encoder.layer
        elif hasattr(model, 'bert'):
            encoder_layers = model.bert.encoder.layer
        else:
            raise AttributeError("Cannot find encoder layers")

        # Patch source head
        source_attention_layer = encoder_layers[source_layer].attention
        original_dropout = source_attention_layer.self.dropout

        class SourceInterventionDropout(torch.nn.Module):
            def __init__(self, original_dropout, intervention_attn, target_head):
                super().__init__()
                self.original_dropout = original_dropout
                self.intervention_attn = intervention_attn
                self.target_head = target_head
                self.training = original_dropout.training

            def forward(self, attention_probs):
                if attention_probs.dim() == 4:  # [batch, heads, seq, seq]
                    modified_probs = attention_probs.clone()
                    modified_probs[0, self.target_head] = self.intervention_attn
                    return self.original_dropout(modified_probs)
                else:
                    return self.original_dropout(attention_probs)

        # Replacing source head dropout with intervened attention
        source_attention_layer.self.dropout = SourceInterventionDropout(
            original_dropout, intervention_attention, source_head_idx
        )

        # Forward pass with source head patched
        outputs = model(**clean_tokens, output_attentions=True, output_hidden_states=True)

        # Extracting target head representation
        target_hidden_states = outputs.hidden_states[target_layer][0]
        head_dim = target_hidden_states.size(-1) // model.config.num_attention_heads
        start_idx = target_head_idx * head_dim
        end_idx = (target_head_idx + 1) * head_dim

        target_head_repr = target_hidden_states[article_pos, start_idx:end_idx]

    except Exception as e:
        print(f"Error in causal influence measurement: {e}")
        # Return zero representation as fallback
        hidden_size = model.config.hidden_size
        head_dim = hidden_size // model.config.num_attention_heads
        target_head_repr = torch.zeros(head_dim)

    finally:
        # Restoring original dropout
        if original_dropout is not None:
            source_attention_layer.self.dropout = original_dropout

    return target_head_repr


In [17]:
import torch
import numpy as np
from itertools import combinations
import pandas as pd
from tqdm import tqdm
import random

def test_head_synergy_intervention_comprehensive(model, tokenizer, test_pairs,
                                               error_sensitive_heads, gender_predictive_heads,
                                               config):
    """

    Args:
        test_pairs: List of (correct_sentence, incorrect_sentence, gender) tuples
        error_sensitive_heads: List of (layer, head) tuples from error sensitivity analysis
        gender_predictive_heads: List of (layer, head) tuples from classification analysis
        config: Dict with testing configuration

    Returns:
        Dict with detailed synergy analysis results
    """

    results = {
        'individual_effects': {},
        'pairwise_synergy': {},
        'group_effects': {},
        'causal_chains': {},
        'config': config
    }

    print(" COMPREHENSIVE HEAD SYNERGY ANALYSIS")
    print(f"   Configuration:")
    print(f"   - Individual heads: {config['num_heads_individual']}")
    print(f"   - Synergy pairs: {config['num_heads_synergy']} x {config['num_heads_synergy']}")
    print(f"   - Test sentences: {len(test_pairs)} total")
    print(f"   - Cross-validation: {config['cross_validation']} folds")

    # 1. Comprehensive individual head effects
    print("\n Testing individual head effects ")
    #results['individual_effects'] = test_individual_head_effects(
    #    model, tokenizer, test_pairs,
    #    error_sensitive_heads + gender_predictive_heads,
    #    config
    #)

    # 2. Pairwise synergy testing
    #print("\n Testing pairwise synergy")
    #results['pairwise_synergy'] = test_pairwise_synergy(
    #    model, tokenizer, test_pairs,
    #    error_sensitive_heads, gender_predictive_heads,
    #    config
    #)

    # 3. Multi-size group effects
    #print("\n Testing group effects")
    #results['group_effects'] = test_group_effects(
    #    model, tokenizer, test_pairs,
    #    error_sensitive_heads, gender_predictive_heads,
    #    config
    #)

    # 4. Comprehensive causal chains
    print("\n Testing causal chains")
    results['causal_chains'] = test_causal_chains(
        model, tokenizer, test_pairs,
        error_sensitive_heads, gender_predictive_heads,
        config
    )

    return results

def test_individual_head_effects(model, tokenizer, test_pairs, all_heads, config):
    """
    Test individual head effects with cross-validation and larger sample sizes
    """
    individual_effects = {}

    heads_to_test = all_heads[:config['num_heads_individual']]
    sentences_per_head = config['num_sentences_individual']

    print(f"   Testing {len(heads_to_test)} heads on {sentences_per_head} sentences each...")

    for layer, head in tqdm(heads_to_test, desc="Individual heads"):

        cv_effects = []

        for cv_fold in range(config['cross_validation']):
            # Sampling different sentences for each fold of cross-validation
            fold_pairs = random.sample(test_pairs, min(sentences_per_head, len(test_pairs)))

            total_effect = 0
            valid_pairs = 0

            for correct_sent, incorrect_sent, gender in fold_pairs:
                try:
                    effect = patch_single_head(
                        model, tokenizer, correct_sent, incorrect_sent, layer, head
                    )
                    total_effect += effect
                    valid_pairs += 1
                except Exception as e:
                    continue

            if valid_pairs >= config['min_valid_pairs']:
                cv_effects.append(total_effect / valid_pairs)

        # Storing results
        if len(cv_effects) > 0:
            individual_effects[(layer, head)] = {
                'mean_effect': np.mean(cv_effects),
                'std_effect': np.std(cv_effects),
                'cv_effects': cv_effects,
                'valid_folds': len(cv_effects)
            }

    return individual_effects

def test_pairwise_synergy(model, tokenizer, test_pairs,
                                      error_sensitive_heads, gender_predictive_heads, config):
    """
    Comprehensive pairwise synergy testing with more combinations
    """
    synergy_scores = {}

    # Test more head combinations
    error_heads_to_test = error_sensitive_heads[:config['num_heads_synergy']]
    pred_heads_to_test = gender_predictive_heads[:config['num_heads_synergy']]

    # Also test within-category synergy
    combinations_to_test = []

    # Cross-category combinations (error + predictive)
    for error_head in error_heads_to_test:
        for pred_head in pred_heads_to_test:
            if error_head != pred_head:  # Don't test head with itself
                combinations_to_test.append(('cross', error_head, pred_head))

    # Within-category combinations (error + error, predictive + predictive)
    for head1, head2 in combinations(error_heads_to_test, 2):
        combinations_to_test.append(('error_error', head1, head2))

    for head1, head2 in combinations(pred_heads_to_test, 2):
        combinations_to_test.append(('pred_pred', head1, head2))

    print(f"   Testing {len(combinations_to_test)} head pair combinations...")

    for combination_type, head1, head2 in tqdm(combinations_to_test, desc="Synergy pairs"):

        synergy_score = calculate_synergy_score(
            model, tokenizer, test_pairs[:config['num_sentences_synergy']],
            head1, head2, config
        )

        synergy_scores[((head1, head2), combination_type)] = synergy_score

    return synergy_scores

def calculate_synergy_score(model, tokenizer, test_pairs, head1, head2, config):
    """
    Calculate synergy score to see if effect of patching 2 heads is greater than patching the heads individually. Typically, one head
    from the error-sensitive heads which are in earlier layers and one from later layers is used. Or both from same category but at different layers.
    """

    synergy_scores = []

    for correct_sent, incorrect_sent, gender in test_pairs:
        try:
            # Effect of head1 alone
            effect1 = patch_single_head(model, tokenizer, correct_sent, incorrect_sent, *head1)

            # Effect of head2 alone
            effect2 = patch_single_head(model, tokenizer, correct_sent, incorrect_sent, *head2)

            # Effect of both heads together
            effect_both = patch_multiple_heads(model, tokenizer, correct_sent, incorrect_sent, [head1, head2])

            # Synergy = joint effect - sum of individual effects
            synergy = effect_both - (effect1 + effect2)
            synergy_scores.append(synergy)

        except Exception as e:
            continue

    if len(synergy_scores) >= config['min_valid_pairs']:
        return {
            'mean_synergy': np.mean(synergy_scores),
            'std_synergy': np.std(synergy_scores),
            'median_synergy': np.median(synergy_scores),
            'valid_pairs': len(synergy_scores),
            'raw_scores': synergy_scores
        }
    else:
        return None

def test_group_effects(model, tokenizer, test_pairs,
                                   error_sensitive_heads, gender_predictive_heads, config):
    """
    Test group effects with multiple group sizes and compositions
    """
    group_effects = {}

    sentences_to_use = test_pairs[:config['num_sentences_group']]

    for group_size in config['group_sizes']:
        print(f"   Testing groups of size {group_size}...")

        # Test different group compositions
        group_compositions = {
            f'top_{group_size}_error': error_sensitive_heads[:group_size],
            f'top_{group_size}_predictive': gender_predictive_heads[:group_size],
            #f'mixed_{group_size}': (error_sensitive_heads[:group_size//2] +
            #                       gender_predictive_heads[:group_size//2]),
            f'combined_{group_size}': (error_sensitive_heads[:group_size//2] +
                                      gender_predictive_heads[:group_size//2])
        }

        for group_name, head_group in group_compositions.items():
            if len(head_group) == group_size:  # Ensure we have enough heads

                group_effect = test_head_group(
                    model, tokenizer, sentences_to_use, head_group, config
                )

                group_effects[group_name] = group_effect

    return group_effects

def test_head_group(model, tokenizer, test_pairs, head_group, config):

    group_effects = []

    for correct_sent, incorrect_sent, gender in test_pairs:
        try:
            effect = patch_multiple_heads(model, tokenizer, correct_sent, incorrect_sent, head_group)
            group_effects.append(effect)
        except Exception as e:
            continue

    if len(group_effects) >= config['min_valid_pairs']:
        return {
            'mean_effect': np.mean(group_effects),
            'std_effect': np.std(group_effects),
            'median_effect': np.median(group_effects),
            'valid_pairs': len(group_effects),
            'head_count': len(head_group),
            'heads': head_group
        }
    else:
        return None

def test_causal_chains(model, tokenizer, test_pairs,
                                   error_sensitive_heads, gender_predictive_heads, config):
    """
    Comprehensive causal chain testing
    """
    causal_effects = {}

    # Test more heads for causal relationships
    error_heads_to_test = error_sensitive_heads[:config['num_heads_causal']]
    pred_heads_to_test = gender_predictive_heads[:config['num_heads_causal']]

    sentences_to_use = test_pairs[:config['num_sentences_causal']]

    causal_pairs = []

    # Error → Predictive chains
    for error_head in error_heads_to_test:
        for pred_head in pred_heads_to_test:
            if error_head != pred_head and error_head[0] <= pred_head[0]:
                causal_pairs.append(('error_to_pred', error_head, pred_head))

    # Within-category chains (earlier → later layers)
    for head1 in error_heads_to_test:
        for head2 in error_heads_to_test:
            if head1[0] < head2[0]:  # Earlier layer to later layer
                causal_pairs.append(('error_to_error', head1, head2))

    for head1 in pred_heads_to_test:
        for head2 in pred_heads_to_test:
            if head1[0] < head2[0]:  # Earlier layer to later layer
                causal_pairs.append(('pred_to_pred', head1, head2))

    print(f"   Testing {len(causal_pairs)} causal relationships...")

    for causal_type, source_head, target_head in tqdm(causal_pairs, desc="Causal chains"):

        causal_effect = measure_causal_influence(
            model, tokenizer, sentences_to_use, source_head, target_head, config
        )

        if causal_effect is not None:
            causal_effects[((source_head, target_head), causal_type)] = causal_effect

    return causal_effects

def measure_causal_influence(model, tokenizer, test_pairs, source_head, target_head, config):
    """
    Measure causal influence where the source head is patched and the effect on the target head is noted
    """
    influences = []

    for correct_sent, incorrect_sent, gender in test_pairs:
        try:
            # Getting baseline target head representation
            baseline_repr = get_head_representation(
                model, tokenizer, correct_sent, *target_head
            )

            # Patching source head and see how target head changes
            patched_repr = get_head_representation_with_patch(
                model, tokenizer, correct_sent, incorrect_sent, source_head, target_head
            )

            baseline_norm = torch.norm(baseline_repr).item()
            if baseline_norm > 1e-8:
                influence = torch.norm(patched_repr - baseline_repr).item() / baseline_norm
            else:
                influence = torch.norm(patched_repr - baseline_repr).item()

            influences.append(influence)

        except Exception as e:
            continue

    if len(influences) >= config['min_valid_pairs']:
        return {
            'mean_influence': np.mean(influences),
            'std_influence': np.std(influences),
            'median_influence': np.median(influences),
            'valid_pairs': len(influences),
            'raw_influences': influences
        }
    else:
        return None

def analyze_synergy_results_comprehensive(synergy_results, config):
    """
    Comprehensive analysis and interpretation of synergy test results for individual, pairwise, group effects as well as causal chains
    """
    print("\n COMPREHENSIVE SYNERGY ANALYSIS RESULTS") # with confidence intervals
    print("=" * 60)


    #individual = synergy_results['individual_effects']
    #if individual:
    #    print("\n TOP INDIVIDUAL HEAD EFFECTS :")
    #    sorted_individual = sorted(
    #        [(k, v) for k, v in individual.items() if v.get('valid_folds', 0) > 0],
    #        key=lambda x: x[1]['mean_effect'], reverse=True
    #    )

    #    for (layer, head), data in sorted_individual[:10]:
    #        mean_eff = data['mean_effect']
    #        std_eff = data['std_effect']
    #        ci_lower = mean_eff - 1.96 * std_eff / np.sqrt(data['valid_folds'])
    #        ci_upper = mean_eff + 1.96 * std_eff / np.sqrt(data['valid_folds'])

    #        print(f"  Layer {layer:2d}, Head {head:2d}: {mean_eff:.4f} ± {std_eff:.4f} "
    #              f"[{ci_lower:.4f}, {ci_upper:.4f}] (n={data['valid_folds']})")

    pairwise = synergy_results['pairwise_synergy']
    if pairwise:
        print("\nTOP SYNERGISTIC PAIRS:") # (with significance)

        valid_pairs = [(k, v) for k, v in pairwise.items() if v is not None]
        sorted_synergy = sorted(valid_pairs, key=lambda x: x[1]['mean_synergy'], reverse=True)

        for ((head1, head2), pair_type), data in sorted_synergy[:10]:
            mean_syn = data['mean_synergy']
            std_syn = data['std_synergy']

            # Simple significance test checking if the mean is significantly different from 0?
            t_stat = mean_syn / (std_syn / np.sqrt(data['valid_pairs'])) if std_syn > 0 else 0
            p_value = 2 * (1 - stats.t.cdf(abs(t_stat), data['valid_pairs'] - 1))

            synergy_type = " Synergistic" if mean_syn > 0.05 else " Interfering" if mean_syn < -0.05 else " Independent"
            significance = "*" if p_value < 0.05 else ""

            print(f"  {head1} + {head2} ({pair_type}): {mean_syn:.4f} ± {std_syn:.4f} "
                  f"({synergy_type}){significance} p={p_value:.3f}")

    group = synergy_results['group_effects']
    if group:
        print("\nGROUP EFFECTS:")
        for group_name, data in group.items():
            if data is not None:
                print(f"  {group_name}: {data['mean_effect']:.4f} ± {data['std_effect']:.4f} "
                      f"(n={data['valid_pairs']}, heads={data['head_count']})")

    chains = synergy_results['causal_chains']
    if chains:
        print("\n STRONGEST CAUSAL INFLUENCES:")

        valid_chains = [(k, v) for k, v in chains.items() if v is not None]
        sorted_chains = sorted(valid_chains, key=lambda x: x[1]['mean_influence'], reverse=True)

        for ((source, target), chain_type), data in sorted_chains[:10]:
            mean_inf = data['mean_influence']
            std_inf = data['std_influence']

            print(f"  {source} → {target} ({chain_type}): {mean_inf:.4f} ± {std_inf:.4f} "
                  f"(n={data['valid_pairs']})")


def run_efficient_synergy_analysis(model, tokenizer, test_pairs,
                                 error_sensitive_heads, gender_predictive_heads):
    """
    Defining config

    """

    efficient_config = {
        'num_heads_individual': 14,      # Testing all heads individually
        'num_heads_synergy': 5,          # Top 5 from each (25 combinations)
        'num_heads_causal': 5,           # Top 5
        'num_sentences_individual': 80,
        'num_sentences_synergy': 50,
        'num_sentences_group': 70,
        'num_sentences_causal': 30,
        'group_sizes': [3, 5],           # Just two group sizes
        'cross_validation': 3,           # Just 3 folds
        'min_valid_pairs': 5
    }

    print(f"Starting efficient analysis with {len(test_pairs)} test pairs...")
    print(f"Testing: {efficient_config['num_heads_synergy']}×{efficient_config['num_heads_synergy']} = {efficient_config['num_heads_synergy']**2} synergy pairs")

    synergy_results = test_head_synergy_intervention_comprehensive(
        model, tokenizer, test_pairs,
        error_sensitive_heads, gender_predictive_heads,
        efficient_config
    )

    analyze_synergy_results_comprehensive(synergy_results, efficient_config)

    return synergy_results

In [18]:
test_pairs = pairs["case"][200:300]
top_gender_predictive_heads = [(8,5),(6,0),(6,4),(11,7),(8,0),(6,9),(6,1),(6,6),(6,10),(10,7)] ## from the heads collected in Gender_Encoding.ipynb
top_error_sensitive_heads = [(0,8),(0,0),(0,3),(0,4),(0,11),(0,1),(0,7),(1,0),(0,2),(4,4)]

run_efficient_synergy_analysis(
     model, tokenizer, test_pairs,
     top_error_sensitive_heads[:7], top_gender_predictive_heads[:7])


Starting efficient analysis with 100 test pairs...
Testing: 5×5 = 25 synergy pairs
 COMPREHENSIVE HEAD SYNERGY ANALYSIS
   Configuration:
   - Individual heads: 14
   - Synergy pairs: 5 x 5
   - Test sentences: 100 total
   - Cross-validation: 3 folds

 Testing individual head effects 

 Testing causal chains
   Testing 33 causal relationships...


Causal chains: 100%|██████████| 33/33 [02:13<00:00,  4.06s/it]


 COMPREHENSIVE SYNERGY ANALYSIS RESULTS

 STRONGEST CAUSAL INFLUENCES:
  (6, 0) → (8, 5) (pred_to_pred): 0.3239 ± 0.4697 (n=30)
  (6, 0) → (8, 0) (pred_to_pred): 0.3234 ± 0.4768 (n=30)
  (6, 4) → (8, 5) (pred_to_pred): 0.3218 ± 0.4717 (n=30)
  (0, 0) → (8, 5) (error_to_pred): 0.3211 ± 0.4721 (n=30)
  (6, 4) → (8, 0) (pred_to_pred): 0.3211 ± 0.4776 (n=30)
  (0, 0) → (8, 0) (error_to_pred): 0.3203 ± 0.4788 (n=30)
  (0, 11) → (8, 5) (error_to_pred): 0.3188 ± 0.4750 (n=30)
  (0, 11) → (8, 0) (error_to_pred): 0.3181 ± 0.4795 (n=30)
  (0, 4) → (8, 5) (error_to_pred): 0.3153 ± 0.4762 (n=30)
  (0, 0) → (6, 0) (error_to_pred): 0.3149 ± 0.4671 (n=30)





{'individual_effects': {},
 'pairwise_synergy': {},
 'group_effects': {},
 'causal_chains': {(((0, 8), (8, 5)),
   'error_to_pred'): {'mean_influence': np.float64(0.3139770758071471), 'std_influence': np.float64(0.4776290972799598), 'median_influence': np.float64(0.008366465743970111), 'valid_pairs': 30, 'raw_influences': [1.0410605731495473,
    0.006288949010171902,
    1.1017667660747108,
    0.005520208381914462,
    1.0286425073700398,
    0.0089668298107244,
    0.008595327363224119,
    0.0027632690727929553,
    0.006245707105633919,
    0.007035099419349268,
    0.8264772896250354,
    1.3361900693804658,
    0.0034144596339652613,
    0.003734178427415919,
    0.01466724249516617,
    0.009849652792046698,
    0.007045786568529879,
    0.017301256923706747,
    1.0411246184855287,
    0.010624240582884895,
    0.006137104100608376,
    1.214927864523783,
    0.005870091090076657,
    0.008137604124716103,
    0.8629840226119876,
    0.0024261807697640407,
    0.00714406831124

In [None]:
def extract_ud_test_sentences(ud_file_path):
    """
    Extracting German test sentences from UD dataset to create a dataset to just test logit distribution after ablation

    Args:
        ud_file_path: Path to UD German file (e.g., 'de_gsd-ud-test.conllu')

    Returns:
        List of (sentence, article, gender) tuples
    """

    # German articles and their genders
    article_genders = {
        'das': 'Neut', 'der': 'Masc', 'die': 'Fem',
        'den': 'Masc', 'dem': 'Masc', 'des': 'Masc'
    }

    test_sentences = []
    current_sentence = []
    current_words = []

    try:
        with open(ud_file_path, 'r', encoding='utf-8') as file:
            for line in file:
                line = line.strip()

                # End of sentence
                if not line or line.startswith('#'):
                    if current_sentence and current_words:

                        sentence_text = ' '.join(current_words)

                        # Finding articles in this sentence
                        for token in current_sentence:
                            word_original = token['form']  # Keeping original capitalization
                            word_lower = word_original.lower()
                            upos = token['upos']

                            if upos == 'DET' and word_lower in article_genders:
                                gender = article_genders[word_lower]
                                test_sentences.append((sentence_text, word_original, gender))
                                break  # Only taking the first article per sentence

                    current_sentence = []
                    current_words = []
                    continue

                # Parsing token line
                if '\t' in line:
                    fields = line.split('\t')
                    if len(fields) >= 4 and '-' not in fields[0]:  # Skip multiword tokens
                        token_info = {
                            'form': fields[1],
                            'upos': fields[3]
                        }
                        current_sentence.append(token_info)
                        current_words.append(fields[1])

    except FileNotFoundError:
        print(f"File not found: {ud_file_path}")
        return []

    print(f"Extracted {len(test_sentences)} sentences")
    return test_sentences

test_sentences = extract_ud_test_sentences('UD_German-GSD/de_gsd-ud-test.conllu')


Extracted 753 sentences


In [None]:
test_sentences[:2]

[('Der Hauptgang war in Ordnung , aber alles andere als umwerfend .',
  'Der',
  'Masc'),
 ('Ich habe dort 2007 meinen OWD gemacht und weil mir das Tauchen so gefiel hab ich dort noch in dem selben Jahr den AOWD und den Deep drangehängt .',
  'das',
  'Neut')]

In [None]:
import torch
import torch.nn.functional as F
import numpy as np

def compute_kl_divergence(orig_probs, ablated_probs, epsilon=1e-10):
    """
    Computing KL divergence: KL(orig || ablated)
    """
    ablated_probs_safe = ablated_probs + epsilon #adding epsilon to prevent log 0
    kl_div = torch.sum(orig_probs * torch.log(orig_probs / ablated_probs_safe))
    return kl_div.item()

def compute_js_divergence(orig_probs, ablated_probs, epsilon=1e-10):
    """
    Computing Jensen-Shannon divergence (symmetric version of KL)
    """
    orig_probs_safe = orig_probs + epsilon
    ablated_probs_safe = ablated_probs + epsilon

    # Average distribution
    m = 0.5 * (orig_probs_safe + ablated_probs_safe)

    # JS divergence = 0.5 * KL(P || M) + 0.5 * KL(Q || M)
    kl1 = torch.sum(orig_probs_safe * torch.log(orig_probs_safe / m))
    kl2 = torch.sum(ablated_probs_safe * torch.log(ablated_probs_safe / m))

    js_div = 0.5 * kl1 + 0.5 * kl2
    return js_div.item()

def compute_distribution_distances(orig_logits, ablated_logits):
    """
    Computing the various distance metrics between original and ablated model outputs
    """
    # Convert to probabilities
    orig_probs = F.softmax(orig_logits, dim=0)
    ablated_probs = F.softmax(ablated_logits, dim=0)

    # Distance metrics
    results = {}

    # 1. KL divergence (asymmetric)
    results['kl_orig_to_ablated'] = compute_kl_divergence(orig_probs, ablated_probs)
    results['kl_ablated_to_orig'] = compute_kl_divergence(ablated_probs, orig_probs)

    # 2. Jensen-Shannon divergence (symmetric)
    results['js_divergence'] = compute_js_divergence(orig_probs, ablated_probs)

    # 3. L2 distance on logits (unnormalized scores)
    results['l2_logits'] = torch.norm(orig_logits - ablated_logits, p=2).item()

    # 4. L2 distance on probabilities
    results['l2_probs'] = torch.norm(orig_probs - ablated_probs, p=2).item()

    # 5. Cosine similarity on logits
    cos_sim = F.cosine_similarity(orig_logits.unsqueeze(0), ablated_logits.unsqueeze(0))
    results['cosine_similarity'] = cos_sim.item()
    results['cosine_distance'] = 1 - cos_sim.item()

    return results

def get_masked_logits_full(model, tokenizer, masked_sentence):
    """
    Get full logits for the masked token (entire vocabulary)
    """
    # Tokenizing masked sentence
    inputs = tokenizer(masked_sentence, return_tensors="pt")

    # Finding [MASK] position
    mask_token_id = tokenizer.mask_token_id
    mask_positions = (inputs.input_ids == mask_token_id).nonzero(as_tuple=True)[1]

    if len(mask_positions) == 0:
        raise ValueError(f"No [MASK] token found in: {masked_sentence}")

    mask_pos = mask_positions[0].item()

    # Getting logits
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits[0, mask_pos]  # Full vocabulary logits

    return logits

def create_masked_sentence(tokenizer, sentence, article_position):
    """Replacing article with [MASK] token"""
    tokens = tokenizer.tokenize(sentence)

    words = sentence.split()

    # Simple approach: replace the article word with [MASK] by finding which word contains the article
    tokens_with_positions = tokenizer(sentence, return_tensors="pt", return_offsets_mapping=True)

    if hasattr(tokens_with_positions, 'offset_mapping'):
        offset_map = tokens_with_positions.offset_mapping[0]
        if article_position < len(offset_map):
            start_char, end_char = offset_map[article_position]

            # Replacing the character span with [MASK]
            masked_sentence = sentence[:start_char] + "[MASK]" + sentence[end_char:]
            return masked_sentence

    words = sentence.split()
    article_word = get_article_at_position(tokenizer, sentence, article_position)

    for i, word in enumerate(words):
        if article_word in word:
            words[i] = "[MASK]"
            break

    return " ".join(words)

def find_article_in_sentence(tokenizer, sentence, target_article):
    """Find position of specific article in sentence"""
    tokens = tokenizer(sentence, return_tensors="pt")
    token_ids = tokens.input_ids[0]

    target_id = tokenizer.convert_tokens_to_ids(target_article)

    for i, token_id in enumerate(token_ids):
        if token_id == target_id:
            return i

    return None

def get_article_at_position(tokenizer, sentence, position):
    tokens = tokenizer(sentence, return_tensors="pt")
    token_ids = tokens.input_ids[0]

    if position < len(token_ids):
        token_id = token_ids[position]
        return tokenizer.decode([token_id])

    return None

def masked_article_test_with_distances(model, tokenizer, test_sentences):
    """
    Test using KL-divergence and L2 distances instead of accuracy

    Args:
        test_sentences: List of (sentence, correct_article, gender) tuples

    Returns:
        List of results with distance metrics
    """

    print("MASKED ARTICLE TEST - Distance-based Analysis")
    print("=" * 60)

    results = []
    total_tests = 0

    for i, (sentence, correct_article, gender) in enumerate(test_sentences):

        article_pos = find_article_in_sentence(tokenizer, sentence, correct_article)

        if article_pos is None:
            print(f"Sentence {i+1}: Could not find article '{correct_article}'")
            continue

        masked_sentence = create_masked_sentence(tokenizer, sentence, article_pos)

        try:
            # Getting original model logits
            orig_logits = get_masked_logits_full(model, tokenizer, masked_sentence)

            # Storing baseline result
            results.append({
                'sentence_id': i + 1,
                'gender': gender,
                'original_sentence': sentence,
                'masked_sentence': masked_sentence,
                'correct_article': correct_article,
                'orig_logits': orig_logits.clone(),
                'is_baseline': True
            })

            total_tests += 1

        except Exception as e:
            print(f"Error processing sentence {i+1}: {e}")
            continue

    print(f"Successfully processed {total_tests} sentences for baseline")
    return results

def simple_ablation_test_distances(model, tokenizer, test_sentences, heads_to_ablate):
    """
    Ablation test using KL-divergence and L2 distances

    Args:
        model: mBERT model
        tokenizer: mBERT tokenizer
        test_sentences: List of (sentence, correct_article, gender) tuples
        heads_to_ablate: List of (layer, head) tuples to ablate

    Returns:
        Dict with baseline and ablated results plus distance metrics
    """

    print(f"ABLATION TEST WITH DISTANCE METRICS")
    print(f"Ablating heads: {heads_to_ablate}")
    print("=" * 60)

    print("Getting baseline logits (no ablation)...")
    baseline_results = masked_article_test_with_distances(model, tokenizer, test_sentences)

    if not baseline_results:
        print("No baseline results obtained!")
        return None

    # Applying ablation
    print(f"\nApplying ablation to heads: {heads_to_ablate}")
    original_weights = ablate_heads(model, heads_to_ablate)

    try:
        # Getting ablated logits
        print("Getting ablated logits...")
        distance_results = []

        for baseline_result in baseline_results:
            sentence_id = baseline_result['sentence_id']
            masked_sentence = baseline_result['masked_sentence']
            orig_logits = baseline_result['orig_logits']

            try:
                ablated_logits = get_masked_logits_full(model, tokenizer, masked_sentence)

                distances = compute_distribution_distances(orig_logits, ablated_logits)

                # Storing result
                result = {
                    **baseline_result,  # Include all baseline info
                    'ablated_logits': ablated_logits.clone(),
                    **distances,  # Include all distance metrics
                    'is_baseline': False
                }
                distance_results.append(result)

            except Exception as e:
                print(f"Error processing ablated sentence {sentence_id}: {e}")
                continue

        if distance_results:
            analyze_distance_results(distance_results, heads_to_ablate)

        results = {
            'heads_ablated': heads_to_ablate,
            'baseline_results': baseline_results,
            'distance_results': distance_results,
            'n_sentences': len(distance_results)
        }

    finally:
        restore_heads(model, original_weights)
        print(f"\nRestored original model weights")

    return results

def analyze_distance_results(results, heads_ablated):
    """
    Analyze and summarize the distance-based ablation results
    """
    print(f"\nDISTANCE-BASED ABLATION ANALYSIS")
    print("=" * 50)

    # Extract distance metrics
    kl_orig_to_ablated = [r['kl_orig_to_ablated'] for r in results]
    kl_ablated_to_orig = [r['kl_ablated_to_orig'] for r in results]
    js_divergences = [r['js_divergence'] for r in results]
    l2_logits = [r['l2_logits'] for r in results]
    l2_probs = [r['l2_probs'] for r in results]
    cosine_distances = [r['cosine_distance'] for r in results]

    metrics = {
        'KL(orig||ablated)': kl_orig_to_ablated,
        'KL(ablated||orig)': kl_ablated_to_orig,
        'JS Divergence': js_divergences,
        'L2 Logits': l2_logits,
        'L2 Probs': l2_probs,
        'Cosine Distance': cosine_distances
    }

    print(f"Ablated heads: {heads_ablated}")
    print(f"Number of sentences: {len(results)}")
    print()

    for metric_name, values in metrics.items():
        mean_val = np.mean(values)
        std_val = np.std(values)
        max_val = np.max(values)
        min_val = np.min(values)

        print(f"{metric_name}:")
        print(f"  Mean: {mean_val:.4f} ± {std_val:.4f}")
        print(f"  Range: [{min_val:.4f}, {max_val:.4f}]")

    mean_js = np.mean(js_divergences)
    mean_l2_logits = np.mean(l2_logits)

    print(f"\nOVERALL IMPACT ASSESSMENT:")

    if mean_js > 0.5:
        js_impact = "MASSIVE"
    elif mean_js > 0.2:
        js_impact = "STRONG"
    elif mean_js > 0.05:
        js_impact = "MODERATE"
    else:
        js_impact = "MINIMAL"

    if mean_l2_logits > 50:
        l2_impact = "MASSIVE"
    elif mean_l2_logits > 20:
        l2_impact = "STRONG"
    elif mean_l2_logits > 5:
        l2_impact = "MODERATE"
    else:
        l2_impact = "MINIMAL"

    print(f"  JS Divergence Impact: {js_impact} (mean = {mean_js:.4f})")
    print(f"  L2 Logits Impact: {l2_impact} (mean = {mean_l2_logits:.4f})")

    print(f"\nMOST AFFECTED SENTENCES (by JS Divergence):")
    sorted_results = sorted(results, key=lambda x: x['js_divergence'], reverse=True)[:3]

    for i, result in enumerate(sorted_results, 1):
        print(f"{i}. Sentence {result['sentence_id']} (JS = {result['js_divergence']:.4f}):")
        print(f"   Original: {result['original_sentence']}")
        print(f"   Gender: {result['gender']}, Target: {result['correct_article']}")

def ablate_heads(model, heads_to_ablate):
    """Zero out heads with validation and confirmation"""

    original_weights = {}

    for layer_idx, head_idx in heads_to_ablate:
        try:
            if hasattr(model, 'bert'):
                attention_layer = model.bert.encoder.layer[layer_idx].attention
            else:
                attention_layer = model.encoder.layer[layer_idx].attention

            self_attn = attention_layer.self
            output_attn = attention_layer.output

            head_dim = self_attn.attention_head_size
            start_idx = head_idx * head_dim
            end_idx = (head_idx + 1) * head_dim

            if end_idx > self_attn.query.weight.data.shape[1]:
                print(f"ERROR: Head ({layer_idx}, {head_idx}) out of range!")
                continue

            key = (layer_idx, head_idx)
            original_weights[key] = {
                'query': self_attn.query.weight.data[:, start_idx:end_idx].clone(),
                'key': self_attn.key.weight.data[:, start_idx:end_idx].clone(),
                'value': self_attn.value.weight.data[:, start_idx:end_idx].clone(),
                'output_dense': output_attn.dense.weight.data[:, start_idx:end_idx].clone()
            }

            # Verify we have non-zero weights before ablation
            orig_norm = torch.norm(original_weights[key]['query']).item()
            print(f"   Head {layer_idx}.{head_idx} original norm: {orig_norm:.4f}")

            # Zeroing out this head's weights
            self_attn.query.weight.data[:, start_idx:end_idx] = 0
            self_attn.key.weight.data[:, start_idx:end_idx] = 0
            self_attn.value.weight.data[:, start_idx:end_idx] = 0
            output_attn.dense.weight.data[:, start_idx:end_idx] = 0

            new_norm = torch.norm(self_attn.query.weight.data[:, start_idx:end_idx]).item()
            print(f"   Head {layer_idx}.{head_idx} after ablation...: {new_norm:.4f}")

            if new_norm > 1e-6:
                print(f"   WARNING: Ablation may have failed for head ({layer_idx}, {head_idx})")

        except Exception as e:
            print(f"ERROR ablating head ({layer_idx}, {head_idx}): {e}")
            continue

    return original_weights
def restore_heads(model, original_weights):
    """To restore original attention head weights (query/key/value/output)"""

    for (layer_idx, head_idx), weights in original_weights.items():
        if hasattr(model, 'bert'):
            attention_layer = model.bert.encoder.layer[layer_idx].attention
        else:
            attention_layer = model.encoder.layer[layer_idx].attention

        self_attn = attention_layer.self
        output_attn = attention_layer.output

        head_dim = self_attn.attention_head_size
        start_idx = head_idx * head_dim
        end_idx = (head_idx + 1) * head_dim

        self_attn.query.weight.data[:, start_idx:end_idx] = weights['query']
        self_attn.key.weight.data[:, start_idx:end_idx] = weights['key']
        self_attn.value.weight.data[:, start_idx:end_idx] = weights['value']

        output_attn.dense.weight.data[:, start_idx:end_idx] = weights['output_dense']


In [None]:
## For Gender-Encoding heads

model = AutoModelForMaskedLM.from_pretrained("bert-base-multilingual-cased")

heads_to_test = [(8,5),(6,0),(6,4),(11,7),(8,0),(6,9),(6,1),(6,6),(6,10),(10,7)]

results = simple_ablation_test_distances(model, tokenizer, test_sentences, heads_to_test)


Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


ABLATION TEST WITH DISTANCE METRICS
Ablating heads: [(8, 5), (6, 0), (6, 4), (11, 7), (8, 0), (6, 9), (6, 1), (6, 6), (6, 10), (10, 7)]
Getting baseline logits (no ablation)...
MASKED ARTICLE TEST - Distance-based Analysis
Sentence 609: Could not find article 'Einem'
Successfully processed 752 sentences for baseline

Applying ablation to heads: [(8, 5), (6, 0), (6, 4), (11, 7), (8, 0), (6, 9), (6, 1), (6, 6), (6, 10), (10, 7)]
   Head 8.5 original norm: 9.9065
   Head 8.5 after ablation...: 0.0000
   Head 6.0 original norm: 10.5033
   Head 6.0 after ablation...: 0.0000
   Head 6.4 original norm: 10.5122
   Head 6.4 after ablation...: 0.0000
   Head 11.7 original norm: 10.1050
   Head 11.7 after ablation...: 0.0000
   Head 8.0 original norm: 9.8228
   Head 8.0 after ablation...: 0.0000
   Head 6.9 original norm: 10.5352
   Head 6.9 after ablation...: 0.0000
   Head 6.1 original norm: 10.5665
   Head 6.1 after ablation...: 0.0000
   Head 6.6 original norm: 10.7673
   Head 6.6 after ablat

In [None]:
# For error-sensitive heads

model = AutoModelForMaskedLM.from_pretrained("bert-base-multilingual-cased")

heads_to_test = [(0,8),(0,0),(0,3),(0,4),(0,11),(0,1),(0,7),(1,0),(0,2),(4,4)]

results = simple_ablation_test_distances(model, tokenizer, test_sentences, heads_to_test)


Some weights of the model checkpoint at bert-base-multilingual-cased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


ABLATION TEST WITH DISTANCE METRICS
Ablating heads: [(0, 8), (0, 0), (0, 3), (0, 4), (0, 11), (0, 1), (0, 7), (1, 0), (0, 2), (4, 4)]
Getting baseline logits (no ablation)...
MASKED ARTICLE TEST - Distance-based Analysis
Sentence 609: Could not find article 'Einem'
Successfully processed 752 sentences for baseline

Applying ablation to heads: [(0, 8), (0, 0), (0, 3), (0, 4), (0, 11), (0, 1), (0, 7), (1, 0), (0, 2), (4, 4)]
   Head 0.8 original norm: 9.2777
   Head 0.8 after ablation...: 0.0000
   Head 0.0 original norm: 9.1679
   Head 0.0 after ablation...: 0.0000
   Head 0.3 original norm: 8.9877
   Head 0.3 after ablation...: 0.0000
   Head 0.4 original norm: 9.0872
   Head 0.4 after ablation...: 0.0000
   Head 0.11 original norm: 9.0908
   Head 0.11 after ablation...: 0.0000
   Head 0.1 original norm: 9.1184
   Head 0.1 after ablation...: 0.0000
   Head 0.7 original norm: 9.1842
   Head 0.7 after ablation...: 0.0000
   Head 1.0 original norm: 9.4222
   Head 1.0 after ablation...: 0.