# Probing experiments

## Set-up

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import transformers
from sklearn.linear_model import LogisticRegression, LogisticRegressionCV
from sklearn.model_selection import cross_validate, StratifiedKFold, StratifiedShuffleSplit
import random
import pickle
import warnings

In [2]:
transformers.logging.set_verbosity_error()

## Utilites

In [3]:
def probing_experiment(df, colname='break_rep_layer12', target_colname='meaning', control=False, mincount=10):
    y_dist = df[target_colname].value_counts()
    keepers = list(y_dist[y_dist >= mincount].index)
    probe_df = df[df[target_colname].isin(keepers)]
    X = torch.vstack(list(probe_df[colname].values)).numpy()
    y = list(probe_df[target_colname].values)
    if control:
        random.seed(42)
        random.shuffle(y)
    with warnings.catch_warnings():
        warnings.simplefilter("ignore")
        #mod = LogisticRegressionCV(max_iter=3000, fit_intercept=True, penalty='l2', Cs=2)
        mod = LogisticRegression(max_iter=3000, fit_intercept=True, penalty='l2', C=10000)
        scores = cross_validate(
            estimator=mod,
            X=X,
            y=y,
            cv=StratifiedShuffleSplit(n_splits=20, test_size=0.20, random_state=42),
            return_estimator=True,
            scoring='f1_macro')
    return scores

In [4]:
def run_probing_experiment(weights_name, colname='break_rep_layer12', target_colname='meaning', mincount=10):
    with open(f"reps/{weights_name.replace('/', '_')}_df.pickle", "rb") as f:
        df = pickle.load(f)
    results = probing_experiment(df, colname=colname, target_colname=target_colname, control=False, mincount=mincount)
    control = probing_experiment(df, colname=colname, target_colname=target_colname, control=True, mincount=mincount)
    results['model'] = weights_name
    results['layer'] = int(colname.replace("break_rep_layer", ""))
    results['control'] = control['test_score']
    results['mincount'] = mincount
    return pd.DataFrame(results)

In [5]:
all_weights = (
    'bert-base-cased',
    'bert-large-cased',
    'roberta-base',
    'roberta-large',
    'microsoft/deberta-base',
    'microsoft/deberta-large',
    'microsoft/deberta-v3-base',
    'microsoft/deberta-v3-large'
)

In [6]:
def run_all_probing_experiments(target_colname='meaning'):      
    all_results = []

    for weights in all_weights:
        print(weights)
        if 'large' in weights:
            layers = (1, 6, 12, 18, 24)
        else:
            layers = (1, 6, 12)
        for layer in layers:
            results = run_probing_experiment(weights, colname=f'break_rep_layer{layer}', target_colname=target_colname)               
            all_results.append(results)
    results_df = pd.concat(all_results)
    mu_df = results_df.groupby(["model", "layer"]).apply(
        lambda x: pd.Series([x.test_score.mean(), x.control.mean(), (x.test_score - x.control).mean()]))
    mu_df.rename(columns={0: 'Probe', 1: 'Control', 2: 'Selectivity'}, inplace=True)
    return results_df, mu_df

## Meaning probes

In [7]:
meaning_results_df, meaning_mu_df = run_all_probing_experiments(target_colname='meaning')

bert-base-cased
bert-large-cased
roberta-base
roberta-large
microsoft/deberta-base
microsoft/deberta-large
microsoft/deberta-v3-base
microsoft/deberta-v3-large


In [8]:
meaning_results_df.to_csv("results/probes-meaning.csv", index=None)

meaning_mu_df.to_csv("results/probes-meaning-means.csv", index=None)

In [9]:
print(meaning_mu_df.round(2).to_latex())

\begin{tabular}{llrrr}
\toprule
              &    &  Probe &  Control &  Selectivity \\
model & layer &        &          &              \\
\midrule
bert-base-cased & 1  &   0.64 &     0.04 &         0.60 \\
              & 6  &   0.80 &     0.03 &         0.77 \\
              & 12 &   0.81 &     0.03 &         0.78 \\
bert-large-cased & 1  &   0.65 &     0.04 &         0.61 \\
              & 6  &   0.78 &     0.03 &         0.75 \\
              & 12 &   0.83 &     0.03 &         0.80 \\
              & 18 &   0.83 &     0.03 &         0.81 \\
              & 24 &   0.84 &     0.03 &         0.81 \\
microsoft/deberta-base & 1  &   0.72 &     0.03 &         0.68 \\
              & 6  &   0.81 &     0.03 &         0.78 \\
              & 12 &   0.85 &     0.03 &         0.82 \\
microsoft/deberta-large & 1  &   0.72 &     0.04 &         0.68 \\
              & 6  &   0.84 &     0.03 &         0.81 \\
              & 12 &   0.81 &     0.04 &         0.78 \\
              & 18 &   0.78 

## Construction probes

In [10]:
construction_results_df, construction_mu_df = run_all_probing_experiments(target_colname='construction')

bert-base-cased
bert-large-cased
roberta-base
roberta-large
microsoft/deberta-base
microsoft/deberta-large
microsoft/deberta-v3-base
microsoft/deberta-v3-large


In [11]:
construction_results_df.to_csv("results/probes-construction.csv", index=None)

construction_mu_df.to_csv("results/probes-construction-means.csv", index=None)

In [12]:
construction_mu_df

Unnamed: 0_level_0,Unnamed: 1_level_0,Probe,Control,Selectivity
model,layer,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
bert-base-cased,1,0.747368,0.342704,0.404664
bert-base-cased,6,0.933982,0.338092,0.59589
bert-base-cased,12,0.95359,0.327532,0.626058
bert-large-cased,1,0.720326,0.333306,0.387021
bert-large-cased,6,0.910715,0.344671,0.566044
bert-large-cased,12,0.943793,0.328348,0.615445
bert-large-cased,18,0.966229,0.349327,0.616902
bert-large-cased,24,0.967429,0.334955,0.632474
microsoft/deberta-base,1,0.881986,0.339654,0.542332
microsoft/deberta-base,6,0.957958,0.34143,0.616528


In [13]:
print(construction_mu_df.round(2).to_latex())

\begin{tabular}{llrrr}
\toprule
              &    &  Probe &  Control &  Selectivity \\
model & layer &        &          &              \\
\midrule
bert-base-cased & 1  &   0.75 &     0.34 &         0.40 \\
              & 6  &   0.93 &     0.34 &         0.60 \\
              & 12 &   0.95 &     0.33 &         0.63 \\
bert-large-cased & 1  &   0.72 &     0.33 &         0.39 \\
              & 6  &   0.91 &     0.34 &         0.57 \\
              & 12 &   0.94 &     0.33 &         0.62 \\
              & 18 &   0.97 &     0.35 &         0.62 \\
              & 24 &   0.97 &     0.33 &         0.63 \\
microsoft/deberta-base & 1  &   0.88 &     0.34 &         0.54 \\
              & 6  &   0.96 &     0.34 &         0.62 \\
              & 12 &   0.97 &     0.32 &         0.64 \\
microsoft/deberta-large & 1  &   0.86 &     0.33 &         0.53 \\
              & 6  &   0.96 &     0.33 &         0.63 \\
              & 12 &   0.96 &     0.33 &         0.64 \\
              & 18 &   0.95 