In [None]:
# Import necessary libraries
import torch
import matplotlib.pyplot as plt
import numpy as np
from utils.helpers import load_interpretable_model
import os
from tqdm import tqdm

# Import visualization utilities
from feature_vis_impala import (
    total_variation, 
    jitter, 
    random_scale, 
    random_rotate
)

# Import SAE-related modules
from sae_cnn import ConvSAE
from extract_sae_features import replace_layer_with_sae
from feature_vis_sae import load_sae_from_checkpoint

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load model
model = load_interpretable_model()
model.to(device)
model.eval()

# Load SAE model
sae_checkpoint_path = "../checkpoints/sae_checkpoint_step_4500000.pt"  # Update with your checkpoint path
sae = load_sae_from_checkpoint(sae_checkpoint_path, device)
layer_name = "conv4a"  # Update with your target layer
layer_number = 8
feature_idx = 1  # Change to the feature you want to visualize

# Function to visualize a feature with optional color decorrelation
def visualize_feature(model, sae, layer_name, feature_idx, num_steps=1000, 
                      use_decorrelation=False, color_matrices_path=None):
    """
    Visualize what maximally activates a specific SAE feature
    
    Args:
        model: The base model
        sae: The SAE model
        layer_name: Name of the layer containing the feature
        feature_idx: Index of the feature to visualize
        num_steps: Number of optimization steps
        use_decorrelation: Whether to use color decorrelation
        color_matrices_path: Path to color matrices file
        
    Returns:
        visualization_image: numpy array of the visualization (H, W, C)
    """
    # Set up color matrices if using decorrelation
    if use_decorrelation and color_matrices_path and os.path.exists(color_matrices_path):
        print(f"Loading color matrices from {color_matrices_path}")
        data = torch.load(color_matrices_path, map_location=device)
        whitening_matrix = data['whitening_matrix'].to(device)
        unwhitening_matrix = data['unwhitening_matrix'].to(device)
        mean_color = data['mean_color'].to(device)
    else:
        use_decorrelation = False
        print("Not using color decorrelation")
    
    sae = sae[0]
    # Attach SAE to the model
    sae_hook = replace_layer_with_sae(model, sae, layer_number)
    
    # Set up hooks to capture activations
    activations = {}
    
    def sae_activation_hook(module, input, output):
        # For ConvSAE, the 3rd return value contains the activations
        if isinstance(output, tuple) and len(output) >= 3:
            activations['sae_features'] = output[2]
        return output
    
    # Register hook on the SAE
    hook = sae.register_forward_hook(sae_activation_hook)
    
    # Initialize image
    padded_size = 64 + 8  # 4 pixels on each side
    input_img = torch.randint(0, 256, (1, 3, padded_size, padded_size), 
                            device=device, 
                            dtype=torch.float32).requires_grad_(True)
    
    # Set up optimizer
    optimizer = torch.optim.Adam([input_img], lr=0.08)
    
    # Parameters for transformations
    scales = [1.0, 0.975, 1.025, 0.95, 1.05]
    angles = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]
    
    best_activation = float('-inf')
    best_img = None
    
    # Optimization loop
    pbar = tqdm(range(num_steps))
    for step in pbar:
        optimizer.zero_grad()
        
        # Create a copy for transformations
        processed_img = input_img.clone()
        
        # Apply transformations
        ox, oy = np.random.randint(-8, 9, 2)
        processed_img = jitter(processed_img, ox, oy)
        processed_img = random_scale(processed_img, scales)
        processed_img = random_rotate(processed_img, angles)
        ox, oy = np.random.randint(-4, 5, 2)
        processed_img = jitter(processed_img, ox, oy)
        
        # Crop padding
        processed_img = processed_img[:, :, 4:-4, 4:-4]
        
        # Ensure values are in [0, 255]
        processed_img.data.clamp_(0, 255)
        
        # Normalize to [0,1] for model input
        normalized_img = processed_img / 255.0
        
        # Forward pass
        _ = model(normalized_img)
        
        # Get the activation for the target feature
        if 'sae_features' in activations:
            feature_activation = activations['sae_features'][0, feature_idx]
            activation_loss = -feature_activation.mean()  # negative because we want to maximize
        else:
            activation_loss = torch.tensor(0.0, device=device)
        
        # Calculate regularization losses
        tv_loss = 1e-3 * total_variation(input_img / 255.0)
        l2_loss = 1e-3 * torch.norm(input_img / 255.0)
        
        # Total loss
        loss = activation_loss + tv_loss + l2_loss
        
        loss.backward()
        optimizer.step()
        
        # Post-processing
        with torch.no_grad():
            input_img.data.clamp_(0, 255)
            
            # Track best activation
            current_activation = -activation_loss.item()
            if current_activation > best_activation:
                best_activation = current_activation
                best_img = processed_img.clone()
        
        # Update progress bar
        pbar.set_postfix({'loss': loss.item(), 'act': current_activation})
    
    # Clean up hooks
    hook.remove()
    sae_hook.remove()
    
    # If we used decorrelation, apply it to the final image
    if use_decorrelation and best_img is not None:
        # This would be where we apply decorrelation, but we'll skip it for now
        # to avoid the range issues
        pass
    
    # Convert to numpy for display
    result = (best_img.detach().cpu().squeeze().permute(1, 2, 0) / 255.0).numpy()
    
    return result, best_activation

# Run the visualization
vis, activation = visualize_feature(
    model=model,
    sae=sae,
    layer_name=layer_name,
    feature_idx=feature_idx,
    num_steps=5000,  # Reduce for faster results
    use_decorrelation=False,  # Start with this off to ensure it works
    color_matrices_path="color_matrices.pt"
)

# Display the result
plt.figure(figsize=(8, 8))
plt.imshow(vis)
plt.title(f'SAE Feature {feature_idx} (Act: {activation:.2f})')
plt.axis('off')
plt.show()

In [None]:
# Import necessary libraries
import torch
import matplotlib.pyplot as plt
import numpy as np
from utils.helpers import load_interpretable_model
import os
from tqdm import tqdm

# Import visualization utilities
from feature_vis_impala import (
    total_variation, 
    jitter, 
    random_scale, 
    random_rotate
)

# Import SAE-related modules
from sae_cnn import ConvSAE
from extract_sae_features import replace_layer_with_sae
from feature_vis_sae import load_sae_from_checkpoint

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load model
model = load_interpretable_model()
model.to(device)
model.eval()

# Load SAE model with adjusted path and access
sae_checkpoint_path = "../checkpoints/sae_checkpoint_step_4500000.pt"  # Adjusted with ..
sae_result = load_sae_from_checkpoint(sae_checkpoint_path, device)
layer_name = "conv4a"  # Update with your target layer
layer_number = 8

# List of feature indices to visualize
feature_indices = [0, 1, 2, 3]  # Change to the features you want to visualize

# Function to apply whitening to an image tensor
def apply_whitening(x, whitening_matrix, mean_color=None):
    """Apply whitening transformation to an image tensor"""
    # Store original shape
    original_shape = x.shape
    
    # Reshape to (B*H*W, C)
    x_flat = x.permute(0, 2, 3, 1).reshape(-1, 3)
    
    # Center if mean_color is provided
    if mean_color is not None:
        x_flat = x_flat - mean_color
    
    # Apply whitening
    x_whitened = torch.mm(x_flat, whitening_matrix.t())
    
    # Reshape back to original shape
    x_whitened = x_whitened.reshape(original_shape[0], original_shape[2], original_shape[3], 3).permute(0, 3, 1, 2)
    
    return x_whitened

# Function to apply unwhitening to an image tensor
def apply_unwhitening(x, unwhitening_matrix, mean_color=None):
    """Apply unwhitening transformation to an image tensor"""
    # Store original shape
    original_shape = x.shape
    
    # Reshape to (B*H*W, C)
    x_flat = x.permute(0, 2, 3, 1).reshape(-1, 3)
    
    # Apply unwhitening
    x_unwhitened = torch.mm(x_flat, unwhitening_matrix.t())
    
    # Add back mean if provided
    if mean_color is not None:
        x_unwhitened = x_unwhitened + mean_color
    
    # Reshape back to original shape
    x_unwhitened = x_unwhitened.reshape(original_shape[0], original_shape[2], original_shape[3], 3).permute(0, 3, 1, 2)
    
    return x_unwhitened

# Function to visualize a feature with optional color decorrelation
def visualize_feature(model, sae, layer_name, feature_idx, num_steps=1000, 
                      use_decorrelation=False, color_matrices_path=None):
    """
    Visualize what maximally activates a specific SAE feature
    
    Args:
        model: The base model
        sae: The SAE model
        layer_name: Name of the layer containing the feature
        feature_idx: Index of the feature to visualize
        num_steps: Number of optimization steps
        use_decorrelation: Whether to use color decorrelation
        color_matrices_path: Path to color matrices file
        
    Returns:
        visualization_image: numpy array of the visualization (H, W, C)
    """
    # Set up color matrices if using decorrelation
    if use_decorrelation and color_matrices_path and os.path.exists(color_matrices_path):
        print(f"Loading color matrices from {color_matrices_path}")
        data = torch.load(color_matrices_path, map_location=device)
        whitening_matrix = data['whitening_matrix'].to(device)
        unwhitening_matrix = data['unwhitening_matrix'].to(device)
        mean_color = data['mean_color'].to(device)
    else:
        use_decorrelation = False
        whitening_matrix = torch.eye(3, device=device)
        unwhitening_matrix = torch.eye(3, device=device)
        mean_color = torch.zeros(3, device=device)
        print("Not using color decorrelation")
    
    # Attach SAE to the model
    sae = sae[0]
    sae_hook = replace_layer_with_sae(model, sae, layer_number)
    
    # Set up hooks to capture activations
    activations = {}
    
    def sae_activation_hook(module, input, output):
        # For ConvSAE, the 3rd return value contains the activations
        if isinstance(output, tuple) and len(output) >= 3:
            activations['sae_features'] = output[2]
        return output
    
    # Register hook on the SAE
    hook = sae.register_forward_hook(sae_activation_hook)
    
    # Initialize image
    padded_size = 64 + 8  # 4 pixels on each side
    input_img = torch.randint(0, 256, (1, 3, padded_size, padded_size), 
                            device=device, 
                            dtype=torch.float32).requires_grad_(True)
    
    # Set up optimizer
    optimizer = torch.optim.Adam([input_img], lr=0.08)
    
    # Parameters for transformations
    scales = [1.0, 0.975, 1.025, 0.95, 1.05]
    angles = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]
    
    best_activation = float('-inf')
    best_img = None
    
    # Optimization loop
    pbar = tqdm(range(num_steps))
    for step in pbar:
        optimizer.zero_grad()
        
        # Create a copy for transformations
        processed_img = input_img.clone()
        
        # Apply transformations
        ox, oy = np.random.randint(-8, 9, 2)
        processed_img = jitter(processed_img, ox, oy)
        processed_img = random_scale(processed_img, scales)
        processed_img = random_rotate(processed_img, angles)
        ox, oy = np.random.randint(-4, 5, 2)
        processed_img = jitter(processed_img, ox, oy)
        
        # Crop padding
        processed_img = processed_img[:, :, 4:-4, 4:-4]
        
        # Ensure values are in [0, 255]
        processed_img.data.clamp_(0, 255)
        
        # Normalize to [0,1] for model input
        normalized_img = processed_img / 255.0
        
        # Apply whitening if using decorrelation
        # Note: We optimize in the original space but do the forward pass in whitened space
        if use_decorrelation:
            # Apply whitening
            whitened_img = apply_whitening(normalized_img, whitening_matrix, mean_color)
            
            # Rescale to [0,1] range for the model
            # This is a key step to avoid the out-of-range error
            whitened_min = whitened_img.min()
            whitened_max = whitened_img.max()
            rescaled_img = (whitened_img - whitened_min) / (whitened_max - whitened_min)
            
            # Use the rescaled whitened image for the forward pass
            model_input = rescaled_img
        else:
            # Use the normalized image directly
            model_input = normalized_img
        
        # Forward pass
        _ = model(model_input)
        
        # Get the activation for the target feature
        if 'sae_features' in activations:
            feature_activation = activations['sae_features'][0, feature_idx]
            activation_loss = -feature_activation.mean()  # negative because we want to maximize
        else:
            activation_loss = torch.tensor(0.0, device=device)
        
        # Calculate regularization losses
        tv_loss = 1e-3 * total_variation(input_img / 255.0)
        l2_loss = 1e-3 * torch.norm(input_img / 255.0)
        
        # Total loss
        loss = activation_loss + tv_loss + l2_loss
        
        loss.backward()
        optimizer.step()
        
        # Post-processing
        with torch.no_grad():
            input_img.data.clamp_(0, 255)
            
            # Track best activation
            current_activation = -activation_loss.item()
            if current_activation > best_activation:
                best_activation = current_activation
                best_img = processed_img.clone()
        
        # Update progress bar
        pbar.set_postfix({'loss': loss.item(), 'act': current_activation})
    
    # Clean up hooks
    hook.remove()
    sae_hook.remove()
    
    # Process the final image
    result_img = best_img / 255.0  # Normalize to [0,1]
    
    # If we used decorrelation, apply it to the final image for visualization
    if use_decorrelation:
        # Apply whitening
        whitened_img = apply_whitening(result_img, whitening_matrix, mean_color)
        
        # Apply unwhitening to get back to normal color space
        # This is the key step for visualization - we want to see what the decorrelated
        # optimization produced in the original color space
        unwhitened_img = apply_unwhitening(whitened_img, unwhitening_matrix, mean_color)
        
        # Ensure values are in [0,1]
        result_img = torch.clamp(unwhitened_img, 0, 1)
    
    # Convert to numpy for display
    result = result_img.detach().cpu().squeeze().permute(1, 2, 0).numpy()
    
    return result, best_activation

# Create a figure with 2x2 subplots for standard visualization
fig1, axes1 = plt.subplots(2, 2, figsize=(12, 12))
axes1 = axes1.flatten()

# Create a figure with 2x2 subplots for decorrelated visualization
fig2, axes2 = plt.subplots(2, 2, figsize=(12, 12))
axes2 = axes2.flatten()

# Create visualizations directory if it doesn't exist
os.makedirs('visualizations', exist_ok=True)

# Visualize each feature with both standard and decorrelated approaches
for i, feature_idx in enumerate(feature_indices):
    print(f"\nVisualizing feature {feature_idx} ({i+1}/{len(feature_indices)})")
    
    # Standard visualization (no decorrelation)
    print("Standard visualization...")
    vis_standard, act_standard = visualize_feature(
        model=model,
        sae=sae,
        layer_name=layer_name,
        feature_idx=feature_idx,
        num_steps=2560,
        use_decorrelation=False
    )
    
    # Decorrelated visualization
    print("Decorrelated visualization...")
    vis_decorrelated, act_decorrelated = visualize_feature(
        model=model,
        sae=sae,
        layer_name=layer_name,
        feature_idx=feature_idx,
        num_steps=2560,
        use_decorrelation=True,
        color_matrices_path="color_matrices.pt"
    )
    
    # Display and save standard visualization
    axes1[i].imshow(vis_standard)
    axes1[i].set_title(f'Standard - Feature {feature_idx} (Act: {act_standard:.2f})')
    axes1[i].axis('off')
    plt.imsave(f'visualizations/standard_feature_{feature_idx}.png', vis_standard)
    
    # Display and save decorrelated visualization  
    axes2[i].imshow(vis_decorrelated)
    axes2[i].set_title(f'Decorrelated - Feature {feature_idx} (Act: {act_decorrelated:.2f})')
    axes2[i].axis('off')
    plt.imsave(f'visualizations/decorrelated_feature_{feature_idx}.png', vis_decorrelated)

# Add overall titles
fig1.suptitle('Standard Visualization', fontsize=16)
fig2.suptitle('Decorrelated Visualization', fontsize=16)

# Save and display combined plots
plt.figure(fig1.number)
plt.tight_layout()
plt.savefig('visualizations/standard_features_combined.png')

plt.figure(fig2.number)
plt.tight_layout() 
plt.savefig('visualizations/decorrelated_features_combined.png')

plt.show()