# Figure 4: Mirror Swapping Experiment

This notebook generates Figure 4 from the paper, showing the effects of mirror swapping and attribute swapping interventions.

## Setup

Before running this notebook, make sure you've run the experiments:
```bash
bash run_experiments.sh <data_dir> outputs/
```

In [None]:
import torch
import matplotlib.pyplot as plt
import numpy as np

## Load Experiment Results

Update the paths below to point to your output directory from `run_experiments.sh`

In [None]:
# Path to output directory from run_experiments.sh
OUTPUT_DIR = "embeds/mirror_swapping/"

# Load mirror swapping results (Figure 4A)
mirror_swap_objwords = torch.load(
    f"{OUTPUT_DIR}/mirror_swap_objwords_dict_of_all_res.pt", 
    map_location=torch.device('cpu')
)

# Load attribute swapping results (Figure 4B - control)
attr_swap_objwords = torch.load(
    f"{OUTPUT_DIR}/attribute_swap_objwords_dict_of_all_res.pt",
    map_location=torch.device('cpu')
)

print(f"Loaded mirror swap data: {len(mirror_swap_objwords)} samples")
print(f"Loaded attribute swap data: {len(attr_swap_objwords)} samples")

## Define Plotting Functions

In [None]:
def compute_token_stats(result_dict, coord, layers):
    """
    For a given coordinate, compute mean ± std across layers for both 'left' and 'right' tokens.
    
    Args:
        result_dict: Dictionary with keys (img_id, direction) containing layer-wise results
        coord: Type of intervention ('text_objwords', 'image', 'text')
        layers: List of layer indices to analyze
    
    Returns:
        means_left, stds_left, means_right, stds_right: Statistics for each layer
    """
    means_left, stds_left, means_right, stds_right = [], [], [], []

    for layer in layers:
        collect_left = []
        collect_right = []

        for key in result_dict:
            left_val = result_dict[key][layer][coord][0]   # token 'left'
            right_val = result_dict[key][layer][coord][1]  # token 'right'
            
            # Handle NaN/Inf values
            if torch.isnan(left_val) or torch.isinf(left_val):
                left_val = torch.tensor(0.0)
            if torch.isnan(right_val) or torch.isinf(right_val):
                right_val = torch.tensor(0.0)
                
            collect_left.append(left_val)
            collect_right.append(right_val)

        v_left = torch.stack(collect_left)
        v_right = torch.stack(collect_right)

        means_left.append(torch.mean(v_left).item())
        stds_left.append(torch.std(v_left).item())
        means_right.append(torch.mean(v_right).item())
        stds_right.append(torch.std(v_right).item())

    return means_left, stds_left, means_right, stds_right

In [None]:
def plot_intervention_comparison(result_dicts, layers, coord='text_objwords', titles=None):
    """
    Plot comparison of different interventions (mirror swap vs attribute swap).
    
    Args:
        result_dicts: List of result dictionaries to plot
        layers: List of layer indices
        coord: Type of intervention to plot
        titles: List of titles for each subplot
    """
    n_plots = len(result_dicts)
    fig, axes = plt.subplots(1, n_plots, figsize=(3.5 * n_plots, 4), sharey=True)
    
    if n_plots == 1:
        axes = [axes]
    
    if titles is None:
        titles = [f"Intervention {i+1}" for i in range(n_plots)]

    for i, (result_dict, title) in enumerate(zip(result_dicts, titles)):
        ax = axes[i]
        m_left, s_left, m_right, s_right = compute_token_stats(result_dict, coord, layers)

        ax.plot(layers, m_left, label="Δ*\"left\"*", linestyle='-', color='C0')
        ax.fill_between(layers,
                        [m - s for m, s in zip(m_left, s_left)],
                        [m + s for m, s in zip(m_left, s_left)],
                        alpha=0.3, color='C0')

        ax.plot(layers, m_right, label="Δ*\"right\"*", linestyle='-', color='C1')
        ax.fill_between(layers,
                        [m - s for m, s in zip(m_right, s_right)],
                        [m + s for m, s in zip(m_right, s_right)],
                        alpha=0.3, color='C1')

        ax.set_title(title, fontsize=15)
        ax.set_xlabel("Layer", fontsize=15)
        
        if i == 0:
            ax.set_ylabel("Belief Shift", fontsize=15)
        
        ax.tick_params(axis='both', which='major', labelsize=13)
        ax.set_yticks([-0.5, 0, 0.5, 1.0])
        ax.set_xticks([5*p for p in range(7)])
        ax.legend(fontsize=11, loc='best')
        ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)


    plt.tight_layout()
    return fig

## Generate Figure 4

This reproduces Figure 4 from the paper:
- **(A) Mirror Swapping**: Shows distinct binary belief swaps when swapping object word tokens
- **(B) Attribute Swapping**: Control experiment showing noise when swapping color attributes

In [None]:
# Generate Figure 4A - Mirror Swapping
layers = range(1, 32)

# Define intervention types
intervention_types = ['image', 'text', 'text_objwords']
intervention_titles = ['Swapped All Text', 'Swapped All Image', 'Swapped Obj Tokens']

# Figure 4A: Mirror Swapping
fig, axes = plt.subplots(1, 3, figsize=(10, 4), sharey=True)

for i, (coord, title) in enumerate(zip(intervention_types, intervention_titles)):
    ax = axes[i]
    m_left, s_left, m_right, s_right = compute_token_stats(mirror_swap_objwords, coord, layers)
    
    ax.plot(layers, m_left, label="Δ\"left\"", linestyle='-', color='C0')
    ax.fill_between(layers,
                    [m - s for m, s in zip(m_left, s_left)],
                    [m + s for m, s in zip(m_left, s_left)],
                    alpha=0.3, color='C0')

    ax.plot(layers, m_right, label="Δ\"right\"", linestyle='-', color='C1')
    ax.fill_between(layers,
                    [m - s for m, s in zip(m_right, s_right)],
                    [m + s for m, s in zip(m_right, s_right)],
                    alpha=0.3, color='C1')

    ax.set_title(title, fontsize=13)
    ax.set_xlabel("Layer", fontsize=12)
    ax.tick_params(axis='both', which='major', labelsize=10)
    ax.set_yticks([-0.5, 0, 0.5, 1.0])
    ax.set_xticks([5*p for p in range(7)])
    ax.legend(fontsize=9, loc='best')
    ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    
    if i == 0:
        ax.set_ylabel("Belief Shift", fontsize=12)

plt.suptitle("(A) Mirror Swapping", fontsize=14, y=1.02)
plt.tight_layout()
plt.show()

# Figure 4B: Attribute Swapping
fig, axes = plt.subplots(1, 3, figsize=(10, 4), sharey=True)

for i, (coord, title) in enumerate(zip(intervention_types, intervention_titles)):
    ax = axes[i]
    m_left, s_left, m_right, s_right = compute_token_stats(attr_swap_objwords, coord, layers)
    
    ax.plot(layers, m_left, label="Δ\"left\"", linestyle='-', color='C0')
    ax.fill_between(layers,
                    [m - s for m, s in zip(m_left, s_left)],
                    [m + s for m, s in zip(m_left, s_left)],
                    alpha=0.3, color='C0')

    ax.plot(layers, m_right, label="Δ\"right\"", linestyle='-', color='C1')
    ax.fill_between(layers,
                    [m - s for m, s in zip(m_right, s_right)],
                    [m + s for m, s in zip(m_right, s_right)],
                    alpha=0.3, color='C1')

    ax.set_title(title, fontsize=13)
    ax.set_xlabel("Layer", fontsize=12)
    ax.tick_params(axis='both', which='major', labelsize=10)
    ax.set_xticks([5*p for p in range(7)])
    ax.legend(fontsize=9, loc='best')
    ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
    
    if i == 0:
        ax.set_ylabel("Belief Shift", fontsize=12)

plt.suptitle("(B) Attribute Swapping", fontsize=14, y=1.02)
plt.tight_layout()
plt.show()