# Depth Diagnosis: Figure 7a


In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

In [None]:
# Load height results (above/below)
height_results = torch.load("embeds/depth_diagnosis/depth_diagnosis_results_height.pt", map_location="cpu")

# Load depth results (front/behind)
depth_results = torch.load("embeds/depth_diagnosis/depth_diagnosis_results_depth.pt", map_location="cpu")

print(f"Height experiments: {len(height_results)} images")
print(f"Depth experiments: {len(depth_results)} images")

In [None]:
def compute_mean_std_collapsed(results_dict, layer, x_labels):
    """
    Compute mean and std of logprob deltas collapsed across all images.
    
    Returns:
        means_w1: Mean change in log P(word1) for each intervention
        stds_w1: Std change in log P(word1)
        means_w2: Mean change in log P(word2)
        stds_w2: Std change in log P(word2)
    """
    collect_w1 = {k: [] for k in x_labels}
    collect_w2 = {k: [] for k in x_labels}
    
    for key in results_dict:
        if layer not in results_dict[key]:
            continue
        for coord in x_labels:
            if coord in results_dict[key][layer]:
                collect_w1[coord].append(results_dict[key][layer][coord][0])
                collect_w2[coord].append(results_dict[key][layer][coord][1])
    
    means_w1 = []
    stds_w1 = []
    means_w2 = []
    stds_w2 = []
    
    for k in x_labels:
        if collect_w1[k]:
            v_w1 = torch.tensor(collect_w1[k])
            v_w2 = torch.tensor(collect_w2[k])
            means_w1.append(torch.mean(v_w1).item())
            stds_w1.append(torch.std(v_w1).item())
            means_w2.append(torch.mean(v_w2).item())
            stds_w2.append(torch.std(v_w2).item())
        else:
            means_w1.append(0)
            stds_w1.append(0)
            means_w2.append(0)
            stds_w2.append(0)
    
    return means_w1, stds_w1, means_w2, stds_w2

In [None]:
# Setup: vary y-coordinate (height), fix x-coordinate at 0
x_labels = [(0, 0, 224), (0, 1, 224), (0, 2, 224), (0, 3, 224)]  # Bottom to top
x_vals = list(range(len(x_labels)))
layer = 12  # Modality binding layer

# Compute statistics
m_front, s_front, m_behind, s_behind = compute_mean_std_collapsed(height_results, layer, x_labels)
m_below, s_below, m_above, s_above = compute_mean_std_collapsed(depth_results, layer, x_labels)

# Create figure
fig, (ax2, ax1) = plt.subplots(1, 2, figsize=(6, 3))

ax1.plot(x_vals, m_behind, label="ΔP('behind')")
ax1.fill_between(x_vals, 
                 [m - s for m, s in zip(m_behind, s_behind)],
                 [m + s for m, s in zip(m_behind, s_behind)], 
                 alpha=0.25)

ax1.plot(x_vals, m_front, label="ΔP('front')")
ax1.fill_between(x_vals, 
                 [m - s for m, s in zip(m_front, s_front)],
                 [m + s for m, s in zip(m_front, s_front)], 
                 alpha=0.25)

ax1.set_xticks(x_vals)
ax1.set_xticklabels([(0, 0), (0, 1), (0, 2), (0, 3)])
ax1.set_ylabel('Change in Log Prob')
ax1.set_xlabel('Intervention ID on subject')
ax1.set_title('Layer 12')
ax1.legend()

ax2.plot(x_vals, m_above, label="ΔP('above')")
ax2.fill_between(x_vals, 
                 [m - s for m, s in zip(m_above, s_above)],
                 [m + s for m, s in zip(m_above, s_above)], 
                 alpha=0.25)

ax2.plot(x_vals, m_below, label="ΔP('below')")
ax2.fill_between(x_vals, 
                 [m - s for m, s in zip(m_below, s_below)],
                 [m + s for m, s in zip(m_below, s_below)], 
                 alpha=0.25)

ax2.set_xticks(x_vals)
ax2.set_xticklabels([(0, 0), (0, 1), (0, 2), (0, 3)])
ax2.set_ylabel('Change in Log Prob')
ax2.set_xlabel('Intervention ID on subject')
ax2.set_title('Layer 12')
ax2.legend()

plt.tight_layout()
plt.savefig('fig7a_depth_diagnosis.pdf', dpi=300, bbox_inches='tight')
plt.savefig('fig7a_depth_diagnosis.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nFigure saved as:")
print("  fig7a_depth_diagnosis.pdf")
print("  fig7a_depth_diagnosis.png")