In [20]:
import pandas as pd
import numpy as np
from collections import defaultdict, Counter
import conllu
import random
from typing import List, Tuple, Dict

def load_ud_german(filepath: str):
    """Load UD German dataset"""
    with open(filepath, 'r', encoding='utf-8') as f:
        sentences = conllu.parse(f.read())
    return sentences

def extract_article_noun_pairs(sentences):
    """Extract all article-noun pairs with their contexts"""
    pairs = []
    debug_info = {'total_sentences': 0, 'articles_found': 0, 'with_gender': 0, 'sample_tokens': []}
    
    for sent_idx, sent in enumerate(sentences):
        debug_info['total_sentences'] += 1
        tokens = [token for token in sent if isinstance(token['id'], int)]
        
        if sent_idx < 3:
            debug_info['sample_tokens'].extend([
                {
                    'form': token['form'],
                    'upos': token['upos'],
                    'lemma': token.get('lemma'),
                    'feats': str(token.get('feats')) if token.get('feats') else None
                } for token in tokens[:10]  # First 10 tokens
            ])
        
        for i, token in enumerate(tokens):
            # Finding definite articles - be more flexible
            if token['upos'] == 'DET':
                debug_info['articles_found'] += 1
                
                # Debug print for first few articles
                if debug_info['articles_found'] <= 5:
                    print(f"Article found: form='{token['form']}', lemma='{token.get('lemma')}', feats='{token.get('feats')}'")
                
                # More flexible article detection
                article_forms = ['der', 'die', 'das', 'dem', 'den', 'des']
                if (token['form'].lower() in article_forms or 
                    (token.get('lemma') and token['lemma'].lower() in ['der', 'die', 'das'])):
                    
                    # Finding the noun this article modifies
                    noun_idx = None
                    
                    # Check dependency relation
                    if token.get('head') and isinstance(token['head'], int):
                        potential_noun_idx = token['head'] - 1  # Convert to 0-indexed
                        if (0 <= potential_noun_idx < len(tokens) and 
                            tokens[potential_noun_idx]['upos'] == 'NOUN'):
                            noun_idx = potential_noun_idx
                    
                    # Checking immediate next token
                    if noun_idx is None and i + 1 < len(tokens):
                        if tokens[i + 1]['upos'] == 'NOUN':
                            noun_idx = i + 1
                    
                    # Looking ahead a few tokens for noun
                    if noun_idx is None:
                        for j in range(i + 1, min(i + 4, len(tokens))):
                            if tokens[j]['upos'] == 'NOUN':
                                noun_idx = j
                                break
                    
                    if noun_idx is not None and noun_idx < len(tokens):
                        noun = tokens[noun_idx]
                        
                        # Extracts gender 
                        gender = None
                        if noun.get('feats'):
                            gender = extract_gender(noun['feats'])
                        
                        if not gender and token.get('feats'):
                            gender = extract_gender(token['feats'])
                        
                        case = None
                        if token.get('feats'):
                            case = extract_case(token['feats'])
                        
                        if gender:  # Only add if we found gender
                            debug_info['with_gender'] += 1
                            pairs.append({
                                'sentence': ' '.join([t['form'] for t in tokens]),
                                'article': token['form'].lower(),
                                'article_lemma': token.get('lemma', '').lower(),
                                'article_pos': i,
                                'noun': noun['form'],
                                'noun_lemma': noun.get('lemma', ''),
                                'noun_pos': noun_idx,
                                'gender': gender,
                                'case': case,
                                'sentence_length': len(tokens),
                                'distance': noun_idx - i
                            })
    
    print(f"Debug info: {debug_info['total_sentences']} sentences, {debug_info['articles_found']} articles found, {debug_info['with_gender']} with gender")
    
    if debug_info['sample_tokens']:
        print("\nSample tokens from first sentences:")
        for token in debug_info['sample_tokens'][:15]:
            print(f"  {token}")
    
    return pd.DataFrame(pairs)

def extract_gender(feats):
    """Extract gender from morphological features"""
    if not feats:
        return None
    
    if isinstance(feats, dict):
        gender = feats.get('Gender')
    else:
        feat_str = str(feats)
        if 'Gender=Masc' in feat_str or "'Gender': 'Masc'" in feat_str:
            gender = 'Masc'
        elif 'Gender=Fem' in feat_str or "'Gender': 'Fem'" in feat_str:
            gender = 'Fem'
        elif 'Gender=Neut' in feat_str or "'Gender': 'Neut'" in feat_str:
            gender = 'Neut'
        else:
            return None
    
    # Convert to single letter format
    if gender == 'Masc':
        return 'm'
    elif gender == 'Fem':
        return 'f'
    elif gender == 'Neut':
        return 'n'
    return None

def extract_case(feats):
    """Extract case from morphological features"""
    if not feats:
        return None
    
    if isinstance(feats, dict):
        case = feats.get('Case')
    else:
        # Parse string format
        feat_str = str(feats)
        if 'Case=Nom' in feat_str or "'Case': 'Nom'" in feat_str:
            case = 'Nom'
        elif 'Case=Acc' in feat_str or "'Case': 'Acc'" in feat_str:
            case = 'Acc'
        elif 'Case=Dat' in feat_str or "'Case': 'Dat'" in feat_str:
            case = 'Dat'
        elif 'Case=Gen' in feat_str or "'Case': 'Gen'" in feat_str:
            case = 'Gen'
        else:
            return None
    
    if case:
        return case.lower()
    return None

def extract_valid_article_gender_pairs(pairs_df):
    """Extract valid examples for each article-gender combination"""
    
    valid_combinations = {
        ('der', 'nom', 'm'), ('der', 'dat', 'f'), ('der', 'gen', 'f'),
        ('die', 'nom', 'f'), ('die', 'acc', 'f'), ('die', 'nom', 'n'), ('die', 'acc', 'n'),
        ('das', 'nom', 'n'), ('das', 'acc', 'n'),
        ('dem', 'dat', 'm'), ('dem', 'dat', 'n'),
        ('den', 'acc', 'm'),
        ('des', 'gen', 'm'), ('des', 'gen', 'n')
    }
    
    stimuli = []
    for article, case, gender in valid_combinations:
        examples = pairs_df[
            (pairs_df['article'] == article) & 
            (pairs_df['case'] == case) & 
            (pairs_df['gender'] == gender) &
            (pairs_df['distance'] <= 2) &  # Article close to noun
            (pairs_df['sentence_length'] <= 20)  # Reasonable length
        ]
        
        # Sample up to N examples per combination
        if len(examples) >= 10:
            sampled = examples.sample(n=min(50, len(examples)), random_state=42)
            stimuli.extend(sampled.to_dict('records'))
    
    return stimuli


def create_ud_based_substitution_stimuli(pairs_df):
    """Create substitution stimuli using UD sentence templates"""
    
    valid_examples = extract_valid_article_gender_pairs(pairs_df)
    
    templates_by_article_case = defaultdict(list)
    
    for example in valid_examples:
        # Creating template by replacing the noun with [NOUN]
        sentence = example['sentence']
        noun = example['noun']
        template = sentence.replace(noun, '[NOUN]', 1)
        
        # Grouping templates by article and case
        key = (example['article'], example['case'])
        templates_by_article_case[key].append({
            'template': template,
            'article': example['article'],
            'case': example['case'],
            'article_pos': example['article_pos'],
            'original_noun': noun,
            'original_gender': example['gender']
        })
    
    noun_sets = {
        'm': {
            'nom': ['Mann', 'Hund', 'Tisch', 'Baum', 'Stuhl'],
            'acc': ['Mann', 'Hund', 'Tisch', 'Baum', 'Stuhl'], 
            'dat': ['Mann', 'Hund', 'Tisch', 'Baum', 'Stuhl'],
            'gen': ['Mannes', 'Hundes', 'Tisches', 'Baumes', 'Stuhles']  # Genitive forms
        },
        'f': {
            'nom': ['Frau', 'Katze', 'Lampe', 'Blume', 'Uhr'],
            'acc': ['Frau', 'Katze', 'Lampe', 'Blume', 'Uhr'],
            'dat': ['Frau', 'Katze', 'Lampe', 'Blume', 'Uhr'],
            'gen': ['Frau', 'Katze', 'Lampe', 'Blume', 'Uhr']  # No change for feminine
        },
        'n': {
            'nom': ['Kind', 'Haus', 'Buch', 'Auto', 'Fenster'],
            'acc': ['Kind', 'Haus', 'Buch', 'Auto', 'Fenster'],
            'dat': ['Kind', 'Haus', 'Buch', 'Auto', 'Fenster'], 
            'gen': ['Kindes', 'Hauses', 'Buches', 'Autos', 'Fensters']  # Genitive forms
        }
    }
    
    stimuli = []
    
    for (article, case), template_list in templates_by_article_case.items():
        sampled_templates = random.sample(template_list, min(10, len(template_list)))
        
        for template_info in sampled_templates:
            template = template_info['template']
            
            # Testing each gender with this template
            for gender in ['m', 'f', 'n']:
                # Skip if this is the original correct combination (for comparison)
                if gender == template_info['original_gender']:
                    continue
                
                if case in noun_sets[gender]:
                    available_nouns = noun_sets[gender][case]
                    
                    for noun in available_nouns[:3]:  # Use first 3 nouns per gender
                        substituted_sentence = template.replace('[NOUN]', noun)
                        
                        stimuli.append({
                            'sentence': substituted_sentence,
                            'template': template,
                            'article': article,
                            'article_pos': template_info['article_pos'],
                            'substituted_noun': noun,
                            'substituted_gender': gender,
                            'case': case,
                            'original_noun': template_info['original_noun'],
                            'original_gender': template_info['original_gender'],
                            'is_grammatical': is_grammatically_correct(article, case, gender),
                            'is_substitution': True
                        })
                
                # Also add the original (grammatical) version for comparison
                original_sentence = template.replace('[NOUN]', template_info['original_noun'])
                stimuli.append({
                    'sentence': original_sentence,
                    'template': template,
                    'article': article,
                    'article_pos': template_info['article_pos'],
                    'substituted_noun': template_info['original_noun'],
                    'substituted_gender': template_info['original_gender'],
                    'case': case,
                    'original_noun': template_info['original_noun'],
                    'original_gender': template_info['original_gender'],
                    'is_grammatical': True,
                    'is_substitution': False
                })
    
    return stimuli

def is_grammatically_correct(article, case, gender):
    """Check if article-case-gender combination is grammatically correct"""
    correct_combinations = {
        ('der', 'nom', 'm'), ('der', 'dat', 'f'), ('der', 'gen', 'f'),
        ('die', 'nom', 'f'), ('die', 'acc', 'f'), 
        ('das', 'nom', 'n'), ('das', 'acc', 'n'),
        ('dem', 'dat', 'm'), ('dem', 'dat', 'n'),
        ('den', 'acc', 'm'),
        ('des', 'gen', 'm'), ('des', 'gen', 'n')
    }
    return (article, case, gender) in correct_combinations

def analyze_substitution_stimuli(stimuli):
    """Analyze the created substitution stimuli"""
    df = pd.DataFrame(stimuli)
    
    print("Substitution Stimuli Analysis:")
    print("="*50)
    print(f"Total stimuli: {len(df)}")
    print(f"Grammatical: {len(df[df['is_grammatical']])}")
    print(f"Ungrammatical: {len(df[~df['is_grammatical']])}")
    print(f"Original sentences: {len(df[~df['is_substitution']])}")
    print(f"Substituted sentences: {len(df[df['is_substitution']])}")
    
    print("\nBy article:")
    print(df['article'].value_counts())
    
    print("\nBy case:")
    print(df['case'].value_counts())
    
    print("\nBy substituted gender:")
    print(df['substituted_gender'].value_counts())
    
    print("\nGrammaticality by article-gender:")
    for article in df['article'].unique():
        article_data = df[df['article'] == article]
        print(f"\n{article}:")
        for gender in ['m', 'f', 'n']:
            gender_data = article_data[article_data['substituted_gender'] == gender]
            if len(gender_data) > 0:
                grammatical_pct = (gender_data['is_grammatical'].sum() / len(gender_data)) * 100
                print(f"  {gender}: {grammatical_pct:.1f}% grammatical ({len(gender_data)} examples)")
    
    return df

# Example usage:
def create_comprehensive_stimuli_dataset(ud_sentences):
    """Create comprehensive dataset combining UD extraction and substitution"""
    
    pairs_df = extract_article_noun_pairs(ud_sentences)

    print("Extracting valid UD examples...")
    valid_examples = extract_valid_article_gender_pairs(pairs_df)
    print(f"Found {len(valid_examples)} valid examples")
    
    print("\nCreating substitution stimuli...")
    substitution_stimuli = create_ud_based_substitution_stimuli(pairs_df)
    print(f"Created {len(substitution_stimuli)} substitution stimuli")
    
    # Analyze
    stimuli_df = analyze_substitution_stimuli(substitution_stimuli)
    
    return {
        'valid_examples': valid_examples,
        'substitution_stimuli': substitution_stimuli,
        'analysis_df': stimuli_df
    }
 

'\nExample stimuli created:\n\nOriginal UD sentence: "Der Mann geht nach Hause."\nTemplate: "Der [NOUN] geht nach Hause."\n\nGenerated stimuli:\n1. "Der Mann geht nach Hause." (original, grammatical)\n2. "Der Frau geht nach Hause." (substituted, ungrammatical - should be "Die Frau")  \n3. "Der Kind geht nach Hause." (substituted, ungrammatical - should be "Das Kind")\n\nThis tests whether mBERT\'s representation of "Der" encodes that it should agree with masculine nouns.\n'

In [86]:
sentences = load_ud_german("UD_German-GSD/de_gsd-ud-train.conllu")
    
    # Creating comprehensive dataset for stimulus
dataset = create_comprehensive_stimuli_dataset(sentences)
    

Article found: form='der', lemma='der', feats='{'Case': 'Gen', 'Definite': 'Def', 'Number': 'Plur', 'PronType': 'Art'}'
Article found: form='Die', lemma='der', feats='{'Case': 'Nom', 'Definite': 'Def', 'Number': 'Plur', 'PronType': 'Art'}'
Article found: form='dem', lemma='der', feats='{'Case': 'Dat', 'Definite': 'Def', 'Gender': 'Masc', 'Number': 'Sing', 'PronType': 'Art'}'
Article found: form='der', lemma='der', feats='{'Case': 'Dat', 'Definite': 'Def', 'Gender': 'Fem', 'Number': 'Sing', 'PronType': 'Art'}'
Article found: form='der', lemma='der', feats='{'Case': 'Dat', 'Definite': 'Def', 'Gender': 'Fem', 'Number': 'Sing', 'PronType': 'Art'}'
Debug info: 13814 sentences, 37305 articles found, 21734 with gender

Sample tokens from first sentences:
  {'form': 'Sehr', 'upos': 'ADV', 'lemma': 'sehr', 'feats': None}
  {'form': 'gute', 'upos': 'ADJ', 'lemma': 'gut', 'feats': "{'Case': 'Nom', 'Degree': 'Pos', 'Gender': 'Fem', 'Number': 'Sing'}"}
  {'form': 'Beratung', 'upos': 'NOUN', 'lemma'

In [87]:
dataset['substitution_stimuli'][:10]

[{'sentence': 'Der Sage nach wurden die Nelken 1270 von dem Mann des französischen Königs Ludwig IX .',
  'template': 'Der Sage nach wurden die Nelken 1270 von dem [NOUN] des französischen Königs Ludwig IX .',
  'article': 'dem',
  'article_pos': 8,
  'substituted_noun': 'Mann',
  'substituted_gender': 'm',
  'case': 'dat',
  'original_noun': 'Heer',
  'original_gender': 'n',
  'is_grammatical': True,
  'is_substitution': True},
 {'sentence': 'Der Sage nach wurden die Nelken 1270 von dem Hund des französischen Königs Ludwig IX .',
  'template': 'Der Sage nach wurden die Nelken 1270 von dem [NOUN] des französischen Königs Ludwig IX .',
  'article': 'dem',
  'article_pos': 8,
  'substituted_noun': 'Hund',
  'substituted_gender': 'm',
  'case': 'dat',
  'original_noun': 'Heer',
  'original_gender': 'n',
  'is_grammatical': True,
  'is_substitution': True},
 {'sentence': 'Der Sage nach wurden die Nelken 1270 von dem Tisch des französischen Königs Ludwig IX .',
  'template': 'Der Sage nach 

In [80]:
import torch
import pandas as pd
import numpy as np
from transformers import AutoModel, AutoTokenizer

def extract_article_representations(model, tokenizer, stimuli_list):
    """Extracts article representations from UD-based substitution stimuli"""
    data = []
    
    for stimulus in stimuli_list:
        # Extracts information from stimulus dictionary
        sentence = stimulus['sentence']
        article = stimulus['article']
        article_pos = stimulus['article_pos']  # Already provided in your data
        substituted_noun = stimulus['substituted_noun']
        substituted_gender = stimulus['substituted_gender']
        case = stimulus['case']
        is_grammatical = stimulus['is_grammatical']
        is_substitution = stimulus['is_substitution']
        
        # Tokenized sentence
        inputs = tokenizer(sentence, return_tensors="pt")
        tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
        
        # Finding actual article position in tokenized sequence
        try:
            art_token_pos = None
            
            search_range = range(max(0, article_pos - 2), min(len(tokens), article_pos + 3))
            for pos in search_range:
                if tokens[pos].lower().replace('##', '') == article.lower():
                    art_token_pos = pos
                    break
            
            # If not found, search the entire sequence
            if art_token_pos is None:
                for pos, token in enumerate(tokens):
                    if token.lower().replace('##', '') == article.lower():
                        art_token_pos = pos
                        break
            
            if art_token_pos is None:
                print(f"Warning: Could not find article '{article}' in tokens: {tokens}")
                continue
                
        except Exception as e:
            print(f"Error processing sentence: {sentence}")
            print(f"Error: {e}")
            continue
        
        # Extract hidden states
        with torch.no_grad():
            outputs = model(**inputs, output_hidden_states=True)

        
        for transformer_layer in range(model.config.num_hidden_layers):  # 0 to 11 for mBERT
            # Getting the corresponding hidden state (add 1 to skip embeddings)
            hidden_state_idx = transformer_layer + 1
            layer_repr = outputs.hidden_states[hidden_state_idx][0].numpy()
            
            # Processing each attention head
            for head in range(model.config.num_attention_heads):
                head_dim = layer_repr.shape[-1] // model.config.num_attention_heads
                start = head * head_dim
                end = (head + 1) * head_dim
                head_repr = layer_repr[art_token_pos, start:end]
                
                # Finds noun position (should be close to article)
                noun_token_pos = None
                noun_token = None
                
                # Looks for noun around article position
                for offset in [1, 2, -1, 0]:  # Check positions relative to article
                    pos = art_token_pos + offset
                    if 0 <= pos < len(tokens):
                        token = tokens[pos].replace('##', '')
                        if token.lower() == substituted_noun.lower():
                            noun_token_pos = pos
                            noun_token = tokens[pos]
                            break
                
                data.append({
                    "sentence": sentence,
                    "article": article,
                    "article_position": art_token_pos,
                    "noun": substituted_noun,
                    "noun_token": noun_token,
                    "noun_position": noun_token_pos,
                    "noun_gender": substituted_gender,
                    "case": case,
                    "is_grammatical": is_grammatical,
                    "is_substitution": is_substitution,
                    "original_gender": stimulus.get('original_gender'),
                    "layer": transformer_layer,  # ✅ FIXED: Now correctly 0-11
                    "head": head,
                    "representation": head_repr,
                    "template": stimulus.get('template', '')
                })
    
    return pd.DataFrame(data)

def analyze_article_sensitivity(df):
    """Analyzes article sensitivity to noun gender from representation data"""
    from sklearn.linear_model import LogisticRegression
    from sklearn.model_selection import cross_val_score
    from sklearn.metrics import silhouette_score
    from sklearn.cluster import KMeans
    from sklearn.decomposition import PCA
    from sklearn.preprocessing import StandardScaler  
    
    results = []
    
    for (layer, head, article, case), group in df.groupby(["layer", "head", "article", "case"]):
        if len(group) < 20:  
            continue
        
        # Prepare data
        X = np.stack(group["representation"])
        y_gender = group["noun_gender"]
        y_grammatical = group["is_grammatical"]
        
        scaler = StandardScaler()
        X_scaled = scaler.fit_transform(X)
        
        # Gender prediction accuracy
        try:
            gender_scores = cross_val_score(
                LogisticRegression(max_iter=2000, solver='liblinear', random_state=42), 
                X_scaled, y_gender, cv=min(3, len(set(y_gender)))  # ✅ FIXED: CV can't exceed classes
            )
            gender_accuracy = np.mean(gender_scores)
        except Exception as e:
            print(f"Gender classification error for layer {layer}, head {head}: {e}")
            gender_accuracy = 0.0
        
        # Grammaticality prediction accuracy
        try:
            grammar_scores = cross_val_score(
                LogisticRegression(max_iter=2000, solver='liblinear', random_state=42), 
                X_scaled, y_grammatical, cv=min(3, len(set(y_grammatical)))  # ✅ FIXED: CV can't exceed classes
            )
            grammar_accuracy = np.mean(grammar_scores)
        except Exception as e:
            print(f"Grammar classification error for layer {layer}, head {head}: {e}")
            grammar_accuracy = 0.0
        
        # Clustering analysis
        try:
            if len(set(y_gender)) >= 2:  # Need at least 2 genders
                n_clusters = min(len(set(y_gender)), len(X_scaled))  # ✅ FIXED: Can't have more clusters than samples
                kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
                cluster_labels = kmeans.fit_predict(X_scaled)
                silhouette = silhouette_score(X_scaled, cluster_labels)
            else:
                silhouette = 0.0
        except Exception as e:
            print(f"Clustering error for layer {layer}, head {head}: {e}")
            silhouette = 0.0
        
        # Variance analysis
        try:
            total_var = np.trace(np.cov(X_scaled.T))
            gender_vars = []
            for gender in set(y_gender):
                gender_data = X_scaled[y_gender == gender]
                if len(gender_data) > 1:
                    gender_vars.append(np.trace(np.cov(gender_data.T)))
            
            if gender_vars:
                within_var = np.mean(gender_vars)
                variance_ratio = total_var / within_var if within_var > 0 else 0
            else:
                variance_ratio = 0
        except Exception as e:
            print(f"Variance analysis error for layer {layer}, head {head}: {e}")
            variance_ratio = 0
        
        results.append({
            "layer": layer,
            "head": head,
            "article": article,
            "case": case,
            "n_examples": len(group),
            "gender_accuracy": gender_accuracy,
            "grammar_accuracy": grammar_accuracy,
            "silhouette_score": silhouette,
            "variance_ratio": variance_ratio,
            "agreement_score": (gender_accuracy + grammar_accuracy + silhouette) / 3
        })
    
    return pd.DataFrame(results).sort_values("agreement_score", ascending=False)


# Example usage:
def run_full_analysis():
    """Running complete analysis pipeline"""
    model = AutoModel.from_pretrained('bert-base-multilingual-cased', output_attentions=True)
    tokenizer = AutoTokenizer.from_pretrained('bert-base-multilingual-cased')

    print("Extracting representations...")
    repr_df = extract_article_representations(model, tokenizer, dataset['substitution_stimuli'])
    print(f"Extracted representations for {len(repr_df)} examples")
    
    print("\nAnalyzing gender sensitivity...")
    analysis_df = analyze_article_sensitivity(repr_df)
    print(f"Analyzed {len(analysis_df)} layer-head combinations")
    
    print("\nTop 10 gender-sensitive heads:")
    print(analysis_df.head(10)[['layer', 'head', 'article', 'case', 'gender_accuracy', 'agreement_score']])
    
    return repr_df, analysis_df


In [81]:
df1, df2 = run_full_analysis()

Extracting representations...
Extracted representations for 115200 examples

Analyzing gender sensitivity...
Analyzed 1440 layer-head combinations

Top 10 gender-sensitive heads:
     layer  head article case  gender_accuracy  agreement_score
186      1     6     der  nom         0.825261         0.683244
546      4     6     der  nom         0.923552         0.666211
181      1     6     das  nom         0.838557         0.662879
311      2     7     das  nom         0.887464         0.653940
286      2     4     der  nom         0.788224         0.653716
211      1     9     das  nom         0.825261         0.653599
510      4     3     das  acc         0.912156         0.653019
111      0    11     das  nom         0.750712         0.652995
246      2     0     der  nom         0.824786         0.652186
176      1     5     der  nom         0.787749         0.652101


In [82]:
import pandas as pd
import numpy as np
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import cross_val_score
from collections import defaultdict


def probe_fixed_article_case_context(repr_df, min_examples=20):
    """
    H1 (Article Gender): Representations constant → Low accuracy (~33%), only morphological gender is attended to
    H2 (Noun Gender): Representations vary by noun → High accuracy (>70%), context determines the gender of the article, specifically the noun
    
    For each (layer, head, article, case) context, probe whether representations
    encode article's gender (H1) or noun's gender (H2).
    
    """
    results = []
    
    print("Probing fixed (article, case) contexts...")
    print("H1: Article's morphological gender encoding:: Low accuracy (~33%)")  
    print("H2: Noun affects gender encoding:: High accuracy (>70%)")
    
    # Group by (layer, head, article, case) - the fixed context
    for (layer, head, article, case), group in repr_df.groupby(['layer', 'head', 'article', 'case']):
        
        if len(group) < min_examples:
            continue
            
        gender_counts = group['noun_gender'].value_counts()
        if len(gender_counts) < 2:  # At least 2 different genders are needed from the stimuli dataset, with varying noun genders
            continue
            
       
        X = np.stack(group['representation'])  # Extracting article representations
        y = group['noun_gender']               # Noun genders (target)
        
        # Training a simple Logistic Regression classifier to predict noun gender from article representation
        try:
            from sklearn.preprocessing import StandardScaler
            
            # Scaling the representations for convergence
            scaler = StandardScaler()
            X_scaled = scaler.fit_transform(X)
            
            classifier = LogisticRegression(
                max_iter=2000,          
                solver='liblinear',     
                random_state=42         
            )
            
            scores = cross_val_score(
                classifier, 
                X_scaled,  # Using scaled data
                y, 
                cv=min(5, len(set(y))),  
                scoring='accuracy'
            )
            accuracy = np.mean(scores)
            
            # Accuracy determines hypothesis 
            if accuracy > 0.9:
                hypothesis = "H2 (Noun Gender)"
                evidence_strength = "Very strong"
            if accuracy > 0.7:
                hypothesis = "H2 (Noun Gender)"
                evidence_strength = "Strong"
            elif accuracy > 0.5:
                hypothesis = "H2 (Noun Gender)" 
                evidence_strength = "Moderate"
            elif accuracy < 0.4:
                hypothesis = "H1 (Article Gender)"
                evidence_strength = "Strong" if accuracy < 0.35 else "Moderate"
            else:
                hypothesis = "Unclear"
                evidence_strength = "Weak"
            
            results.append({
                'layer': layer,
                'head': head,
                'article': article,
                'case': case,
                'n_examples': len(group),
                'n_genders': len(gender_counts),
                'gender_distribution': dict(gender_counts),
                'accuracy': accuracy,
                'hypothesis': hypothesis,
                'evidence_strength': evidence_strength,
                'grammatical_count': len(group[group['is_grammatical'] == True]),
                'ungrammatical_count': len(group[group['is_grammatical'] == False])
            })
            
        except Exception as e:
            print(f"Error processing {layer}, {head}, {article}, {case}: {e}")
            continue
    
    return pd.DataFrame(results)

def analyze_hypothesis_results(results_df):
    """Analyze the results of the hypothesis testing"""
    
    print("\nHypothesis Testing Results:")
    print("="*80)
    
    hypothesis_counts = results_df['hypothesis'].value_counts()
    print(f"\nOverall Distribution:")
    for hyp, count in hypothesis_counts.items():
        pct = (count / len(results_df)) * 100
        print(f"  {hyp}: {count} contexts ({pct:.1f}%)")

    strong_evidence = results_df[results_df['evidence_strength'] == 'Strong']
    print(f"\nStrong Evidence Cases ({len(strong_evidence)} contexts):")
    
    strong_h1 = strong_evidence[strong_evidence['hypothesis'] == 'H1 (Article Gender)']
    strong_h2 = strong_evidence[strong_evidence['hypothesis'] == 'H2 (Noun Gender)']
    
    print(f"  Strong H1 (Article Gender) evidence: {len(strong_h1)} contexts")
    print(f"  Strong H2 (Noun Gender) evidence: {len(strong_h2)} contexts")
    
    # Grouping by layer, hypothesis combination for layer-wise analysis
    print(f"\nHypothesis by Layer:")
    layer_analysis = results_df.groupby(['layer', 'hypothesis']).size().unstack(fill_value=0)
    print(layer_analysis)
    
    print(f"\nTop 10 Noun Gender Encoding Contexts (of H2 context):")
    top_h2 = results_df[results_df['hypothesis'] == 'H2 (Noun Gender)'].sort_values('accuracy', ascending=False).head(10)
    
    print("Layer Head Article Case  Accuracy  N_Examples  Evidence")
    for _, row in top_h2.iterrows():
        print(f"{row['layer']:5d} {row['head']:4d} {row['article']:7s} {row['case']:4s}  "
              f"{row['accuracy']:8.3f}  {row['n_examples']:10d}  {row['evidence_strength']}")
    
    return {
        'hypothesis_counts': hypothesis_counts,
        'strong_h1': strong_h1,
        'strong_h2': strong_h2,
        'layer_analysis': layer_analysis
    }

def find_consistent_noun_gender_heads(results_df, min_contexts=3):
    """Grouping for heads that consistently encode noun gender across multiple contexts"""
    
    # Group by (layer, head) and counting for H2 contexts
    head_summary = results_df.groupby(['layer', 'head']).agg({
        'hypothesis': lambda x: (x == 'H2 (Noun Gender)').sum(),  # Count H2
        'accuracy': 'mean',
        'article': 'count'  # Total contexts
    }).rename(columns={'hypothesis': 'h2_count', 'article': 'total_contexts'})
    
    # H2 percentage calculated
    head_summary['h2_percentage'] = (head_summary['h2_count'] / head_summary['total_contexts']) * 100
    
    consistent_heads = head_summary[
        (head_summary['total_contexts'] >= min_contexts) &
        (head_summary['h2_percentage'] >= 90)  # At least 70% H2 contexts
    ].sort_values('h2_percentage', ascending=False)
    
    print(f"\nConsistent Noun Gender Encoding Heads (≥{min_contexts} contexts, ≥70% H2):")
    print("Layer Head  Total_Contexts  H2_Count  H2_Percentage  Avg_Accuracy")
    print("-" * 65)
    
    for (layer, head), row in consistent_heads.iterrows():
        print(f"{layer:5d} {head:4d}  {row['total_contexts']:13.0f}  "
              f"{row['h2_count']:8.0f}  {row['h2_percentage']:12.1f}%  "
              f"{row['accuracy']:12.3f}")
    
    return consistent_heads

def run_fixed_context_probe(repr_df):
    """Running the complete fixed context probing analysis"""
    
    print("Running Fixed Article-Case Context Gender Probe")
    
    results = probe_fixed_article_case_context(repr_df)
    
    if len(results) == 0:
        print("No valid contexts found!")
        return None
    
    print(f"\nAnalyzed {len(results)} (layer, head, article, case) contexts")
    
    # Analyzing results
    analysis = analyze_hypothesis_results(results)
    
    # Finding consistent heads
    consistent_heads = find_consistent_noun_gender_heads(results)
    
    
    return {
        'results': results,
        'analysis': analysis,
        'consistent_heads': consistent_heads
    }

# Usage example:
# probe_results = run_fixed_context_probe(repr_df)

In [83]:
probe_results = run_fixed_context_probe(df1)

Running Fixed Article-Case Context Gender Probe
Probing fixed (article, case) contexts...
H1: Article gender encoding → Low accuracy (~33%)
H2: Noun gender encoding → High accuracy (>70%)

Analyzed 1440 (layer, head, article, case) contexts

Hypothesis Testing Results:

Overall Distribution:
  H2 (Noun Gender): 1304 contexts (90.6%)
  Unclear: 123 contexts (8.5%)
  H1 (Article Gender): 13 contexts (0.9%)

Strong Evidence Cases (613 contexts):
  Strong H1 (Article Gender) evidence: 0 contexts
  Strong H2 (Noun Gender) evidence: 613 contexts

Hypothesis by Layer:
hypothesis  H1 (Article Gender)  H2 (Noun Gender)  Unclear
layer                                                     
0                             1                91       28
1                             2               108       10
2                             4               108        8
3                             3               108        9
4                             0               115        5
5                  

In [84]:
#sorted_data = sorted(probe_results["consistent_heads"], key=lambda x: x["accuracy"], reverse=True)
df_sorted = probe_results["consistent_heads"].sort_values(by='accuracy', ascending=False)


In [85]:
df_sorted[:10]

Unnamed: 0_level_0,Unnamed: 1_level_0,h2_count,accuracy,total_contexts,h2_percentage
layer,head,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
8,5,9,0.749573,10,90.0
6,0,10,0.748528,10,100.0
6,4,10,0.743732,10,100.0
11,7,10,0.734046,10,100.0
8,0,10,0.732146,10,100.0
6,9,9,0.731102,10,90.0
6,1,10,0.730769,10,100.0
6,6,9,0.72982,10,90.0
6,10,10,0.728965,10,100.0
10,7,9,0.727208,10,90.0
