# V-Lens Paper Figures

This notebook contains all the data and plotting logic for the figures in the paper:
**"The surprising interpretability of vision tokens in LLMs"**

Each section contains:
1. The raw data needed for the plot
2. The plotting code

This notebook is self-contained and reproducible.


In [None]:
# Setup and imports
import matplotlib.pyplot as plt
import matplotlib
import seaborn as sns
import numpy as np
from pathlib import Path

# Set default style
sns.set_style("whitegrid")
plt.rcParams['figure.dpi'] = 150
plt.rcParams['savefig.dpi'] = 300
plt.rcParams['font.family'] = 'sans-serif'

# Output directory for figures
OUTPUT_DIR = Path('paper_figures_output')
OUTPUT_DIR.mkdir(exist_ok=True)
print(f"Figures will be saved to: {OUTPUT_DIR.absolute()}")


---
## Figure 1: Unified Interpretability Line Plot

Three-panel figure showing interpretability across layers for:
- (a) Nearest Neighbors (static V-Lens at input layer)
- (b) Logit Lens
- (c) Contextual NN (contextual V-Lens)

Each panel shows 9 model combinations (3 LLMs Ã— 3 vision encoders).


In [None]:
# =============================================================================
# RAW DATA: Interpretability percentages per layer
# Format: {"llm+encoder": {layer: percentage}}
# =============================================================================

# Nearest Neighbors (Static V-Lens) - LLM Judge evaluated
NN_DATA = {
    "llama3-8b+dinov2-large-336": {0: 20.33, 1: 17.0, 2: 19.0, 3: 18.0, 4: 20.0, 8: 15.0, 12: 19.0, 16: 20.0, 20: 24.0, 24: 19.0, 28: 22.0, 32: 12.0},
    "llama3-8b+siglip": {0: 23.33, 1: 30.0, 2: 24.0, 3: 29.0, 4: 28.0, 8: 31.0, 12: 31.0, 16: 29.0, 20: 34.0, 24: 27.0, 28: 27.0, 32: 29.0},
    "llama3-8b+vit-l-14-336": {0: 35.33, 1: 30.0, 2: 37.0, 3: 32.0, 4: 29.0, 8: 37.0, 12: 44.0, 16: 51.0, 20: 52.0, 24: 47.0, 28: 43.0, 32: 21.0},
    "olmo-7b+dinov2-large-336": {0: 42.0, 1: 45.0, 2: 40.0, 3: 41.0, 4: 47.0, 8: 56.0, 12: 61.0, 16: 67.0, 20: 70.0, 24: 67.0, 28: 67.0, 32: 33.0},
    "olmo-7b+siglip": {0: 41.67, 1: 39.0, 2: 38.0, 3: 28.0, 4: 39.0, 8: 45.0, 12: 55.0, 16: 49.0, 20: 55.0, 24: 56.0, 28: 53.0, 32: 22.0},
    "olmo-7b+vit-l-14-336": {0: 55.0, 1: 52.0, 2: 56.0, 3: 60.0, 4: 59.0, 8: 60.0, 12: 62.0, 16: 62.0, 20: 63.0, 24: 59.0, 28: 62.0, 32: 35.0},
    "qwen2-7b+dinov2-large-336": {0: 7.0, 1: 10.0, 2: 9.0, 3: 7.0, 4: 9.0, 8: 11.0, 12: 11.0, 16: 11.0, 20: 12.0, 24: 14.0, 28: 9.0},
    "qwen2-7b+siglip": {0: 5.33, 1: 4.0, 2: 5.0, 3: 5.0, 4: 4.0, 8: 3.0, 12: 5.0, 16: 4.0, 20: 5.0, 24: 5.0, 28: 9.0},
    "qwen2-7b+vit-l-14-336": {0: 17.67, 1: 15.0, 2: 9.0, 3: 13.0, 4: 15.0, 8: 18.0, 12: 18.0, 16: 9.0, 20: 8.0, 24: 16.0, 28: 10.0},
}

# Logit Lens - LLM Judge evaluated
LOGITLENS_DATA = {
    "llama3-8b+dinov2-large-336": {0: 9.0, 1: 7.0, 2: 9.0, 3: 11.0, 4: 5.0, 8: 10.0, 12: 10.0, 16: 11.0, 20: 9.0, 24: 7.0, 28: 7.0, 29: 9.0, 30: 13.0, 31: 7.0, 32: 7.0},
    "llama3-8b+siglip": {0: 9.0, 1: 9.0, 2: 10.0, 3: 8.0, 4: 12.0, 8: 10.0, 12: 9.0, 16: 9.0, 20: 13.0, 24: 14.0, 28: 8.0, 29: 10.0, 30: 9.0, 31: 9.0, 32: 7.0},
    "llama3-8b+vit-l-14-336": {0: 13.0, 1: 10.0, 2: 10.0, 3: 11.0, 4: 14.0, 8: 10.0, 12: 7.0, 16: 12.0, 20: 26.0, 24: 50.0, 28: 52.0, 29: 62.0, 30: 64.0, 31: 76.0, 32: 81.0},
    "olmo-7b+dinov2-large-336": {0: 11.0, 1: 13.0, 2: 13.0, 3: 13.0, 4: 17.0, 8: 15.0, 12: 23.0, 16: 39.0, 20: 61.0, 24: 78.0, 28: 76.0, 29: 78.0, 30: 69.0, 31: 56.0, 32: 32.0},
    "olmo-7b+siglip": {0: 14.0, 1: 20.0, 2: 15.0, 3: 21.0, 4: 16.0, 8: 20.0, 12: 22.0, 16: 26.0, 20: 52.0, 24: 69.0, 28: 83.0, 29: 86.0, 30: 82.0, 31: 63.0, 32: 43.0},
    "olmo-7b+vit-l-14-336": {0: 11.0, 1: 8.0, 2: 18.0, 3: 19.0, 4: 19.0, 8: 22.0, 12: 25.0, 16: 23.0, 20: 49.0, 24: 75.0, 28: 78.0, 29: 82.0, 30: 74.0, 31: 59.0, 32: 31.0},
    "qwen2-7b+dinov2-large-336": {0: 15.0, 1: 9.0, 2: 10.0, 3: 12.0, 4: 9.0, 8: 7.0, 12: 8.0, 16: 13.0, 20: 14.0, 24: 25.0, 25: 34.0, 26: 42.0, 27: 56.0, 28: 45.0},
    "qwen2-7b+siglip": {0: 8.0, 1: 7.0, 2: 9.0, 3: 7.0, 4: 9.0, 8: 8.0, 12: 6.0, 16: 8.0, 20: 7.0, 24: 11.0, 25: 6.0, 26: 11.0, 27: 6.0, 28: 12.0},
    "qwen2-7b+vit-l-14-336": {0: 6.0, 1: 4.0, 2: 6.0, 3: 2.0, 4: 3.0, 8: 7.0, 12: 12.0, 16: 8.0, 20: 9.0, 24: 43.0, 25: 51.0, 26: 59.0, 27: 78.0, 28: 71.0},
}

# Contextual NN (Contextual V-Lens) - LLM Judge evaluated
# Layer 0 = static NN (same as NN_DATA at layer 0), Layer 1+ = contextual embeddings
CONTEXTUAL_DATA = {
    "llama3-8b+dinov2-large-336": {0: 20.33, 1: 82.0, 2: 81.0, 4: 82.0, 8: 82.0, 16: 82.0, 24: 83.33},
    "llama3-8b+siglip": {0: 23.33, 1: 63.0, 2: 60.0, 4: 65.0, 8: 62.0, 16: 58.0, 24: 63.89},
    "llama3-8b+vit-l-14-336": {0: 35.33, 1: 82.0, 2: 84.0, 4: 79.0, 8: 82.0, 16: 82.0, 24: 70.0},
    "olmo-7b+dinov2-large-336": {0: 42.0, 1: 81.0, 2: 79.0, 4: 80.0, 8: 80.0, 16: 81.0, 24: 82.43},
    "olmo-7b+siglip": {0: 41.67, 1: 63.0, 2: 64.0, 4: 65.0, 8: 64.0, 16: 65.0, 24: 60.87},
    "olmo-7b+vit-l-14-336": {0: 54.67, 1: 71.0, 2: 71.0, 4: 70.0, 8: 71.0, 16: 70.0, 24: 69.7},
    "qwen2-7b+dinov2-large-336": {0: 7.0, 1: 79.0, 2: 81.0, 4: 83.0, 8: 81.0, 16: 80.0, 24: 82.81},
    "qwen2-7b+siglip": {0: 5.33, 1: 65.0, 2: 64.0, 4: 62.0, 8: 68.0, 16: 69.0, 24: 66.67},
    "qwen2-7b+vit-l-14-336": {0: 17.67, 1: 76.0, 2: 80.0, 4: 78.0, 8: 76.0, 16: 79.0, 24: 78.0, 26: 71.43},
}


In [None]:
# =============================================================================
# PLOTTING CONFIG: Colors, markers, labels
# =============================================================================

# Display names for paper
LLM_DISPLAY_NAMES = {
    'llama3-8b': 'Llama3-8B',
    'olmo-7b': 'OLMo-7B',
    'qwen2-7b': 'Qwen2-7B'
}

ENCODER_DISPLAY_NAMES = {
    'vit-l-14-336': 'CLIP ViT-L/14',
    'siglip': 'SigLIP',
    'dinov2-large-336': 'DINOv2'
}

# Order for consistent legend
LLM_ORDER = ['olmo-7b', 'llama3-8b', 'qwen2-7b']
ENCODER_ORDER = ['vit-l-14-336', 'siglip', 'dinov2-large-336']

# Color scheme: each LLM gets a color family, encoders get shades
LLM_BASE_COLORS = {
    'olmo-7b': plt.cm.Blues,
    'llama3-8b': plt.cm.Greens,
    'qwen2-7b': plt.cm.Reds
}
ENCODER_SHADE_INDICES = [0.5, 0.7, 0.9]

# Markers for each encoder
ENCODER_MARKERS = {
    'vit-l-14-336': '*',       # star (filled)
    'siglip': 'o',             # circle (hollow)
    'dinov2-large-336': '^'    # triangle (filled)
}
ENCODER_MARKER_FACECOLORS = {
    'vit-l-14-336': None,      # filled
    'siglip': 'none',          # hollow
    'dinov2-large-336': None   # filled
}

def get_color_map():
    """Generate color mapping for all model combinations."""
    color_map = {}
    for llm in LLM_ORDER:
        base_cmap = LLM_BASE_COLORS[llm]
        for enc_idx, encoder in enumerate(ENCODER_ORDER):
            color_map[(llm, encoder)] = base_cmap(ENCODER_SHADE_INDICES[enc_idx])
    return color_map

def parse_model_key(key):
    """Parse 'llm+encoder' string into (llm, encoder) tuple."""
    parts = key.split('+')
    return parts[0], parts[1]

def get_display_label(llm, encoder):
    """Get display label for legend."""
    llm_label = LLM_DISPLAY_NAMES.get(llm, llm)
    encoder_label = ENCODER_DISPLAY_NAMES.get(encoder, encoder)
    return f"{llm_label} + {encoder_label}"


In [None]:
# =============================================================================
# PLOTTING FUNCTION: Unified 3-panel figure
# =============================================================================

def create_unified_lineplot(nn_data, logitlens_data, contextual_data, 
                            output_path=None, figsize=(18, 5)):
    """
    Create unified figure with 3 subplots showing interpretability across layers.
    
    Args:
        nn_data: Dict of nearest neighbors data
        logitlens_data: Dict of logit lens data  
        contextual_data: Dict of contextual NN data
        output_path: Path to save figure (optional)
        figsize: Figure size tuple
    """
    color_map = get_color_map()
    
    # Create figure with 3 subplots
    fig, axes = plt.subplots(1, 3, figsize=figsize)
    
    subplot_configs = [
        {'ax': axes[0], 'data': nn_data, 'title': '(a) Static V-Lens (NN)', 'xlabel': 'Layer'},
        {'ax': axes[1], 'data': logitlens_data, 'title': '(b) Logit Lens', 'xlabel': 'Layer'},
        {'ax': axes[2], 'data': contextual_data, 'title': '(c) Contextual V-Lens', 'xlabel': 'Layer'},
    ]
    
    handles_dict = {}
    
    for config in subplot_configs:
        ax = config['ax']
        data = config['data']
        
        if not data:
            continue
        
        # Get all layers
        all_layers = set()
        for key, layer_data in data.items():
            all_layers.update(layer_data.keys())
        all_layers = sorted(list(all_layers))
        
        # Plot lines for each model combination
        for llm in LLM_ORDER:
            for encoder in ENCODER_ORDER:
                key = f"{llm}+{encoder}"
                if key not in data:
                    continue
                
                layer_data = data[key]
                layers = sorted(layer_data.keys())
                values = [layer_data[l] for l in layers]
                
                if len(layers) == 0:
                    continue
                
                label = get_display_label(llm, encoder)
                marker = ENCODER_MARKERS.get(encoder, 'o')
                marker_facecolor = ENCODER_MARKER_FACECOLORS.get(encoder)
                color = color_map[(llm, encoder)]
                
                if marker_facecolor is not None:
                    line, = ax.plot(layers, values, marker=marker, color=color,
                                   markerfacecolor=marker_facecolor,
                                   markeredgewidth=2, linewidth=2.5, markersize=10)
                else:
                    line, = ax.plot(layers, values, marker=marker, color=color,
                                   linewidth=2.5, markersize=10)
                
                if label not in handles_dict:
                    handles_dict[label] = line
        
        # Customize subplot
        ax.set_xlabel(config['xlabel'], fontsize=14, fontweight='bold')
        ax.set_ylabel('Interpretability %', fontsize=14, fontweight='bold')
        ax.set_title(config['title'], fontsize=16, fontweight='bold', pad=10)
        ax.grid(True, alpha=0.3)
        ax.set_ylim(0, 100)
        ax.tick_params(axis='both', labelsize=11)
        
        if all_layers:
            ax.set_xlim(min(all_layers) - 0.5, max(all_layers) + 0.5)
            # Show subset of x-ticks if too many
            if len(all_layers) > 10:
                step = max(1, len(all_layers) // 8)
                shown_layers = all_layers[::step]
                if all_layers[-1] not in shown_layers:
                    shown_layers.append(all_layers[-1])
                ax.set_xticks(shown_layers)
            else:
                ax.set_xticks(all_layers)
    
    # Create shared legend at bottom
    ordered_handles = []
    ordered_labels = []
    for llm in LLM_ORDER:
        for encoder in ENCODER_ORDER:
            label = get_display_label(llm, encoder)
            if label in handles_dict:
                ordered_handles.append(handles_dict[label])
                ordered_labels.append(label)
    
    fig.legend(ordered_handles, ordered_labels,
              loc='lower center',
              bbox_to_anchor=(0.5, -0.12),
              ncol=3,
              fontsize=12,
              framealpha=0.9,
              columnspacing=2.0,
              handlelength=2.5)
    
    plt.tight_layout()
    plt.subplots_adjust(bottom=0.22, wspace=0.25)
    
    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Saved: {output_path}")
        # Also save PNG
        png_path = Path(output_path).with_suffix('.png')
        plt.savefig(png_path, dpi=150, bbox_inches='tight')
        print(f"Saved: {png_path}")
    
    plt.show()
    return fig


In [None]:
# Generate Figure 1: Unified interpretability plot
fig1 = create_unified_lineplot(
    NN_DATA, 
    LOGITLENS_DATA, 
    CONTEXTUAL_DATA,
    output_path=OUTPUT_DIR / 'fig1_unified_interpretability.pdf'
)


---
## Figure 2: Individual Method Plots

Separate plots for each method with more detail.


In [None]:
def create_single_lineplot(data, title, xlabel='Layer', ylabel='Interpretability %',
                           output_path=None, figsize=(10, 6)):
    """
    Create a single line plot for one method.
    """
    color_map = get_color_map()
    
    fig, ax = plt.subplots(figsize=figsize)
    
    # Get all layers
    all_layers = set()
    for key, layer_data in data.items():
        all_layers.update(layer_data.keys())
    all_layers = sorted(list(all_layers))
    
    # Plot lines
    for llm in LLM_ORDER:
        for encoder in ENCODER_ORDER:
            key = f"{llm}+{encoder}"
            if key not in data:
                continue
            
            layer_data = data[key]
            layers = sorted(layer_data.keys())
            values = [layer_data[l] for l in layers]
            
            if len(layers) == 0:
                continue
            
            label = get_display_label(llm, encoder)
            marker = ENCODER_MARKERS.get(encoder, 'o')
            marker_facecolor = ENCODER_MARKER_FACECOLORS.get(encoder)
            color = color_map[(llm, encoder)]
            
            if marker_facecolor is not None:
                ax.plot(layers, values, marker=marker, label=label, color=color,
                       markerfacecolor=marker_facecolor,
                       markeredgewidth=1.5, linewidth=2, markersize=8)
            else:
                ax.plot(layers, values, marker=marker, label=label, color=color,
                       linewidth=2, markersize=8)
    
    ax.set_xlabel(xlabel, fontsize=12, fontweight='bold')
    ax.set_ylabel(ylabel, fontsize=12, fontweight='bold')
    ax.set_title(title, fontsize=14, fontweight='bold', pad=15)
    ax.legend(loc='best', fontsize=9, framealpha=0.9)
    ax.grid(True, alpha=0.3)
    ax.set_ylim(0, 100)
    
    if all_layers:
        ax.set_xlim(min(all_layers) - 0.5, max(all_layers) + 0.5)
        ax.set_xticks(all_layers)
        if len(all_layers) > 15:
            ax.tick_params(axis='x', rotation=45)
    
    plt.tight_layout()
    
    if output_path:
        plt.savefig(output_path, dpi=300, bbox_inches='tight')
        print(f"Saved: {output_path}")
    
    plt.show()
    return fig


In [None]:
# Individual plots
fig_nn = create_single_lineplot(
    NN_DATA,
    title='Static V-Lens (Nearest Neighbors) Interpretability',
    output_path=OUTPUT_DIR / 'fig_nn_interpretability.pdf'
)


In [None]:
fig_logitlens = create_single_lineplot(
    LOGITLENS_DATA,
    title='Logit Lens Interpretability',
    output_path=OUTPUT_DIR / 'fig_logitlens_interpretability.pdf'
)


In [None]:
fig_contextual = create_single_lineplot(
    CONTEXTUAL_DATA,
    title='Contextual V-Lens Interpretability',
    output_path=OUTPUT_DIR / 'fig_contextual_interpretability.pdf'
)


---
## Data Tables

Print the data in tabular format for reference.


In [None]:
def print_data_table(data, title):
    """Print data in a formatted table."""
    print(f"\n{'='*80}")
    print(f"{title}")
    print(f"{'='*80}")
    
    # Get all layers
    all_layers = set()
    for key, layer_data in data.items():
        all_layers.update(layer_data.keys())
    all_layers = sorted(list(all_layers))
    
    # Print header
    print(f"{'Model':<35}", end="")
    for layer in all_layers:
        print(f"{'L'+str(layer):>7}", end="")
    print()
    print("-" * (35 + len(all_layers) * 7))
    
    # Print data
    for llm in LLM_ORDER:
        for encoder in ENCODER_ORDER:
            key = f"{llm}+{encoder}"
            if key not in data:
                continue
            label = get_display_label(llm, encoder)
            print(f"{label:<35}", end="")
            for layer in all_layers:
                value = data[key].get(layer)
                if value is not None:
                    print(f"{value:>6.1f}%", end="")
                else:
                    print(f"{'---':>7}", end="")
            print()

print_data_table(NN_DATA, "Static V-Lens (Nearest Neighbors) Data")
print_data_table(LOGITLENS_DATA, "Logit Lens Data")
print_data_table(CONTEXTUAL_DATA, "Contextual V-Lens Data")


---
## Summary Statistics


In [None]:
def compute_summary_stats(data, name):
    """Compute summary statistics for a dataset."""
    print(f"\n{name}:")
    print("-" * 50)
    
    for llm in LLM_ORDER:
        for encoder in ENCODER_ORDER:
            key = f"{llm}+{encoder}"
            if key not in data:
                continue
            values = list(data[key].values())
            label = get_display_label(llm, encoder)
            print(f"  {label}:")
            print(f"    Layer 0: {data[key].get(0, 'N/A'):.1f}%" if 0 in data[key] else "    Layer 0: N/A")
            print(f"    Max: {max(values):.1f}% | Min: {min(values):.1f}% | Mean: {np.mean(values):.1f}%")

compute_summary_stats(NN_DATA, "Static V-Lens")
compute_summary_stats(LOGITLENS_DATA, "Logit Lens")
compute_summary_stats(CONTEXTUAL_DATA, "Contextual V-Lens")


---
## Key Findings

Based on the data above:

1. **Static V-Lens (Layer 0)**: OLMo + CLIP shows highest interpretability (~55%), Qwen2 shows lowest (~5-17%)

2. **Logit Lens**: Shows interpretability increasing dramatically in final layers for most models (e.g., Llama3+CLIP goes from 13% to 81%)

3. **Contextual V-Lens**: Even models with low static interpretability (Qwen2) achieve high interpretability (>75%) when using contextual embeddings from layer 1+

4. **Encoder Effect**: CLIP ViT-L/14 generally provides higher interpretability than SigLIP and DINOv2 at layer 0, but DINOv2 catches up in later contextual layers

5. **LLM Effect**: OLMo shows consistently high interpretability across all methods, while Qwen2 shows the largest gap between static and contextual approaches
