# Probing analysis

## Setup

In [None]:
import os
import experiment.db_utils as db
import json
import csv
import re
from pprint import pprint
from collections import defaultdict
import probing.utils
high_nodes = list(probing.utils.HIGH_NODE_LABEL_SPACE.keys())

print(high_nodes)


## Load data

In [None]:
bert_easy_probing_rows = db.select("data/probing/probing-bert-easy.db", "results", cond_dict={"status":1})
print(len(bert_easy_probing_rows))
bert_easy_interx_rows = db.select("data/interchange/bert-easy.db", "results", cond_dict={"status": 2})
print(len(bert_easy_interx_rows))

bert_hard_probing_rows = db.select("data/probing/probing-bert-hard.db", "results", cond_dict={"status": 1})
print(len(bert_hard_probing_rows))
bert_hard_interx_rows = db.select("data/interchange/bert-hard.db", "results", cond_dict={"status":2})
print(len(bert_hard_interx_rows))

In [None]:
lstm_easy_probing_rows = db.select("data/probing/probing-lstm-easy.db", "results", cond_dict={"status": 1})
print(len(lstm_easy_probing_rows))
lstm_easy_interx_rows = db.select("data/interchange/lstm-easy.db", "results", cond_dict={"status": 2})
print(len(lstm_easy_interx_rows))

lstm_hard_probing_rows = db.select("data/probing/probing-lstm-hard.db", "results", cond_dict={"status": 1})
print(len(lstm_hard_probing_rows))
lstm_hard_interx_rows = db.select("data/interchange/lstm-hard.db", "results", cond_dict={"status": 2})
print(len(lstm_hard_interx_rows))

## Copy probing csv results from cluster

In [None]:
def pull_data_from_cluster(rows):
    for row in rows:
        res_save_path = row["res_save_path"]
        res_save_dir = row["res_save_dir"]
        if not os.path.exists(res_save_dir):
            os.mkdir(res_save_dir)
        src_path = os.path.join("Interchange", res_save_path)
        
        os.system(f"scp hansonlu@sc.stanford.edu:{src_path} {res_save_path}")
        print(f"successfully pulled data from {src_path}")

In [None]:
# pull_data_from_cluster(bert_easy_probing_rows)
# pull_data_from_cluster(lstm_easy_probing_rows)

# pull_data_from_cluster(bert_hard_probing_rows)
# pull_data_from_cluster(lstm_hard_probing_rows)


## Find highest accuracy among all grid search instances

In [None]:
def aggregate_selectivity(rows):
    selectivity = {high_node: defaultdict(lambda: [0. for _ in range(len(rows))]) for high_node in high_nodes}

    for i, row in enumerate(rows):
        with open(row["res_save_path"], "r") as f:
            reader = csv.DictReader(f)
            for row in reader:
                high_node = row["high_node"]
                probe = (row["low_node"], row["low_loc"])
                
                train_acc = float(row["train_acc"]) * (-1. if row["is_control"] == "True" else 1.)
                
                selectivity[high_node][probe][i] += train_acc
                
    max_selectivity = {high_node: {} for high_node in high_nodes}
    
    for high_node, d in selectivity.items():
        for probe, vals in d.items():
            for val in vals:
                max_selectivity[high_node][probe] = max(max_selectivity[high_node].get(probe, -float("inf")), val)

    return max_selectivity


def aggregate_accuracy(rows):
    accuracy = {high_node: defaultdict(lambda: [0. for _ in range(len(rows))]) for high_node in high_nodes}

    for i, row in enumerate(rows):
        with open(row["res_save_path"], "r") as f:
            reader = csv.DictReader(f)
            for row in reader:
                if row["is_control"] == "True": continue
    
                high_node = row["high_node"]
                probe = (row["low_node"], row["low_loc"])
                
                train_acc = float(row["train_acc"])
                
                accuracy[high_node][probe][i] = train_acc
                
    max_accuracy = {high_node: {} for high_node in high_nodes}
    
    for high_node, d in accuracy.items():
        for probe, vals in d.items():
            for val in vals:
                max_accuracy[high_node][probe] = max(max_accuracy[high_node].get(probe, -float("inf")), val)

    return max_accuracy

# def aggregate_accuracy_results(rows):
#     acc_all_res = {high_node: defaultdict(list) for high_node in high_nodes}
#     acc_max_res = {high_node: defaultdict(float) for high_node in high_nodes}
    
#     selectivity = {high_node: defaultdict(lambda: [0. for _ in range(len(rows))]) for high_node in high_nodes}

#     for i, row in enumerate(rows):
#         with open(row["res_save_path"], "r") as f:
#             reader = csv.DictReader(f)
#             for row in reader:
#                 high_node = row["high_node"]
#                 probe = (row["low_node"], row["low_loc"])
                
#                 train_acc = float(row["train_acc"]) * (-1. if row["is_control"] == "True" else 1.)
                
#                 selectivity[high_node][probe][i] += train_acc
                
#     max_selectivity = {high_node: {} for high_node in high_nodes}
    
#     for high_node, d in selectivity.items():
#         for probe, vals in d.items():
#             for val in vals:
#                 max_selectivity[high_node][probe] = max(max_selectivity[high_node].get(probe, -float("inf")), val)

#     return max_selectivity

    

In [None]:
bert_easy_probing_max_sel = aggregate_selectivity(bert_easy_probing_rows)
lstm_easy_probing_max_sel = aggregate_selectivity(lstm_easy_probing_rows)

bert_hard_probing_max_sel = aggregate_selectivity(bert_hard_probing_rows)
lstm_hard_probing_max_sel = aggregate_selectivity(lstm_hard_probing_rows)

bert_hard_probing_max_acc = aggregate_accuracy(bert_hard_probing_rows)


## Analyze Probing Grid Search results

In [None]:
def aggregate_by_grid_search(rows):
    sums = [0. for _ in range(len(rows))]
    counts = [0 for _ in range(len(rows))]
    for i, db_row in enumerate(rows):
        with open(db_row["res_save_path"], "r") as f:
            reader = csv.DictReader(f)
            for row in reader:
                train_acc = float(row["train_acc"]) * (-1. if row["is_control"] == "True" else 1.)
                sums[i] += train_acc
                if row["is_control"] == "True":
                    counts[i] += 1
                
    return [s/c for s, c in zip(sums, counts)]

def compare_lr(values):
    low_lr_sum = 0
    high_lr_sum = 0

    for i in range(0, len(values), 4):
        low_lr_sum += values[i]
        low_lr_sum += values[i+1]
        high_lr_sum += values[i+2]
        high_lr_sum += values[i+3]
    
    return low_lr_sum, high_lr_sum

def compare_wn(values):
    low_wn_sum = 0
    high_wn_sum = 0
    
    for i in range(0, len(values), 2):
        low_wn_sum += values[i]
        high_wn_sum += values[i+1]
    return low_wn_sum , high_wn_sum

In [None]:
bert_easy_gs_sel = aggregate_by_grid_search(bert_easy_probing_rows)
lstm_easy_gs_sel = aggregate_by_grid_search(lstm_easy_probing_rows)
bert_hard_gs_sel = aggregate_by_grid_search(bert_hard_probing_rows)
lstm_hard_gs_sel = aggregate_by_grid_search(lstm_hard_probing_rows)
pprint(sum(lstm_easy_gs_sel[::2]))
pprint(sum(lstm_easy_gs_sel[1::2]))

print(compare_lr(bert_easy_gs_sel[8:16]))
print(compare_lr(bert_hard_gs_sel[8:16]))
print(compare_lr(lstm_easy_gs_sel[8:16]))
print(compare_lr(lstm_hard_gs_sel[8:16]))
print("")
print(compare_wn(bert_easy_gs_sel[8:16]))
print(compare_wn(bert_hard_gs_sel[8:16]))
print(compare_wn(lstm_easy_gs_sel[8:16]))
print(compare_wn(lstm_hard_gs_sel[8:16]))

## Organize intervention results

In [None]:
def aggregate_success_rates(rows):
    causal_success_rates = {high_node: defaultdict(float) for high_node in high_nodes}
    total_success_rates = {high_node: defaultdict(float) for high_node in high_nodes}
    
    for i, row in enumerate(rows):
        abstraction = json.loads(row["abstraction"])
        high_node = abstraction[0]
        low_node = abstraction[1][0]
        mappings = json.loads(row["mappings"])
        locs = []
        res_2_counts = json.loads(row["res_2_counts"])
        res_3_counts = json.loads(row["res_3_counts"])
        res_6_counts = json.loads(row["res_6_counts"])
        res_7_counts = json.loads(row["res_7_counts"])
        for j, mapping in enumerate(mappings):
            low_loc_str = mapping[high_node][low_node]
            low_loc = int(re.findall('\d+', low_loc_str)[0])
            causal_success_rate = res_3_counts[j] / (res_2_counts[j] + res_3_counts[j])
            total_success_rate = (res_3_counts[j] + res_7_counts[j]) / (res_2_counts[j] + res_3_counts[j] + res_6_counts[j] + res_7_counts[j])
            causal_success_rates[high_node][(low_node, low_loc)] = causal_success_rate
            total_success_rates[high_node][(low_node, low_loc)] = total_success_rate
    
    return causal_success_rates, total_success_rates


def aggregate_clique_sizes(rows):
    clique_sizes = {high_node: defaultdict(float) for high_node in high_nodes}
    
    for i, row in enumerate(rows):
        abstraction = json.loads(row["abstraction"])
        high_node = abstraction[0]
        low_node = abstraction[1][0]
        mappings = json.loads(row["mappings"])
        locs = []
        row_clique_sizes = json.loads(row["max_clique_sizes"])
        for j, mapping in enumerate(mappings):
            low_loc_str = mapping[high_node][low_node]
            low_loc = int(re.findall('\d+', low_loc_str)[0])

            clique_sizes[high_node][(low_node, low_loc)] = row_clique_sizes[j]
    
    return clique_sizes

In [None]:
bert_easy_interx_causal, bert_easy_interx_total = aggregate_success_rates(bert_easy_interx_rows)
lstm_easy_interx_causal, lstm_easy_interx_total = aggregate_success_rates(lstm_easy_interx_rows)

bert_hard_interx_causal, bert_hard_interx_total = aggregate_success_rates(bert_hard_interx_rows)
lstm_hard_interx_causal, lstm_hard_interx_total = aggregate_success_rates(lstm_hard_interx_rows)

bert_easy_clq_sizes = aggregate_clique_sizes(bert_easy_interx_rows)
bert_hard_clq_sizes = aggregate_clique_sizes(bert_hard_interx_rows)
lstm_easy_clq_sizes = aggregate_clique_sizes(lstm_easy_interx_rows)
lstm_hard_clq_sizes = aggregate_clique_sizes(lstm_hard_interx_rows)

## Tools for Heatmap Plotting

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import re

def test_heatmap():
    a = np.random.random((16, 16))
    ax = sns.heatmap(a, cmap="YlGnBu")

def get_loc_name(loc, model_name):
    loc_to_name = ["[CLS]", 
                  "Q_S_p(1)", "Q_S_p(2)", "Adj_S_p", "N_S_p", "Neg_p(1)", "Neg_p(2)", 
                   "Adv_p", "V_p", "Q_O_p(1)", "Q_O_p(2)", "Adj_O_p", "N_O_p",
                  "[SEP]",
                  "Q_S_h(1)", "Q_S_h(2)", "Adj_S_h", "N_S_h", "Neg_h(1)", "Neg_h(2)", 
                   "Adv_h", "V_h", "Q_O_h(1)", "Q_O_h(2)", "Adj_O_h", "N_O_h",
                  "[SEP]"]
#     if "lstm" in model_name.lower():
#         loc_to_name = ["Q_S_p", "Adj_S_p", "N_S_p", "Neg_p", "Adv_p", "V_p", "Q_O_p", "Adj_O_p", "N_O_p",
#                       "[SEP]",
#                       "Q_S_h", "Adj_S_h", "N_S_h", "Neg_h", "Adv_h", "V_h", "Q_O_h", "Adj_O_h", "N_O_h"]
    return loc_to_name[loc]

def heatmap(high_node, max_info, model_name, title="probing-sel", high_node_name=None):
    d = max_info[high_node]
    layers = set()
    locs = set()
    
    rand_baseline = 1 / probing.utils.HIGH_NODE_LABEL_SPACE[high_node]

    for layer, loc in d.keys():
        idx_layer = int(re.findall('\d+', layer)[0])
        if isinstance(loc, str):
            int_loc = int(re.findall('\d+', loc)[0])
        else:
            int_loc = loc
        layers.add(idx_layer)
        locs.add(int_loc)
    
    layers = sorted(layers)
    locs = sorted(locs)
    loc_names = [get_loc_name(loc, model_name) for loc in sorted(locs)]
    loc_to_idx = {l: i for i, l in enumerate(locs)}

    heatmap_data = np.zeros((len(layers), len(locs)))
    
    for (layer, loc), acc in d.items():
        idx_layer = int(re.findall('\d+', layer)[0])
        if isinstance(loc, str):
            int_loc = int(re.findall('\d+', loc)[0])
        else:
            int_loc = loc
        idx_loc = loc_to_idx[int_loc]
        heatmap_data[idx_layer, idx_loc] = acc
    
    xticklabels = loc_names
    yticklabels = layers
    
    color_dict = {"probing-sel": "YlGnBu", "probing-acc": "YlGnBu", 
                  "interx": "YlOrRd", "clq-size": "YlOrRd"}
    
    ax = sns.heatmap(heatmap_data, 
                     cmap=color_dict[title], 
                     xticklabels=xticklabels,
                     yticklabels=yticklabels)
    ax.invert_yaxis()
    high_node_name = high_node_name if high_node_name else high_node
    if title == "probing-sel":
        ax.set_title(f"Probing selectivity for {high_node_name} in {model_name}")
    elif title == "probing-acc":
        ax.set_title(f"Probing accuracy for {high_node_name} in {model_name}")
    elif title == "interx":
        ax.set_title(f"Interchange success rate for {high_node_name} in {model_name}")
    elif title == "clq-size":
        ax.set_title(f"Clique sizes for {high_node_name} in {model_name}")
    
    

## Probing accuracies for Bert-Easy

### Sentence_q

In [None]:
# test_heatmap()
heatmap("sentence_q", bert_hard_probing_max_sel, "Bert-Hard")

In [None]:
heatmap("sentence_q", bert_hard_clq_sizes, "Bert-Hard", "clq-size")

In [None]:
heatmap("sentence_q", bert_hard_interx_total, "Bert-Hard", "interx")

In [None]:
heatmap("sentence_q", lstm_hard_probing_max_sel, "LSTM-Hard")

In [None]:
heatmap("sentence_q", lstm_hard_interx_causal, "LSTM-Hard", "interx")

In [None]:
heatmap("sentence_q", bert_easy_probing_max_sel, "Bert-Easy")

In [None]:
heatmap("sentence_q", bert_easy_interx_causal, "Bert-Easy", "interx")

In [None]:
heatmap("sentence_q", lstm_easy_probing_max_sel, "LSTM-Easy")

### Subj

In [None]:
heatmap("subj", bert_hard_probing_max_sel, "Bert-Hard", high_node_name="NP_subj")

In [None]:
heatmap("subj", bert_hard_probing_max_acc, "Bert-Hard", "probing-acc", high_node_name="NP_subj")

In [None]:
heatmap("subj", bert_hard_interx_causal, "Bert-Hard", "interx", high_node_name="NP_subj")

In [None]:
heatmap("subj", lstm_hard_probing_max_sel, "LSTM-Hard")

In [None]:
heatmap("subj", lstm_hard_interx_causal, "LSTM-Hard", "interx")

In [None]:
heatmap("sentence_q", bert_easy_probing_max_sel, "Bert-Easy")

In [None]:
heatmap("sentence_q", bert_easy_interx_causal, "Bert-Easy", "interx")

### Subj_adj

In [None]:
heatmap("subj_adj", bert_hard_probing_max_sel, "Bert-Hard", high_node_name="Adj_subj")

In [None]:
heatmap("subj_adj", bert_hard_probing_max_acc, "Bert-Hard", "probing-acc", high_node_name="Adj_subj")

In [None]:
heatmap("subj_adj", lstm_hard_probing_max_sel, "LSTM-Hard")

In [None]:
heatmap("subj_adj", bert_hard_interx_causal, "Bert-Hard", "interx", high_node_name="Adj_subj")

In [None]:
heatmap("subj_adj", lstm_hard_interx_causal, "LSTM-Hard", "interx")

### subj_noun

In [None]:
heatmap("subj_noun", bert_hard_probing_max_sel, "Bert-Hard", high_node_name="N_subj")

In [None]:
heatmap("subj_noun", bert_hard_probing_max_acc, "Bert-Hard", "probing-acc", high_node_name="N_subj")

In [None]:
heatmap("subj_noun", lstm_hard_probing_max_sel, "LSTM-Hard")

In [None]:
heatmap("subj_noun", bert_hard_interx_causal, "Bert-Hard", "interx", high_node_name="N_subj")

In [None]:
heatmap("subj_noun", lstm_hard_interx_causal, "LSTM-Hard", "interx")

### negp

In [None]:
heatmap("negp", bert_hard_probing_max_sel, "Bert-Hard", high_node_name="NegP")

In [None]:
heatmap("negp", bert_hard_probing_max_acc, "Bert-Hard", "probing-acc", high_node_name="NegP")

In [None]:
heatmap("negp", lstm_hard_probing_max_sel, "LSTM-Hard")

In [None]:
heatmap("negp", bert_hard_interx_causal, "Bert-Hard", "interx",high_node_name="NegP")

In [None]:
heatmap("negp", lstm_hard_interx_causal, "LSTM-Hard", "interx")

### neg

In [None]:
heatmap("neg", bert_hard_probing_max_sel, "Bert-Hard")

In [None]:
heatmap("neg", lstm_hard_probing_max_sel, "LSTM-Hard")

In [None]:
heatmap("neg", bert_hard_interx_causal, "Bert-Hard", "interx")

In [None]:
heatmap("neg", lstm_hard_interx_causal, "LSTM-Hard", "interx")

### vp

In [None]:
heatmap("vp", bert_hard_probing_max_sel, "Bert-Hard", high_node_name="QP_Obj")

In [None]:
heatmap("vp", bert_hard_probing_max_acc, "Bert-Hard", "probing-acc", high_node_name="QP_Obj")

In [None]:
heatmap("vp", lstm_hard_probing_max_sel, "LSTM-Hard")

In [None]:
heatmap("vp", bert_hard_interx_causal, "Bert-Hard", "interx", high_node_name="QP_Obj")

In [None]:
heatmap("vp", lstm_hard_interx_causal, "LSTM-Hard", "interx")

### vp_q

In [None]:
heatmap("vp_q", bert_hard_probing_max_sel, "Bert-Hard")

In [None]:
heatmap("vp_q", lstm_hard_probing_max_sel, "LSTM-Hard")

In [None]:
heatmap("vp_q", bert_hard_interx_causal, "Bert-Hard", "interx")

In [None]:
heatmap("vp_q", lstm_hard_interx_causal, "LSTM-Hard", "interx")

### v_bar

In [None]:
heatmap("v_bar", bert_hard_probing_max_sel, "Bert-Hard", high_node_name="VP")

In [None]:
heatmap("v_bar", bert_hard_probing_max_acc, "Bert-Hard", "probing-acc", high_node_name="VP")

In [None]:
heatmap("v_bar", lstm_hard_probing_max_sel, "LSTM-Hard")

In [None]:
heatmap("v_bar", bert_hard_interx_causal, "Bert-Hard", "interx", high_node_name="VP")

In [None]:
heatmap("v_bar", lstm_hard_interx_causal, "LSTM-Hard", "interx")

### v_adv

In [None]:
heatmap("v_adv", bert_hard_probing_max_sel, "Bert-Hard", high_node_name="Adv")

In [None]:
heatmap("v_adv", bert_hard_probing_max_acc, "Bert-Hard", "probing-acc", high_node_name="Adv")

In [None]:
heatmap("v_adv", lstm_hard_probing_max_sel, "LSTM-Hard")

In [None]:
heatmap("v_adv", bert_hard_interx_causal, "Bert-Hard", "interx", high_node_name="Adv")

In [None]:
heatmap("v_adv", lstm_hard_interx_causal, "LSTM-Hard", "interx")

### v_verb

In [None]:
heatmap("v_verb", bert_hard_probing_max_sel, "Bert-Hard", high_node_name="V")

In [None]:
heatmap("v_verb", bert_hard_probing_max_acc, "Bert-Hard", "probing-acc", high_node_name="V")

In [None]:
heatmap("v_verb", lstm_hard_probing_max_sel, "LSTM-Hard")

In [None]:
heatmap("v_verb", bert_hard_interx_causal, "Bert-Hard", "interx", high_node_name="V")

In [None]:
heatmap("v_verb", lstm_hard_interx_causal, "LSTM-Hard", "interx")

### obj

In [None]:
heatmap("obj", bert_hard_probing_max_sel, "Bert-Hard", high_node_name="NP_obj")

In [None]:
heatmap("obj", bert_hard_probing_max_acc, "Bert-Hard", "probing-acc", high_node_name="NP_obj")

In [None]:
heatmap("obj", lstm_hard_probing_max_sel, "LSTM-Hard")

In [None]:
heatmap("obj", bert_hard_interx_causal, "Bert-Hard", "interx", high_node_name="NP_obj")

In [None]:
heatmap("obj", lstm_hard_interx_causal, "LSTM-Hard", "interx")

### obj_adj

In [None]:
heatmap("obj_adj", bert_hard_probing_max_sel, "Bert-Hard", high_node_name="Adj_obj")

In [None]:
heatmap("obj_adj", bert_hard_probing_max_acc, "Bert-Hard", "probing-acc", high_node_name="Adj_obj")

In [None]:
heatmap("obj_adj", lstm_hard_probing_max_sel, "LSTM-Hard")

In [None]:
heatmap("obj_adj", bert_hard_interx_causal, "Bert-Hard", "interx", high_node_name="Adv_obj")

In [None]:
heatmap("obj_adj", lstm_hard_interx_causal, "LSTM-Hard", "interx")

### obj_noun

In [None]:
heatmap("obj_noun", bert_hard_probing_max_sel, "Bert-Hard", high_node_name="N_obj")

In [None]:
heatmap("obj_noun", bert_hard_probing_max_acc, "Bert-Hard", "probing-acc", high_node_name="N_obj")

In [None]:
heatmap("obj_noun", lstm_hard_probing_max_sel, "LSTM-Hard")

In [None]:
heatmap("obj_noun", bert_hard_interx_causal, "Bert-Hard", "interx")