In [None]:
import torch
import matplotlib.pyplot as plt

x_labels = []
x_vals = []
# x_labels = [(p,q) for q in range(4) for p in range(4)]

# x_labels = [(0,0), (0,1), (0,2), (0,3)]

# w1,w2 = "near", "far"
# w1, w2  = "behind", "front"

# we are setting: right == True, left == False
def dict_to_lists(d):
    return [d[k] for k in x_labels]

def compute_mean_std_by_condition(univ, layer, answer_side):
    """
    Computes means and stds for 'left' and 'right' logprob deltas given a specified answer_side.
    Returns: means_left, stds_left, means_right, stds_right (as lists ordered by x_labels)
    """
    collect_left = {k: [] for k in x_labels}
    collect_right = {k: [] for k in x_labels}
    
    print(collect_left)

    for kkey in univ:
        print(kkey)
        if kkey[1] == answer_side:
            for col_key in collect_left:
                collect_left[col_key].append(univ[kkey][layer][col_key][0])  # logprob 'left'
                collect_right[col_key].append(univ[kkey][layer][col_key][1]) # logprob 'right'

    means_left = {}
    stds_left = {}
    means_right = {}
    stds_right = {}

    #print(collect_left)
    for k in x_labels:
        v_left = torch.stack(collect_left[k])
        v_right = torch.stack(collect_right[k])
        means_left[k] = torch.mean(v_left).item()
        stds_left[k] = torch.std(v_left).item()
        means_right[k] = torch.mean(v_right).item()
        stds_right[k] = torch.std(v_right).item()

    return (dict_to_lists(means_left), dict_to_lists(stds_left),
            dict_to_lists(means_right), dict_to_lists(stds_right))


def print_values(layer, label, data):
    m_near, _, m_far, _ = data   # clearer variable names

    print(f"\n=== Layer {layer} — Original: object @ {label} ===")
    print("| Intervention | mean ΔP('yes' = in-between) | mean ΔP('no' = not in-between) | yes - no |")
    print("|-------------|------------------|------------------|------------|")

    for interv, mn, mf in zip(x_labels, m_near, m_far):
        diff = mf - mn
        print(f"| {interv} | {mn:+.4f} | {mf:+.4f} | {diff:+.4f} |")


def plot_layer_comparison(ax_left, ax_right, data_left, data_right, layer,w1="left", w2= "right"):
    """
    Plots two subplots (left: original='right', right: original='left') for a given layer.
    """
    # Unpack
    m_left1, s_left1, m_right1, s_right1 = data_left
    m_left2, s_left2, m_right2, s_right2 = data_right

    # Subplot for original='right'
    ax_left.plot(x_vals, m_right1, label=f"ΔP('{w2}')")
    ax_left.fill_between(x_vals, [m-s for m,s in zip(m_right1, s_right1)],
                         [m+s for m,s in zip(m_right1, s_right1)], alpha=0.3)
    ax_left.plot(x_vals, m_left1, label=f"ΔP('{w1}')")
    ax_left.fill_between(x_vals, [m-s for m,s in zip(m_left1, s_left1)],
                         [m+s for m,s in zip(m_left1, s_left1)], alpha=0.3)
    ax_left.set_xticks(x_vals)
    ax_left.set_xticklabels(x_labels)
    ax_left.set_ylabel("Change in Log Prob")
    ax_left.set_xlabel("Intervention ID on subject")
    ax_left.set_title(f"Layer {layer} — Original: {w1}")
    ax_left.legend()

    # Subplot for original='left'
    ax_right.plot(x_vals, m_right2, label=f"ΔP('{w2}')")
    ax_right.fill_between(x_vals, [m-s for m,s in zip(m_right2, s_right2)],
                          [m+s for m,s in zip(m_right2, s_right2)], alpha=0.3)
    ax_right.plot(x_vals, m_left2, label=f"ΔP('{w1}')")
    ax_right.fill_between(x_vals, [m-s for m,s in zip(m_left2, s_left2)],
                          [m+s for m,s in zip(m_left2, s_left2)], alpha=0.3)
    ax_right.set_xticks(x_vals)
    ax_right.set_xticklabels(x_labels)
    ax_right.set_ylabel("Change in Log Prob")
    ax_right.set_xlabel("Intervention ID on subject")
    ax_right.set_title(f"Layer {layer} — Original: {w2}")
    ax_right.legend()


def plot_all_layers(univ, layers=[12, 13, 14], w1="left", w2="right"):
    fig, axes = plt.subplots(len(layers), 2, figsize=(6, 3 * len(layers)))

    for i, layer in enumerate(layers):
        data_left_origin = compute_mean_std_by_condition(univ, layer, answer_side=w1)
        data_right_origin = compute_mean_std_by_condition(univ, layer, answer_side=w2)
        # PRINT VALUES HERE
        print_values(layer, w1, data_left_origin)
        print_values(layer, w2, data_right_origin)
        ax_left, ax_right = axes[i]
        plot_layer_comparison( ax_right,ax_left, data_right_origin, data_left_origin, layer, w1, w2)

    plt.tight_layout()
    plt.show()


## Fig 6b

In [None]:
univ = torch.load("embeds/batch_lr/dict_of_all_res.pt", map_location=torch.device('cpu'))
x_labels = [(0,0), (1,0), (2,0), (3,0)]
x_vals = list(range(len(x_labels)))
plot_all_layers(univ, layers=[12,13], w1="left", w2="right")

## Fig 6c

In [None]:
univ = torch.load("embeds/batch_nearfar/dict_of_all_res.pt", map_location=torch.device('cpu'))
x_labels = [(0,0), (1,0), (2,0), (3,0)]
x_vals = list(range(len(x_labels)))
plot_all_layers(univ, layers=[12,13], w1="near", w2="far")

## Fig 6d

In [None]:
univ = torch.load("embeds/batch_inbetween/dict_of_all_res.pt", map_location=torch.device('cpu'))
x_labels = [(2,0), (2,1), (2,2), (2,3)]
x_vals = list(range(len(x_labels)))
plot_all_layers(univ, layers=[12,13], w1="true", w2="false")