# Temporal ID Experiments - Figure 10 Plotting


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from pathlib import Path
from sklearn.decomposition import PCA

%matplotlib inline

## Configuration

In [None]:
# Data directories
MIRROR_SWAP_DIR = Path("embeds/mirror_swapping")
STEERING_DIR = Path("embeds/temporal_steering")
EMBEDS_DIR = Path("embeds")

# Output directory for figures
FIG_DIR = Path("figures")
FIG_DIR.mkdir(exist_ok=True)

print("Data directories:")
print(f"  Mirror Swapping: {MIRROR_SWAP_DIR}")
print(f"  Steering: {STEERING_DIR}")
print(f"  Embeddings: {EMBEDS_DIR}")
print(f"\nFigures will be saved to: {FIG_DIR}")

## Figure 10a: Mirror Swapping on Videos

Shows the layer-specific effect of swapping object word tokens between normal and reversed videos.

In [None]:
# Import plotting utilities
from swapping_plotting import plot_coords_across_layers

# Load mirror swapping results
mirror_swap_results = torch.load(MIRROR_SWAP_DIR / "mirror_swapping_results.pt")

print(f"Loaded mirror swapping results for {len(mirror_swap_results)} videos")

# Check data structure
sample_key = list(mirror_swap_results.keys())[0]
print(f"Sample key: {sample_key}")
print(f"Available modalities: {list(mirror_swap_results[sample_key].keys())}")

# Plot all three modalities across layers
layers_to_plot = list(range(1, 28))

plot_coords_across_layers(
    mirror_swap_results,
    layers=layers_to_plot,
    coords=['text', 'image', 'text-objonly'],
    figsize=(10.5, 4),
    ylim=[-1, 2],
    save_path=FIG_DIR / "figure_10a_mirror_swapping.png"
)

print("✓ Figure 10a saved")

## Figure 10b: Temporal ID Grid (PCA Projection)

Shows 1D PCA projection of temporal IDs for frames 1-7, with "Before" and "After" text embeddings.

In [None]:
def plot_1d_pca(big_embeds_dict, word=None, layer=9, before_emb=None, after_emb=None, save_path=None):
    """
    Plot 1D PCA projection of temporal embeddings.
    
    Args:
        big_embeds_dict: Dictionary of embeddings (average_embeds_dict format: [layer][frame] -> tensor)
        word: Object word (None for average embeddings)
        layer: Layer number to visualize
        before_emb: Embedding for "before" text (shape: hidden_dim)
        after_emb: Embedding for "after" text (shape: hidden_dim)
        save_path: Path to save the figure
    """
    # Collect vectors for the word at the specified layer
    if word is None:
        frame_nums = sorted(big_embeds_dict[layer].keys())
        vectors = torch.stack([
            big_embeds_dict[layer][f].to(dtype=torch.float32).cpu()
            for f in frame_nums
        ])
    else:
        frame_nums = sorted(big_embeds_dict[word][layer].keys())
        vectors = torch.stack([
            big_embeds_dict[word][layer][f].to(dtype=torch.float32).cpu()
            for f in frame_nums
        ])

    # Run PCA to get 1D projection
    pca = PCA(n_components=1)
    pcs = pca.fit_transform(vectors.numpy())  # shape (num_frames, 1)
    
    # Flip axis if needed: ensure early frames (low index) have lower projection values
    # This ensures "before" is on the left and "after" is on the right
    if pcs[0, 0] > pcs[-1, 0]:
        pcs = -pcs
        pca.components_ = -pca.components_
    
    # Normalize to [0, 1]
    pcs_min, pcs_max = pcs.min(), pcs.max()
    pcs_norm = (pcs - pcs_min) / (pcs_max - pcs_min)

    # Project before_emb and after_emb if provided
    projected_B = projected_A = None
    if before_emb is not None:
        before_np = before_emb.to(dtype=torch.float32).cpu().numpy()
        if before_np.ndim == 1:
            before_np = before_np.reshape(1, -1)
        projected_B = pca.transform(before_np).squeeze()
        projected_B = (projected_B - pcs_min) / (pcs_max - pcs_min)
    if after_emb is not None:
        after_np = after_emb.to(dtype=torch.float32).cpu().numpy()
        if after_np.ndim == 1:
            after_np = after_np.reshape(1, -1)
        projected_A = pca.transform(after_np).squeeze()
        projected_A = (projected_A - pcs_min) / (pcs_max - pcs_min)

    # Plot
    colors = plt.get_cmap("rainbow")(np.linspace(0, 1, len(frame_nums)))
    plt.figure(figsize=(10, 1.5))

    for i, val in enumerate(pcs_norm):
        plt.scatter(val, 0, color=colors[i], label=f"Frame {frame_nums[i]}", s=80)

    if projected_B is not None:
        plt.scatter(projected_B, 0, color='black', marker='$B$', s=120, label='$B$ (Before)')

    if projected_A is not None:
        plt.scatter(projected_A, 0, color='black', marker='$A$', s=150, label='$A$ (After)')

    word_label = word if word else "Average"
    plt.yticks([])
    plt.xlabel("First Principal Component")
    plt.title(f'1D PCA of {word_label} Embeddings at Layer {layer}\nVar explained: {pca.explained_variance_ratio_[0]:.2f}')
    plt.grid(True, axis='x', linestyle='--', alpha=0.5)
    plt.tight_layout()
    
    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()
    
    print(f"Explained variance ratio: {pca.explained_variance_ratio_[0]:.4f}")
    return pca.explained_variance_ratio_[0]


# Load temporal embeddings
temporal_data = torch.load(EMBEDS_DIR / "temporal_ids.pt")
big_embeds_dict = temporal_data["big_embeds_dict"]
average_embeds_dict = temporal_data["average_embeds_dict"]

print(f"Loaded embeddings for {len(big_embeds_dict)} objects")
print(f"Layers available: {sorted(average_embeds_dict.keys())}")

# Load text embeddings for "before" and "after" if available
text_embeds_path = EMBEDS_DIR / "text_embeddings.pt"
before_hidden = None
after_hidden = None

if text_embeds_path.exists():
    text_embeds = torch.load(text_embeds_path)
    before_hidden = text_embeds.get("before_hidden")
    after_hidden = text_embeds.get("after_hidden")
    print("Loaded text embeddings for 'before' and 'after'")
else:
    print("Text embeddings not found - plotting without before/after markers")

# Plot for selected layer
layer_to_plot = 12
before_emb = before_hidden[layer_to_plot] if before_hidden else None
after_emb = after_hidden[layer_to_plot] if after_hidden else None

plot_1d_pca(
    average_embeds_dict,
    word=None,
    layer=layer_to_plot,
    before_emb=before_emb,
    after_emb=after_emb,
    save_path=FIG_DIR / "figure_10b_temporal_id_grid.png"
)
print("✓ Figure 10b saved")

## Figure 10c: Temporal ID Steering

Shows how steering with different temporal IDs affects model beliefs about "before" vs "after".

In [None]:
# Load steering results (.pt format from temporal_steering.py)
# Structure: big_dict[(filename, true_before)][id1][layer] = {'delta_before': tensor, 'delta_after': tensor}
steering_results = torch.load(STEERING_DIR / "steering_results.pt")

print(f"Loaded steering results for {len(steering_results)} videos")

# Count how many have true_before=True vs False
count_true = sum(1 for k in steering_results if k[1] == True)
count_false = sum(1 for k in steering_results if k[1] == False)
print(f"  true_before=True (original answer is 'before'): {count_true}")
print(f"  true_before=False (original answer is 'after'): {count_false}")

# Helper functions based on steering_plotting.py
x_labels = list(range(1, 8))  # Intervention IDs 1-7
x_vals = list(range(7))

def dict_to_lists(d):
    return [d[k] for k in x_labels]

def compute_mean_std_by_condition(big_dict, layer, answer_side):
    """
    Computes means and stds for 'delta_before' and 'delta_after' given a specified answer_side (True/False).
    
    answer_side=True means true_before=True, i.e., original correct answer is "before"
    answer_side=False means true_before=False, i.e., original correct answer is "after"
    
    Returns: means_before, stds_before, means_after, stds_after (as lists ordered by x_labels)
    """
    collect_before = {k: [] for k in x_labels}
    collect_after = {k: [] for k in x_labels}
    
    for kkey in big_dict:
        # kkey = (filename, true_before)
        if kkey[1] == answer_side:
            for id1 in x_labels:
                if id1 in big_dict[kkey]:
                    if layer in big_dict[kkey][id1]:
                        data = big_dict[kkey][id1][layer]
                        collect_before[id1].append(data['delta_before'])
                        collect_after[id1].append(data['delta_after'])
    
    means_before, stds_before = {}, {}
    means_after, stds_after = {}, {}
    
    for id1 in x_labels:
        if collect_before[id1]:
            v_before = torch.stack(collect_before[id1])
            v_after = torch.stack(collect_after[id1])
            means_before[id1] = torch.mean(v_before).item()
            stds_before[id1] = torch.std(v_before).item()
            means_after[id1] = torch.mean(v_after).item()
            stds_after[id1] = torch.std(v_after).item()
        else:
            means_before[id1] = 0
            stds_before[id1] = 0
            means_after[id1] = 0
            stds_after[id1] = 0
    
    return (
        dict_to_lists(means_before),
        dict_to_lists(stds_before),
        dict_to_lists(means_after),
        dict_to_lists(stds_after)
    )


def print_values(layer, label, data):
    m_before, _, m_after, _ = data

    print(f"\n=== Layer {layer} — Original: {label} ===")
    print("| Intervention | mean ΔP('before') | mean ΔP('after') | before - after |")
    print("|-------------|------------------|------------------|----------------|")

    for interv, mb, ma in zip(x_labels, m_before, m_after):
        diff = mb - ma
        print(f"| {interv} | {mb:+.4f} | {ma:+.4f} | {diff:+.4f} |")


def plot_single_condition(ax, data, layer, condition_label):
    """Plot a single condition (before or after) on one axis."""
    m_before, s_before, m_after, s_after = data
    
    ax.plot(x_vals, m_after, label="ΔP('after')")
    ax.fill_between(x_vals, [m-s for m,s in zip(m_after, s_after)],
                    [m+s for m,s in zip(m_after, s_after)], alpha=0.3)
    ax.plot(x_vals, m_before, label="ΔP('before')")
    ax.fill_between(x_vals, [m-s for m,s in zip(m_before, s_before)],
                    [m+s for m,s in zip(m_before, s_before)], alpha=0.3)
    ax.set_xticks(x_vals)
    ax.set_xticklabels(x_labels)
    ax.set_ylabel("Change in Log Prob")
    ax.set_xlabel("Intervention ID on subject")
    ax.set_title(f"Layer {layer} — Original: {condition_label}")
    ax.legend()


def plot_all_layers(univ, layers=[13], w1="before", w2="after"):
    fig, axes = plt.subplots(len(layers), 2, figsize=(10, 3 * len(layers)))
    
    # Handle single layer case
    if len(layers) == 1:
        axes = [axes]

    for i, layer in enumerate(layers):
        # answer_side=True means original correct answer is "before"
        # answer_side=False means original correct answer is "after"
        data_before = compute_mean_std_by_condition(univ, layer, answer_side=True)
        data_after = compute_mean_std_by_condition(univ, layer, answer_side=False)
        
        # Print values
        print_values(layer, w1, data_before)
        print_values(layer, w2, data_after)
        
        ax_left, ax_right = axes[i]
        # Left subplot: Original answer is "before" (true_before=True)
        plot_single_condition(ax_left, data_before, layer, w1)
        # Right subplot: Original answer is "after" (true_before=False)
        plot_single_condition(ax_right, data_after, layer, w2)

    plt.tight_layout()
    plt.savefig(FIG_DIR / "figure_10c_temporal_steering.png", dpi=300, bbox_inches='tight')
    plt.savefig(FIG_DIR / "figure_10c_temporal_steering.pdf", bbox_inches='tight')
    plt.show()


# Plot for layer 13
plot_all_layers(steering_results, layers=[13])
print("✓ Figure 10c saved")