In [15]:
from __future__ import annotations
import shutil
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import dotenv
import numpy as np
import os
from pathlib import Path
import tqdm
import json
from pydantic import BaseModel, ConfigDict
"""
8B04545E-B159-4A65-AE77-D474D853FE2E

We simply look at the length statistics of both the generations:
"""
assert os.environ.get("USEABLES_DIR", None) is not None
results_dir = Path.cwd() / os.environ["USEABLES_DIR"] / "steering_generations"
# assert os.environ.get("OUTPUT_DIR", None) is not None
# results_dir = Path.cwd() / os.environ["OUTPUT_DIR"] / "batched_steering"
assert results_dir.exists() and results_dir.is_dir()
assert all(x.is_file() for x in results_dir.iterdir())
jsons = list(results_dir.glob("*.json"))
dataset_layer_all_histos = {}
class AllHistosInfo(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    mean: float
    max: float
    min: float
    stdev: float
    stderr: float
    lens: np.ndarray
    mag: float
    layer: int

def get_mag_layer(name: str) -> tuple[float, int]:
    mag, layer = float(name.split("mag")[1].split("_")[0]), int(name.split("_layer")[1].split(".")[0])
    return mag, layer
print("Total num jsons: ", len(jsons))
for j in jsons:
    # mag1.0_layer8.json
    # Extract magnitude and layer from the filename
    mag, layer = get_mag_layer(j.name)
    # print(f"mag={mag}, layer={layer}") # DEBUG
    dataset2lens = {d: np.array([len(x) for x in xs]) for d, xs in json.loads(j.read_text()).items()} # dict of dataset name: list of interactions
    dataset2means= {d: x.mean() for d, x in dataset2lens.items()}
    dataset2min= {d: x.min() for d, x in dataset2lens.items()}
    dataset2max= {d: x.max() for d, x in dataset2lens.items()}
    dataset2std= {d: x.std() for d, x in dataset2lens.items()}
    dataset2stderr= {d: x.std() / len(x) for d, x in dataset2lens.items()}
    # for dataset in dataset2lens:
    #     print(f"@mag={mag}, layer={layer}, dataset={dataset}, mean={dataset2means[dataset]:.2f}, min={dataset2min[dataset]}, max={dataset2max[dataset]}, std={dataset2std[dataset]:.2f}, stderr={dataset2stderr[dataset]:.4f}") # fmt: skip
    # assert len(dataset_layer_all_histos.keys()) == 0 or set(dataset2lens.keys()) == set(dataset_layer_all_histos.keys())
    for dataset in dataset2lens:
        key = (dataset, layer)
        if key not in dataset_layer_all_histos:
            dataset_layer_all_histos[key] = []
        dataset_layer_all_histos[key].append(AllHistosInfo(
            mean=dataset2means[dataset].item(),
            max=dataset2max[dataset].item(),
            min=dataset2min[dataset].item(),
            stdev=dataset2std[dataset].item(),
            stderr=dataset2stderr[dataset].item(),
            lens=dataset2lens[dataset],
            mag=mag,
            layer=layer,
        ))

# Get a list of distinct colors from matplotlib's color maps
colors = list(mcolors.TABLEAU_COLORS.values())
assert len(colors) == len(set(colors))
dataset_layer_all_histos = {
    k: sorted(v, key=lambda x: x.mag) for k, v in dataset_layer_all_histos.items()
}
items = dataset_layer_all_histos.items()
items = sorted(items, key=lambda x: x[0][0]) # Sort first by dataset (inner ordering)
items = sorted(items, key=lambda x: x[0][1]) # Sort second by layer (outer ordering)
num_histos_per_layer_dataset = {k: len(v) for k, v in items}
print(num_histos_per_layer_dataset)
num_histos_per_layer_dataset = np.array([len(v) for v in dataset_layer_all_histos.values()])
_max, _min = num_histos_per_layer_dataset.max(), num_histos_per_layer_dataset.min()
assert _max == _min
output_dir = Path.cwd() / os.environ["OUTPUT_DIR"] / "viz"
if output_dir.exists() and len(list(output_dir.iterdir())) > 0:
    shutil.rmtree(output_dir)
assert not output_dir.exists()
output_dir.mkdir(parents=True, exist_ok=True)
do_show: bool = False
for i, ((dataset, layer), all_histos) in enumerate(tqdm.tqdm(items, desc="Processing datasets")):
    plt.figure(figsize=(10, 6))
    assert len(all_histos) <= len(colors)
    for j, histo in enumerate(all_histos):
        color = colors[j % len(colors)]
        plt.hist(histo.lens, bins=20, alpha=0.5, color=color, label=f'mag={histo.mag}, (mean={histo.mean:.2f}, stderr={histo.stderr:.4f})')
        
        # Add vertical line for mean
        ymin, ymax = plt.ylim()
        plt.vlines(x=histo.mean, ymin=0, ymax=ymax*0.9, color=color, linestyle='--', linewidth=2)
        
        # Add horizontal error bar for stderr
        plt.hlines(y=ymax*0.85-j*(ymax*0.05), xmin=histo.mean-histo.stderr, xmax=histo.mean+histo.stderr, 
                  color='black', linewidth=2)
        plt.plot([histo.mean-histo.stderr, histo.mean+histo.stderr], [ymax*0.85-j*(ymax*0.05), ymax*0.85-j*(ymax*0.05)], 
                 color='black', marker='|', markersize=8, linestyle='')
    
    plt.title(f"Layer `{layer}`, dataset `{dataset}` Length Distribution (Total samples: {len(all_histos)})")
    plt.xlabel("Length")
    plt.ylabel("Frequency")
    plt.legend()
    plt.grid(alpha=0.3)
    if do_show:
        plt.show()
    plt.savefig(output_dir / f"length_histogram_{dataset}_{layer}.png")
    plt.close()


Total num jsons:  90
{('awesome', 2): 6, ('gsm8k', 2): 6, ('leetcode', 2): 6, ('reasoning', 2): 6, ('awesome', 3): 6, ('gsm8k', 3): 6, ('leetcode', 3): 6, ('reasoning', 3): 6, ('awesome', 4): 6, ('gsm8k', 4): 6, ('leetcode', 4): 6, ('reasoning', 4): 6, ('awesome', 5): 6, ('gsm8k', 5): 6, ('leetcode', 5): 6, ('reasoning', 5): 6, ('awesome', 6): 6, ('gsm8k', 6): 6, ('leetcode', 6): 6, ('reasoning', 6): 6, ('awesome', 7): 6, ('gsm8k', 7): 6, ('leetcode', 7): 6, ('reasoning', 7): 6, ('awesome', 8): 6, ('gsm8k', 8): 6, ('leetcode', 8): 6, ('reasoning', 8): 6, ('awesome', 9): 6, ('gsm8k', 9): 6, ('leetcode', 9): 6, ('reasoning', 9): 6, ('awesome', 10): 6, ('gsm8k', 10): 6, ('leetcode', 10): 6, ('reasoning', 10): 6, ('awesome', 11): 6, ('gsm8k', 11): 6, ('leetcode', 11): 6, ('reasoning', 11): 6, ('awesome', 12): 6, ('gsm8k', 12): 6, ('leetcode', 12): 6, ('reasoning', 12): 6, ('awesome', 13): 6, ('gsm8k', 13): 6, ('leetcode', 13): 6, ('reasoning', 13): 6, ('awesome', 14): 6, ('gsm8k', 14): 6, 

Processing datasets: 100%|██████████| 60/60 [00:15<00:00,  3.94it/s]


In [21]:
# Create visualization of tables showing percent of zero-magnitude mean length

# Helper functions
def extract_data_tensor(items: list[tuple[tuple[str, int], list[AllHistosInfo]]]) -> tuple[np.ndarray, list[str], list[int], list[float]]: # fmt: skip
    """Extract data into a tensor of shape (n_datasets, n_layers, n_mags)"""
    # Get unique datasets, layers, and magnitudes
    datasets = sorted(set(item[0][0] for item in items))
    layers = sorted(set(item[0][1] for item in items))
    all_mags = sorted(set(histo.mag for _, histos in items for histo in histos))
    
    # Create mapping dictionaries for indexing
    dataset_to_idx = {dataset: i for i, dataset in enumerate(datasets)}
    layer_to_idx = {layer: i for i, layer in enumerate(layers)}
    mag_to_idx = {mag: i for i, mag in enumerate(all_mags)}
    assert len(items) == len(datasets) * len(layers)
    assert all(len(item) == len(all_mags) for _, item in items)
    assert all(set(item.mag for item in item) == set(all_mags) for _, item in items)
    
    # Initialize tensor with NaNs
    tensor = np.full((len(datasets), len(layers), len(all_mags)), np.nan)
    
    # Fill tensor with mean lengths
    for (dataset, layer), histos in items:
        d_idx = dataset_to_idx[dataset]
        l_idx = layer_to_idx[layer]
        for histo in histos:
            m_idx = mag_to_idx[histo.mag]
            tensor[d_idx, l_idx, m_idx] = histo.mean
    
    assert all_mags.count(0) == 1
    zero_mag_idx = all_mags.index(0)
    return tensor, datasets, layers, all_mags, zero_mag_idx

def normalize_tensor(tensor: np.ndarray, zero_mag_idx: int) -> np.ndarray:
    """Normalize tensor by zero-magnitude values"""
    # Normalize each dataset and layer by its zero-magnitude value
    normalized = np.zeros_like(tensor)
    for d in range(tensor.shape[0]):
        for l in range(tensor.shape[1]):
            zero_val = tensor[d, l, zero_mag_idx]
            if not np.isnan(zero_val) and zero_val != 0:
                normalized[d, l, :] = tensor[d, l, :] / zero_val# * 100
    
    return normalized

def visualize_tables(normalized_tensor, datasets, layers, magnitudes):
    """Create 6 visualizations (one per dataset)"""
    fig, axes = plt.subplots(2, 2, figsize=(18, 12)) # turns out we only had 4 datasets since we didn't use the testset lol
    axes = axes.flatten()
    
    for d_idx, dataset in enumerate(datasets):
        ax = axes[d_idx]
        data = normalized_tensor[d_idx]
        
        # Create a DataFrame for better visualization
        df = pd.DataFrame(data, index=layers, columns=magnitudes)
        
        # Create heatmap
        sns.heatmap(df, annot=True, fmt=".1f", cmap="YlGnBu", ax=ax, 
                    vmin=normalized_tensor.min().item(), vmax=normalized_tensor.max().item(), cbar_kws={'label': '% of zero-magnitude'})
        
        ax.set_title(f"Dataset: {dataset}")
        ax.set_ylabel("Layer")
        ax.set_xlabel("Magnitude")
    
    plt.tight_layout()
    plt.savefig(output_dir / "magnitude_impact_tables.png")
    plt.close()

# Main execution
# 1. Extract data into tensor
data_tensor, datasets, layers, magnitudes, zero_mag_idx = extract_data_tensor(items)
assert data_tensor.shape == (len(datasets), len(layers), len(magnitudes)), "Tensor shape mismatch"

# 2. Normalize by zero-magnitude values
normalized_tensor = normalize_tensor(data_tensor, zero_mag_idx)
assert normalized_tensor.shape == data_tensor.shape, "Normalized tensor shape mismatch"

# 3. Visualize the tables
import pandas as pd
import seaborn as sns
visualize_tables(normalized_tensor, datasets, layers, magnitudes)

print(f"Tables visualization saved to {output_dir / 'magnitude_impact_tables.png'}")


Tables visualization saved to /mnt/align4_drive2/adrianoh/git2/MiscInterpExperiments/length-experiments-2025-feb-mar/output_steering_experiments/viz/magnitude_impact_tables.png
