In [None]:
# Import necessary libraries
import torch
import matplotlib.pyplot as plt
import numpy as np
from utils.helpers import load_interpretable_model
import os
from tqdm import tqdm

# Import visualization utilities
from feature_vis_impala import (
    total_variation, 
    jitter, 
    random_scale, 
    random_rotate
)

# Import SAE-related modules
from sae_cnn import ConvSAE
from extract_sae_features import replace_layer_with_sae
from feature_vis_sae import load_sae_from_checkpoint

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load model
model = load_interpretable_model()
model.to(device)
model.eval()

# Load SAE model
sae_checkpoint_path = "../checkpoints/sae_checkpoint_step_4500000.pt"  # Update with your checkpoint path
sae = load_sae_from_checkpoint(sae_checkpoint_path, device)
layer_name = "conv4a"  # Update with your target layer
layer_number = 8
feature_idx = 1  # Change to the feature you want to visualize

# Function to visualize a feature with optional color decorrelation
def visualize_feature(model, sae, layer_name, feature_idx, num_steps=1000, 
                      use_decorrelation=False, color_matrices_path=None):
    """
    Visualize what maximally activates a specific SAE feature
    
    Args:
        model: The base model
        sae: The SAE model
        layer_name: Name of the layer containing the feature
        feature_idx: Index of the feature to visualize
        num_steps: Number of optimization steps
        use_decorrelation: Whether to use color decorrelation
        color_matrices_path: Path to color matrices file
        
    Returns:
        visualization_image: numpy array of the visualization (H, W, C)
    """
    # Set up color matrices if using decorrelation
    if use_decorrelation and color_matrices_path and os.path.exists(color_matrices_path):
        print(f"Loading color matrices from {color_matrices_path}")
        data = torch.load(color_matrices_path, map_location=device)
        whitening_matrix = data['whitening_matrix'].to(device)
        unwhitening_matrix = data['unwhitening_matrix'].to(device)
        mean_color = data['mean_color'].to(device)
    else:
        use_decorrelation = False
        print("Not using color decorrelation")
    
    sae = sae[0]
    # Attach SAE to the model
    sae_hook = replace_layer_with_sae(model, sae, layer_number)
    
    # Set up hooks to capture activations
    activations = {}
    
    def sae_activation_hook(module, input, output):
        # For ConvSAE, the 3rd return value contains the activations
        if isinstance(output, tuple) and len(output) >= 3:
            activations['sae_features'] = output[2]
        return output
    
    # Register hook on the SAE
    hook = sae.register_forward_hook(sae_activation_hook)
    
    # Initialize image
    padded_size = 64 + 8  # 4 pixels on each side
    input_img = torch.randint(0, 256, (1, 3, padded_size, padded_size), 
                            device=device, 
                            dtype=torch.float32).requires_grad_(True)
    
    # Set up optimizer
    optimizer = torch.optim.Adam([input_img], lr=0.08)
    
    # Parameters for transformations
    scales = [1.0, 0.975, 1.025, 0.95, 1.05]
    angles = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]
    
    best_activation = float('-inf')
    best_img = None
    
    # Optimization loop
    pbar = tqdm(range(num_steps))
    for step in pbar:
        optimizer.zero_grad()
        
        # Create a copy for transformations
        processed_img = input_img.clone()
        
        # Apply transformations
        ox, oy = np.random.randint(-8, 9, 2)
        processed_img = jitter(processed_img, ox, oy)
        processed_img = random_scale(processed_img, scales)
        processed_img = random_rotate(processed_img, angles)
        ox, oy = np.random.randint(-4, 5, 2)
        processed_img = jitter(processed_img, ox, oy)
        
        # Crop padding
        processed_img = processed_img[:, :, 4:-4, 4:-4]
        
        # Ensure values are in [0, 255]
        processed_img.data.clamp_(0, 255)
        
        # Normalize to [0,1] for model input
        normalized_img = processed_img / 255.0
        
        # Forward pass
        _ = model(normalized_img)
        
        # Get the activation for the target feature
        if 'sae_features' in activations:
            feature_activation = activations['sae_features'][0, feature_idx]
            activation_loss = -feature_activation.mean()  # negative because we want to maximize
        else:
            activation_loss = torch.tensor(0.0, device=device)
        
        # Calculate regularization losses
        tv_loss = 1e-3 * total_variation(input_img / 255.0)
        l2_loss = 1e-3 * torch.norm(input_img / 255.0)
        
        # Total loss
        loss = activation_loss + tv_loss + l2_loss
        
        loss.backward()
        optimizer.step()
        
        # Post-processing
        with torch.no_grad():
            input_img.data.clamp_(0, 255)
            
            # Track best activation
            current_activation = -activation_loss.item()
            if current_activation > best_activation:
                best_activation = current_activation
                best_img = processed_img.clone()
        
        # Update progress bar
        pbar.set_postfix({'loss': loss.item(), 'act': current_activation})
    
    # Clean up hooks
    hook.remove()
    sae_hook.remove()
    
    # If we used decorrelation, apply it to the final image
    if use_decorrelation and best_img is not None:
        # This would be where we apply decorrelation, but we'll skip it for now
        # to avoid the range issues
        pass
    
    # Convert to numpy for display
    result = (best_img.detach().cpu().squeeze().permute(1, 2, 0) / 255.0).numpy()
    
    return result, best_activation

# Run the visualization
vis, activation = visualize_feature(
    model=model,
    sae=sae,
    layer_name=layer_name,
    feature_idx=feature_idx,
    num_steps=5000,  # Reduce for faster results
    use_decorrelation=False,  # Start with this off to ensure it works
    color_matrices_path="color_matrices.pt"
)

# Display the result
plt.figure(figsize=(8, 8))
plt.imshow(vis)
plt.title(f'SAE Feature {feature_idx} (Act: {activation:.2f})')
plt.axis('off')
plt.show()

In [None]:
# Import necessary libraries
import torch
import matplotlib.pyplot as plt
import numpy as np
from utils.helpers import load_interpretable_model
import os
from tqdm import tqdm

# Import visualization utilities
from feature_vis_impala import (
    total_variation, 
    jitter, 
    random_scale, 
    random_rotate
)

# Import SAE-related modules
from sae_cnn import ConvSAE
from extract_sae_features import replace_layer_with_sae
from feature_vis_sae import load_sae_from_checkpoint

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load model
model = load_interpretable_model()
model.to(device)
model.eval()

# Load SAE model with adjusted path and access
sae_checkpoint_path = "../checkpoints/sae_checkpoint_step_4500000.pt"  # Adjusted with ..
sae_result = load_sae_from_checkpoint(sae_checkpoint_path, device)
layer_name = "conv4a"  # Update with your target layer
layer_number = 8

# List of feature indices to visualize - all 128 features
feature_indices = list(range(128))  # Visualize all 128 features

# Function to apply whitening to an image tensor
def apply_whitening(x, whitening_matrix, mean_color=None):
    """Apply whitening transformation to an image tensor"""
    # Store original shape
    original_shape = x.shape
    
    # Reshape to (B*H*W, C)
    x_flat = x.permute(0, 2, 3, 1).reshape(-1, 3)
    
    # Center if mean_color is provided
    if mean_color is not None:
        x_flat = x_flat - mean_color
    
    # Apply whitening
    x_whitened = torch.mm(x_flat, whitening_matrix.t())
    
    # Reshape back to original shape
    x_whitened = x_whitened.reshape(original_shape[0], original_shape[2], original_shape[3], 3).permute(0, 3, 1, 2)
    
    return x_whitened

# Function to apply unwhitening to an image tensor
def apply_unwhitening(x, unwhitening_matrix, mean_color=None):
    """Apply unwhitening transformation to an image tensor"""
    # Store original shape
    original_shape = x.shape
    
    # Reshape to (B*H*W, C)
    x_flat = x.permute(0, 2, 3, 1).reshape(-1, 3)
    
    # Apply unwhitening
    x_unwhitened = torch.mm(x_flat, unwhitening_matrix.t())
    
    # Add back mean if provided
    if mean_color is not None:
        x_unwhitened = x_unwhitened + mean_color
    
    # Reshape back to original shape
    x_unwhitened = x_unwhitened.reshape(original_shape[0], original_shape[2], original_shape[3], 3).permute(0, 3, 1, 2)
    
    return x_unwhitened

# Function to visualize a feature with optional color decorrelation
def visualize_feature(model, sae, layer_name, feature_idx, num_steps=1000, 
                      use_decorrelation=False, color_matrices_path=None):
    """
    Visualize what maximally activates a specific SAE feature
    
    Args:
        model: The base model
        sae: The SAE model
        layer_name: Name of the layer containing the feature
        feature_idx: Index of the feature to visualize
        num_steps: Number of optimization steps
        use_decorrelation: Whether to use color decorrelation
        color_matrices_path: Path to color matrices file
        
    Returns:
        visualization_image: numpy array of the visualization (H, W, C)
    """
    # Set up color matrices if using decorrelation
    if use_decorrelation and color_matrices_path and os.path.exists(color_matrices_path):
        print(f"Loading color matrices from {color_matrices_path}")
        data = torch.load(color_matrices_path, map_location=device)
        whitening_matrix = data['whitening_matrix'].to(device)
        unwhitening_matrix = data['unwhitening_matrix'].to(device)
        mean_color = data['mean_color'].to(device)
    else:
        use_decorrelation = False
        whitening_matrix = torch.eye(3, device=device)
        unwhitening_matrix = torch.eye(3, device=device)
        mean_color = torch.zeros(3, device=device)
        print("Not using color decorrelation")
    
    # Attach SAE to the model
    sae = sae[0]
    sae_hook = replace_layer_with_sae(model, sae, layer_number)
    
    # Set up hooks to capture activations
    activations = {}
    
    def sae_activation_hook(module, input, output):
        # For ConvSAE, the 3rd return value contains the activations
        if isinstance(output, tuple) and len(output) >= 3:
            activations['sae_features'] = output[2]
        return output
    
    # Register hook on the SAE
    hook = sae.register_forward_hook(sae_activation_hook)
    
    # Initialize image
    padded_size = 64 + 8  # 4 pixels on each side
    input_img = torch.randint(0, 256, (1, 3, padded_size, padded_size), 
                            device=device, 
                            dtype=torch.float32).requires_grad_(True)
    
    # Set up optimizer
    optimizer = torch.optim.Adam([input_img], lr=0.08)
    
    # Parameters for transformations
    scales = [1.0, 0.975, 1.025, 0.95, 1.05]
    angles = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]
    
    best_activation = float('-inf')
    best_img = None
    
    # Optimization loop
    pbar = tqdm(range(num_steps))
    for step in pbar:
        optimizer.zero_grad()
        
        # Create a copy for transformations
        processed_img = input_img.clone()
        
        # Apply transformations
        ox, oy = np.random.randint(-8, 9, 2)
        processed_img = jitter(processed_img, ox, oy)
        processed_img = random_scale(processed_img, scales)
        processed_img = random_rotate(processed_img, angles)
        ox, oy = np.random.randint(-4, 5, 2)
        processed_img = jitter(processed_img, ox, oy)
        
        # Crop padding
        processed_img = processed_img[:, :, 4:-4, 4:-4]
        
        # Ensure values are in [0, 255]
        processed_img.data.clamp_(0, 255)
        
        # Normalize to [0,1] for model input
        normalized_img = processed_img / 255.0
        
        # Apply whitening if using decorrelation
        # Note: We optimize in the original space but do the forward pass in whitened space
        if use_decorrelation:
            # Apply whitening
            whitened_img = apply_whitening(normalized_img, whitening_matrix, mean_color)
            
            # Rescale to [0,1] range for the model
            # This is a key step to avoid the out-of-range error
            whitened_min = whitened_img.min()
            whitened_max = whitened_img.max()
            rescaled_img = (whitened_img - whitened_min) / (whitened_max - whitened_min)
            
            # Use the rescaled whitened image for the forward pass
            model_input = rescaled_img
        else:
            # Use the normalized image directly
            model_input = normalized_img
        
        # Forward pass
        _ = model(model_input)
        
        # Get the activation for the target feature
        if 'sae_features' in activations:
            feature_activation = activations['sae_features'][0, feature_idx]
            activation_loss = -feature_activation.mean()  # negative because we want to maximize
        else:
            activation_loss = torch.tensor(0.0, device=device)
        
        # Calculate regularization losses
        tv_loss = 1e-3 * total_variation(input_img / 255.0)
        l2_loss = 1e-3 * torch.norm(input_img / 255.0)
        
        # Total loss
        loss = activation_loss + tv_loss + l2_loss
        
        loss.backward()
        optimizer.step()
        
        # Post-processing
        with torch.no_grad():
            input_img.data.clamp_(0, 255)
            
            # Track best activation
            current_activation = -activation_loss.item()
            if current_activation > best_activation:
                best_activation = current_activation
                best_img = processed_img.clone()
        
        # Update progress bar
        pbar.set_postfix({'loss': loss.item(), 'act': current_activation})
    
    # Clean up hooks
    hook.remove()
    sae_hook.remove()
    
    # Process the final image
    result_img = best_img / 255.0  # Normalize to [0,1]
    
    # If we used decorrelation, apply it to the final image for visualization
    if use_decorrelation:
        # Apply whitening
        whitened_img = apply_whitening(result_img, whitening_matrix, mean_color)
        
        # Apply unwhitening to get back to normal color space
        # This is the key step for visualization - we want to see what the decorrelated
        # optimization produced in the original color space
        unwhitened_img = apply_unwhitening(whitened_img, unwhitening_matrix, mean_color)
        
        # Ensure values are in [0,1]
        result_img = torch.clamp(unwhitened_img, 0, 1)
    
    # Convert to numpy for display
    result = result_img.detach().cpu().squeeze().permute(1, 2, 0).numpy()
    
    return result, best_activation

# Create visualizations directory if it doesn't exist
os.makedirs('visualizations', exist_ok=True)

# Process features in batches of 8
batch_size = 8
num_batches = len(feature_indices) // batch_size
if len(feature_indices) % batch_size != 0:
    num_batches += 1

# Process each batch
for batch_idx in range(num_batches):
    start_idx = batch_idx * batch_size
    end_idx = min(start_idx + batch_size, len(feature_indices))
    batch_features = feature_indices[start_idx:end_idx]
    
    print(f"\nProcessing batch {batch_idx+1}/{num_batches} (features {start_idx}-{end_idx-1})")
    
    # Create figure for standard visualization
    fig1, axes1 = plt.subplots(2, 4, figsize=(16, 8))
    axes1 = axes1.flatten()
    
    # Create figure for decorrelated visualization
    fig2, axes2 = plt.subplots(2, 4, figsize=(16, 8))
    axes2 = axes2.flatten()
    
    # Visualize each feature in this batch
    for i, feature_idx in enumerate(batch_features):
        print(f"\nVisualizing feature {feature_idx} ({i+1}/{len(batch_features)})")
        
        # Standard visualization (no decorrelation)
        print("Standard visualization...")
        vis_standard, act_standard = visualize_feature(
            model=model,
            sae=sae_result,
            layer_name=layer_name,
            feature_idx=feature_idx,
            num_steps=2560,
            use_decorrelation=False
        )
        
        # Decorrelated visualization
        print("Decorrelated visualization...")
        vis_decorrelated, act_decorrelated = visualize_feature(
            model=model,
            sae=sae_result,
            layer_name=layer_name,
            feature_idx=feature_idx,
            num_steps=2560,
            use_decorrelation=True,
            color_matrices_path="color_matrices.pt"
        )
        
        # Display and save standard visualization
        axes1[i].imshow(vis_standard)
        axes1[i].set_title(f'Standard - Feature {feature_idx} (Act: {act_standard:.2f})')
        axes1[i].axis('off')
        plt.imsave(f'visualizations/standard_feature_{feature_idx}.png', vis_standard)
        
        # Display and save decorrelated visualization  
        axes2[i].imshow(vis_decorrelated)
        axes2[i].set_title(f'Decorrelated - Feature {feature_idx} (Act: {act_decorrelated:.2f})')
        axes2[i].axis('off')
        plt.imsave(f'visualizations/decorrelated_feature_{feature_idx}.png', vis_decorrelated)
    
    # Add overall titles
    fig1.suptitle(f'Standard Visualization (Batch {batch_idx+1}/{num_batches})', fontsize=16)
    fig2.suptitle(f'Decorrelated Visualization (Batch {batch_idx+1}/{num_batches})', fontsize=16)
    
    # Save batch figures
    plt.figure(fig1.number)
    plt.tight_layout()
    plt.savefig(f'visualizations/standard_batch_{batch_idx+1}.png')
    
    plt.figure(fig2.number)
    plt.tight_layout() 
    plt.savefig(f'visualizations/decorrelated_batch_{batch_idx+1}.png')
    
    # Close figures to free memory
    plt.close(fig1)
    plt.close(fig2)

print("All visualizations completed!")

In [None]:
# Import necessary libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
import gym
from PIL import Image
import os
from tqdm import tqdm

# Import visualization utilities
from feature_vis_impala import (
    total_variation, 
    jitter, 
    random_scale, 
    random_rotate
)

# Import SAE-related modules
from sae_cnn import ConvSAE, generate_batch_activations_parallel
from extract_sae_features import replace_layer_with_sae
from feature_vis_sae import load_sae_from_checkpoint
from utils.helpers import load_interpretable_model, ModelActivations

# Import color decorrelation functions
from color_decorrelation import (
    apply_whitening,
    apply_unwhitening,
    load_color_matrices
)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load model
model = load_interpretable_model()
model.to(device)
model.eval()

# Load SAE model
sae_checkpoint_path = "../checkpoints/sae_checkpoint_step_4500000.pt"
sae_result = load_sae_from_checkpoint(sae_checkpoint_path, device)
sae = sae_result[0]  # Access the SAE with sae[0]
layer_name = "conv4a"  # Update with your target layer
layer_number = 8

# Load color matrices
color_matrices_path = "color_matrices.pt"
if os.path.exists(color_matrices_path):
    cov_matrix, whitening_matrix, unwhitening_matrix, mean_color = load_color_matrices(color_matrices_path, device)
    use_decorrelation = True
else:
    whitening_matrix = torch.eye(3, device=device)
    unwhitening_matrix = torch.eye(3, device=device)
    mean_color = torch.zeros(3, device=device)
    use_decorrelation = False

# Function to extract activation patches
def extract_activation_patches(activation_map, observation, patch_size=16, stride=8):
    """
    Extract patches around high-activation regions in the activation map.
    """
    patches = []
    act_h, act_w = activation_map.shape
    
    # Ensure observation is in (64,64,3) format
    if observation.shape[0] == 64 and observation.shape[2] == 3:
        # Already in (64,64,3) format
        obs_np = observation.detach().cpu().numpy()
    elif observation.shape[0] == 3:
        # In (3,64,64) format, transpose to (64,64,3)
        obs_np = observation.detach().cpu().numpy().transpose(1, 2, 0)
    else:
        # Try to reshape to (64,64,3)
        obs_np = observation.detach().cpu().numpy()
        if obs_np.shape[0] == 64:
            obs_np = obs_np.transpose(0, 2, 1)  # (64,3,64) -> (64,64,3)
    
    obs_h, obs_w = obs_np.shape[:2]
    
    # Scale factors between activation map and original image
    scale_h = obs_h / act_h
    scale_w = obs_w / act_w
    
    # Ensure patch size isn't larger than the image
    patch_size = min(patch_size, obs_h, obs_w)
    
    # Convert patch_size and stride to activation map space
    act_patch_size = max(1, int(patch_size / scale_h))
    act_stride = max(1, int(stride / scale_h))
    
    # Find regions of high activation
    flat_activations = activation_map.flatten()
    top_k = 10  # Number of high activation points to consider
    top_indices = np.argsort(flat_activations)[-top_k:]
    
    for idx in top_indices:
        # Convert flat index back to 2D coordinates
        i, j = idx // act_w, idx % act_w
        
        # Convert to image space coordinates
        img_i = int(i * scale_h)
        img_j = int(j * scale_w)
        
        # Calculate patch boundaries
        half_size = patch_size // 2
        start_i = max(0, img_i - half_size)
        start_j = max(0, img_j - half_size)
        end_i = min(obs_h, img_i + half_size)
        end_j = min(obs_w, img_j + half_size)
        
        # Extract patch from numpy array
        patch = obs_np[start_i:end_i, start_j:end_j, :]
        
        # Get activation score for this region
        act_start_i = max(0, i - act_patch_size//2)
        act_start_j = max(0, j - act_patch_size//2)
        act_end_i = min(act_h, i + act_patch_size//2)
        act_end_j = min(act_w, j + act_patch_size//2)
        region_score = activation_map[act_start_i:act_end_i, act_start_j:act_end_j].mean().item()
        
        # Resize patch to standard size if needed
        if patch.shape[:2] != (patch_size, patch_size):
            try:
                patch_pil = Image.fromarray((patch * 255).astype(np.uint8))
                patch_pil = patch_pil.resize((patch_size, patch_size), Image.BILINEAR)
                patch = np.array(patch_pil).astype(np.float32) / 255.0
            except Exception as e:
                print(f"Failed to resize patch: {e}, patch shape: {patch.shape}")
                continue
        
        # Convert back to tensor in the original format
        patch_tensor = torch.from_numpy(patch).permute(2, 0, 1)  # HWC -> CHW
        patches.append((region_score, patch_tensor, (img_i, img_j)))
    
    # Sort by activation score
    patches.sort(key=lambda x: x[0], reverse=True)
    
    return patches

# Function to gather max activating samples
def gather_max_activating_samples(model, sae, layer_number, feature_indices, 
                                 iterations=10, batch_size=32, num_envs=8, 
                                 episode_length=150, top_k=4, diversity_weight=2.0, 
                                 patch_size=16, use_decorrelation=False):
    """
    Gather patches from the environment that maximally activate each SAE feature.
    """
    # Set up model activations
    model_activations = ModelActivations(model)
    
    # Attach SAE to the model
    sae_hook = replace_layer_with_sae(model, sae, layer_number)
    
    # Set up hooks to capture SAE activations
    sae_activations = {}
    
    def sae_activation_hook(module, input, output):
        if isinstance(output, tuple) and len(output) >= 3:
            sae_activations['features'] = output[2]
        return output
    
    # Register hook on the SAE
    hook = sae.register_forward_hook(sae_activation_hook)
    
    best_samples = {}
    total_samples = 0

    def cosine_similarity(patch1, patch2):
        """Compute cosine similarity between two patches"""
        vec1 = patch1.flatten()
        vec2 = patch2.flatten()
        norm1 = np.linalg.norm(vec1) + 1e-8
        norm2 = np.linalg.norm(vec2) + 1e-8
        sim = np.dot(vec1, vec2) / (norm1 * norm2)
        return sim

    # Initialize best_samples dictionary
    for c in feature_indices:
        best_samples[c] = []

    # Create environment
    venv = gym.make('procgen:procgen-heist-v0')
    
    try:
        for it in range(iterations):
            print(f"\nIteration {it+1}/{iterations}")
            
            # Collect batch of observations
            observations = []
            for _ in range(batch_size):
                obs = venv.reset()
                observations.append(obs)
            
            # Convert to tensor
            batch_obs = torch.tensor(np.array(observations), dtype=torch.float32).to(device)
            batch_obs = batch_obs.permute(0, 3, 1, 2)  # NHWC -> NCHW
            
            # Apply whitening if using decorrelation
            if use_decorrelation:
                # Normalize to [0,1]
                normalized_obs = batch_obs / 255.0
                
                # Apply whitening
                whitened_obs = apply_whitening(normalized_obs, whitening_matrix, mean_color)
                
                # Rescale to [0,1] range for the model
                whitened_min = whitened_obs.min()
                whitened_max = whitened_obs.max()
                model_input = (whitened_obs - whitened_min) / (whitened_max - whitened_min)
            else:
                # Just normalize
                model_input = batch_obs / 255.0
            
            # Forward pass
            with torch.no_grad():
                _ = model(model_input)
            
            # Get SAE activations
            if 'features' not in sae_activations:
                print("No SAE activations captured!")
                continue
                
            batch_acts = sae_activations['features']
            
            print(f"Batch activations shape: {batch_acts.shape}")
            print(f"Batch observations shape: {batch_obs.shape}")
            
            # Process each sample in the batch
            for b in range(batch_acts.shape[0]):
                total_samples += 1
                for c in feature_indices:
                    # Get activation map for this feature
                    activation_map = batch_acts[b, c].detach().cpu().numpy()
                    observation = batch_obs[b]
                    
                    # Extract patches around high activation regions
                    patches = extract_activation_patches(activation_map, observation, 
                                                      patch_size=patch_size)
                    
                    if not patches:
                        continue
                    
                    # Process top patches
                    for raw_score, patch, center in patches[:top_k*2]:  # Get more candidates for diversity
                        # For the first sample, just use raw score
                        if len(best_samples[c]) == 0:
                            effective_score = raw_score
                        else:
                            # Compute average similarity to existing samples
                            candidate_img = patch.detach().cpu().numpy()
                            similarities = []
                            for _, _, stored_patch, _ in best_samples[c]:
                                stored_img = stored_patch.detach().cpu().numpy()
                                sim = cosine_similarity(candidate_img, stored_img)
                                similarities.append(sim)
                            
                            avg_similarity = np.mean(similarities)
                            # Reward dissimilarity
                            diversity_bonus = diversity_weight * (1 - avg_similarity)
                            effective_score = raw_score + diversity_bonus
                        
                        # Update best samples
                        if len(best_samples[c]) < top_k:
                            best_samples[c].append((effective_score, raw_score, patch.clone(), center))
                        else:
                            min_effective_score = min(best_samples[c], key=lambda x: x[0])[0]
                            if effective_score > min_effective_score:
                                min_idx = np.argmin([s[0] for s in best_samples[c]])
                                best_samples[c][min_idx] = (effective_score, raw_score, patch.clone(), center)

            print(f"Total samples processed: {total_samples}")
    
    finally:
        # Clean up
        hook.remove()
        sae_hook.remove()
        model_activations.clear_hooks()
        venv.close()

    # Sort by effective score and keep top_k
    for c in best_samples:
        best_samples[c] = sorted(best_samples[c], key=lambda x: x[0], reverse=True)
        best_samples[c] = best_samples[c][:top_k]
    
    return best_samples

# Function to visualize max activations
def visualize_max_activations(best_samples, target_layer_name, use_decorrelation=False):
    """Visualize the patches that maximally activate each feature."""
    print(f"Visualizing samples for {len(best_samples)} features")
    
    # Create directory for saving visualizations
    os.makedirs('max_activations', exist_ok=True)
    
    for c, samples in best_samples.items():
        num_samples = len(samples)
        if num_samples == 0:
            print(f"Skipping feature {c} - no samples found")
            continue
            
        print(f"Feature {c}: Found {num_samples} samples")
        fig, axs = plt.subplots(1, num_samples, figsize=(3 * num_samples, 3))
        if num_samples == 1:
            axs = [axs]
        
        for i, (effective_score, raw_score, patch, center) in enumerate(samples):
            img = patch.detach().cpu().numpy()  # Already in CHW format
            img = img.transpose(1, 2, 0)  # Convert to HWC for plotting
            
            axs[i].imshow(img)
            axs[i].set_title(f"Raw: {raw_score:.2f}\nEff: {effective_score:.2f}\nPos: {center}")
            axs[i].axis('off')
            
            # Save individual patch
            plt.imsave(f'max_activations/feature_{c}_patch_{i}.png', img)
            
        # Add title with decorrelation info
        title = f"Feature {c}: Max Activating Patches for {target_layer_name}"
        if use_decorrelation:
            title += " (Decorrelated)"
        fig.suptitle(title)
        
        plt.tight_layout()
        plt.savefig(f'max_activations/feature_{c}_patches.png')
        plt.close(fig)

# Function to visualize feature with decorrelation
def visualize_feature_with_decorrelation(model, sae, layer_name, feature_idx, 
                                        num_steps=1000, use_decorrelation=False):
    """Generate a visualization of what maximally activates a specific SAE feature."""
    # Attach SAE to the model
    sae_hook = replace_layer_with_sae(model, sae, layer_number)
    
    # Set up hooks to capture activations
    activations = {}
    
    def sae_activation_hook(module, input, output):
        if isinstance(output, tuple) and len(output) >= 3:
            activations['sae_features'] = output[2]
        return output
    
    # Register hook on the SAE
    hook = sae.register_forward_hook(sae_activation_hook)
    
    # Initialize image
    padded_size = 64 + 8  # 4 pixels on each side
    input_img = torch.randint(0, 256, (1, 3, padded_size, padded_size), 
                            device=device, 
                            dtype=torch.float32).requires_grad_(True)
    
    # Set up optimizer
    optimizer = torch.optim.Adam([input_img], lr=0.08)
    
    # Parameters for transformations
    scales = [1.0, 0.975, 1.025, 0.95, 1.05]
    angles = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]
    
    best_activation = float('-inf')
    best_img = None
    
    try:
        # Optimization loop
        pbar = tqdm(range(num_steps))
        for step in pbar:
            optimizer.zero_grad()
            
            # Create a copy for transformations
            processed_img = input_img.clone()
            
            # Apply transformations
            ox, oy = np.random.randint(-8, 9, 2)
            processed_img = jitter(processed_img, ox, oy)
            processed_img = random_scale(processed_img, scales)
            processed_img = random_rotate(processed_img, angles)
            ox, oy = np.random.randint(-4, 5, 2)
            processed_img = jitter(processed_img, ox, oy)
            
            # Crop padding
            processed_img = processed_img[:, :, 4:-4, 4:-4]
            
            # Ensure values are in [0, 255]
            processed_img.data.clamp_(0, 255)
            
            # Normalize to [0,1] for model input
            normalized_img = processed_img / 255.0
            
            # Apply whitening if using decorrelation
            if use_decorrelation:
                # Apply whitening
                whitened_img = apply_whitening(normalized_img, whitening_matrix, mean_color)
                
                # Rescale to [0,1] range for the model
                whitened_min = whitened_img.min()
                whitened_max = whitened_img.max()
                model_input = (whitened_img - whitened_min) / (whitened_max - whitened_min)
            else:
                # Use the normalized image directly
                model_input = normalized_img
            
            # Forward pass
            _ = model(model_input)
            
            # Get the activation for the target feature
            if 'sae_features' in activations:
                feature_activation = activations['sae_features'][0, feature_idx]
                activation_loss = -feature_activation.mean()  # negative because we want to maximize
            else:
                activation_loss = torch.tensor(0.0, device=device)
            
            # Calculate regularization losses
            tv_loss = 1e-3 * total_variation(input_img / 255.0)
            l2_loss = 1e-3 * torch.norm(input_img / 255.0)
            
            # Total loss
            loss = activation_loss + tv_loss + l2_loss
            
            loss.backward()
            optimizer.step()
            
            # Post-processing
            with torch.no_grad():
                input_img.data.clamp_(0, 255)
                
                # Track best activation
                current_activation = -activation_loss.item()
                if current_activation > best_activation:
                    best_activation = current_activation
                    best_img = processed_img.clone()
            
            # Update progress bar
            pbar.set_postfix({'loss': loss.item(), 'act': current_activation})
        
        # Process the final image
        result_img = best_img / 255.0  # Normalize to [0,1]
        
        # If we used decorrelation, apply it to the final image for visualization
        if use_decorrelation:
            # Apply whitening
            whitened_img = apply_whitening(result_img, whitening_matrix, mean_color)
            
            # Apply unwhitening to get back to normal color space
            unwhitened_img = apply_unwhitening(whitened_img, unwhitening_matrix, mean_color)
            
            # Ensure values are in [0,1]
            result_img = torch.clamp(unwhitened_img, 0, 1)
        
        # Convert to numpy for display
        result = result_img.detach().cpu().squeeze().permute(1, 2, 0).numpy()
        
        return result, best_activation
    
    finally:
        # Clean up hooks
        hook.remove()
        sae_hook.remove()

# Main execution
# Select features to visualize
feature_indices = list(range(8))  # Start with first 8 features

# Create directory for saving visualizations
os.makedirs('visualizations', exist_ok=True)

# First, generate synthetic visualizations for each feature
print("Generating synthetic visualizations...")
for feature_idx in feature_indices:
    print(f"\nVisualizing feature {feature_idx}")
    
    # Standard visualization
    print("Standard visualization...")
    vis_standard, act_standard = visualize_feature_with_decorrelation(
        model=model,
        sae=sae,
        layer_name=layer_name,
        feature_idx=feature_idx,
        num_steps=1000,
        use_decorrelation=False
    )
    
    # Decorrelated visualization
    print("Decorrelated visualization...")
    vis_decorrelated, act_decorrelated = visualize_feature_with_decorrelation(
        model=model,
        sae=sae,
        layer_name=layer_name,
        feature_idx=feature_idx,
        num_steps=1000,
        use_decorrelation=True
    )
    
    # Save individual visualizations
    plt.imsave(f'visualizations/standard_feature_{feature_idx}.png', vis_standard)
    plt.imsave(f'visualizations/decorrelated_feature_{feature_idx}.png', vis_decorrelated)

# Now, gather max activating samples from the environment
print("\nGathering max activating samples...")
best_samples_standard = gather_max_activating_samples(
    model=model,
    sae=sae,
    layer_number=layer_number,
    feature_indices=feature_indices,
    iterations=5,
    batch_size=16,
    top_k=4,
    patch_size=32,
    use_decorrelation=False
)

best_samples_decorrelated = gather_max_activating_samples(
    model=model,
    sae=sae,
    layer_number=layer_number,
    feature_indices=feature_indices,
    iterations=5,
    batch_size=16,
    top_k=4,
    patch_size=32,
    use_decorrelation=True
)

# Visualize the max activating samples
print("\nVisualizing max activating samples...")
visualize_max_activations(best_samples_standard, layer_name, use_decorrelation=False)
visualize_max_activations(best_samples_decorrelated, layer_name, use_decorrelation=True)

print("All visualizations completed!")

In [None]:
# Import necessary libraries
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from utils.helpers import load_interpretable_model
import os
from tqdm import tqdm
from PIL import Image

# Import visualization utilities
from feature_vis_impala import (
    total_variation, 
    jitter, 
    random_scale, 
    random_rotate
)

# Import SAE-related modules
from sae_cnn import ConvSAE
from extract_sae_features import replace_layer_with_sae
from feature_vis_sae import load_sae_from_checkpoint

# Import color decorrelation functions
from color_decorrelation import (
    apply_whitening,
    apply_unwhitening,
    load_color_matrices
)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load model
model = load_interpretable_model()
model.to(device)
model.eval()

# Load SAE model
sae_checkpoint_path = "../checkpoints/sae_checkpoint_step_4500000.pt"
sae_result = load_sae_from_checkpoint(sae_checkpoint_path, device)
sae = sae_result[0]  # Access the SAE model
layer_name = "conv4a"  # Target layer
layer_number = 8

# Load color matrices
color_matrices_path = "color_matrices.pt"
if os.path.exists(color_matrices_path):
    cov_matrix, whitening_matrix, unwhitening_matrix, mean_color = load_color_matrices(color_matrices_path, device)
    print("Loaded color matrices successfully")
else:
    print("Color matrices not found, using identity matrices")
    whitening_matrix = torch.eye(3, device=device)
    unwhitening_matrix = torch.eye(3, device=device)
    mean_color = torch.zeros(3, device=device)

def generate_synthetic_image(model, sae, feature_idx, num_steps=1000, use_decorrelation=False):
    """
    Generate a synthetic image that maximally activates a specific SAE feature.
    
    Args:
        model: The base model
        sae: The SAE model
        feature_idx: Index of the feature to visualize
        num_steps: Number of optimization steps
        use_decorrelation: Whether to use color decorrelation
        
    Returns:
        image: The optimized image as a tensor (1, 3, 64, 64)
        activation: The activation value achieved
    """
    print(f"Generating synthetic image for feature {feature_idx}")
    print(f"Using decorrelation: {use_decorrelation}")
    
    # Attach SAE to the model
    sae_hook = replace_layer_with_sae(model, sae, layer_number)
    
    # Set up hooks to capture activations
    activations = {}
    
    def sae_activation_hook(module, input, output):
        if isinstance(output, tuple) and len(output) >= 3:
            activations['sae_features'] = output[2]
        return output
    
    # Register hook on the SAE
    hook = sae.register_forward_hook(sae_activation_hook)
    
    try:
        # Initialize a random image with padding
        padded_size = 64 + 8  # 4 pixels padding on each side
        input_img = torch.randint(0, 256, (1, 3, padded_size, padded_size), 
                                device=device, 
                                dtype=torch.float32).requires_grad_(True)
        
        # Set up optimizer
        optimizer = torch.optim.Adam([input_img], lr=0.08)
        
        # Parameters for transformations
        scales = [1.0, 0.975, 1.025, 0.95, 1.05]
        angles = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]
        
        best_activation = float('-inf')
        best_img = None
        
        # Optimization loop
        pbar = tqdm(range(num_steps))
        for step in pbar:
            optimizer.zero_grad()
            
            # Create a copy for transformations
            processed_img = input_img.clone()
            
            # Apply transformations
            ox, oy = np.random.randint(-4, 5, 2)
            processed_img = jitter(processed_img, ox, oy)
            processed_img = random_scale(processed_img, scales)
            processed_img = random_rotate(processed_img, angles)
            ox, oy = np.random.randint(-2, 3, 2)
            processed_img = jitter(processed_img, ox, oy)
            
            # Crop padding
            processed_img = processed_img[:, :, 4:-4, 4:-4]
            
            # Ensure values are in [0, 255]
            processed_img.data.clamp_(0, 255)
            
            # Normalize to [0,1] for model input
            normalized_img = processed_img / 255.0
            
            # Apply whitening if using decorrelation
            if use_decorrelation:
                # Apply whitening
                whitened_img = apply_whitening(normalized_img, whitening_matrix, mean_color)
                
                # Rescale to [0,1] range for the model
                whitened_min = whitened_img.min()
                whitened_max = whitened_img.max()
                model_input = (whitened_img - whitened_min) / (whitened_max - whitened_min)
            else:
                # Use the normalized image directly
                model_input = normalized_img
            
            # Forward pass
            _ = model(model_input)
            
            # Get the activation for the target feature
            if 'sae_features' in activations:
                feature_activation = activations['sae_features'][0, feature_idx]
                activation_loss = -feature_activation.mean()  # negative because we want to maximize
            else:
                activation_loss = torch.tensor(0.0, device=device)
            
            # Calculate regularization losses
            tv_loss = 1e-3 * total_variation(input_img / 255.0)
            l2_loss = 1e-3 * torch.norm(input_img / 255.0)
            
            # Total loss
            loss = activation_loss + tv_loss + l2_loss
            
            loss.backward()
            optimizer.step()
            
            # Post-processing
            with torch.no_grad():
                input_img.data.clamp_(0, 255)
                
                # Track best activation
                current_activation = -activation_loss.item()
                if current_activation > best_activation:
                    best_activation = current_activation
                    best_img = processed_img.clone()
            
            # Update progress bar
            pbar.set_postfix({'loss': loss.item(), 'act': current_activation})
        
        # Process the final image
        result_img = best_img / 255.0  # Normalize to [0,1]
        
        # If we used decorrelation, apply it to the final image for visualization
        if use_decorrelation:
            # Apply whitening
            whitened_img = apply_whitening(result_img, whitening_matrix, mean_color)
            
            # Apply unwhitening to get back to normal color space
            unwhitened_img = apply_unwhitening(whitened_img, unwhitening_matrix, mean_color)
            
            # Ensure values are in [0,1]
            result_img = torch.clamp(unwhitened_img, 0, 1)
        
        return result_img, best_activation
        
    finally:
        # Clean up hooks
        hook.remove()
        sae_hook.remove()

def get_feature_activation_map(model, sae, image, feature_idx):
    """
    Get the activation map for a specific feature when running the image through the model.
    
    Args:
        model: The base model
        sae: The SAE model
        image: Input image tensor (1, 3, H, W)
        feature_idx: Index of the feature to get activations for
        
    Returns:
        activation_map: 2D numpy array of activations
    """
    # Attach SAE to the model
    sae_hook = replace_layer_with_sae(model, sae, layer_number)
    
    # Set up hooks to capture activations
    activations = {}
    
    def sae_activation_hook(module, input, output):
        if isinstance(output, tuple) and len(output) >= 3:
            activations['sae_features'] = output[2]
        return output
    
    # Register hook on the SAE
    hook = sae.register_forward_hook(sae_activation_hook)
    
    try:
        # Forward pass
        with torch.no_grad():
            _ = model(image)
        
        # Get the activation map for the target feature
        if 'sae_features' in activations:
            feature_map = activations['sae_features'][0, feature_idx]
            return feature_map.detach().cpu().numpy()
        else:
            raise ValueError("No activations captured")
    
    finally:
        # Clean up hooks
        hook.remove()
        sae_hook.remove()

def find_high_activation_regions(activation_map, top_k=3):
    """
    Find regions with highest activation in the activation map.
    
    Args:
        activation_map: 2D numpy array of activations
        top_k: Number of top regions to return
        
    Returns:
        regions: List of dictionaries with region information
    """
    act_h, act_w = activation_map.shape
    flat_activations = activation_map.flatten()
    
    # Find top-k activation points
    top_indices = np.argsort(flat_activations)[-top_k:]
    
    regions = []
    for idx in top_indices:
        # Convert flat index to 2D coordinates
        i, j = idx // act_w, idx % act_w
        
        # Get activation score
        score = flat_activations[idx]
        
        regions.append({
            'act_coords': (i, j),
            'score': score
        })
    
    # Sort by score (highest first)
    regions.sort(key=lambda x: x['score'], reverse=True)
    
    return regions

def map_activation_to_image_coords(region, act_shape, img_shape):
    """
    Map coordinates from activation space to image space.
    
    Args:
        region: Dictionary with region information
        act_shape: Shape of the activation map (H, W)
        img_shape: Shape of the image (H, W)
        
    Returns:
        img_coords: Tuple (i, j) of image coordinates
    """
    act_h, act_w = act_shape
    img_h, img_w = img_shape
    
    # Scale factors
    scale_h = img_h / act_h
    scale_w = img_w / act_w
    
    # Get activation coordinates
    act_i, act_j = region['act_coords']
    
    # Convert to image coordinates
    img_i = int(act_i * scale_h)
    img_j = int(act_j * scale_w)
    
    return (img_i, img_j)

def extract_patch(image, center_coords, patch_size=24):
    """
    Extract a patch from the image centered at the given coordinates.
    
    Args:
        image: Image tensor (1, 3, H, W)
        center_coords: Tuple (i, j) of center coordinates
        patch_size: Size of the patch to extract
        
    Returns:
        patch: Extracted patch tensor (1, 3, patch_size, patch_size)
    """
    img_i, img_j = center_coords
    half_size = patch_size // 2
    
    # Calculate patch boundaries
    start_i = max(0, img_i - half_size)
    start_j = max(0, img_j - half_size)
    end_i = min(image.shape[2], img_i + half_size)
    end_j = min(image.shape[3], img_j + half_size)
    
    # Extract patch
    patch = image[:, :, start_i:end_i, start_j:end_j]
    
    # Resize to patch_size if needed
    if patch.shape[2] != patch_size or patch.shape[3] != patch_size:
        patch = F.interpolate(patch, size=(patch_size, patch_size), mode='bilinear', align_corners=False)
    
    return patch

def optimize_patch(model, sae, feature_idx, initial_patch, center_coords, num_steps=2000, use_decorrelation=False):
    """
    Optimize a patch to maximally activate a specific feature.
    
    Args:
        model: The base model
        sae: The SAE model
        feature_idx: Index of the feature to optimize for
        initial_patch: Initial patch tensor (1, 3, patch_size, patch_size)
        center_coords: Tuple (i, j) of center coordinates in the full image
        num_steps: Number of optimization steps
        use_decorrelation: Whether to use color decorrelation
        
    Returns:
        optimized_patch: The optimized patch tensor (1, 3, patch_size, patch_size)
    """
    print(f"Optimizing patch for feature {feature_idx} at coordinates {center_coords}")
    
    # Get patch size
    patch_size = initial_patch.shape[2]
    
    # Create a padded version for transformations
    padded_size = patch_size + 8
    padded_patch = F.pad(initial_patch, (4, 4, 4, 4), mode='reflect')
    padded_patch = padded_patch.clone().detach().requires_grad_(True)
    
    # Attach SAE to the model
    sae_hook = replace_layer_with_sae(model, sae, layer_number)
    
    # Set up hooks to capture activations
    activations = {}
    
    def sae_activation_hook(module, input, output):
        if isinstance(output, tuple) and len(output) >= 3:
            activations['sae_features'] = output[2]
        return output
    
    # Register hook on the SAE
    hook = sae.register_forward_hook(sae_activation_hook)
    
    # Calculate activation map coordinates
    img_h, img_w = 64, 64
    act_h, act_w = 8, 8  # Assuming 8x8 feature maps for conv4a
    scale_h = img_h / act_h
    scale_w = img_w / act_w
    
    act_i = int(center_coords[0] / scale_h)
    act_j = int(center_coords[1] / scale_w)
    
    # Set up optimizer
    optimizer = torch.optim.Adam([padded_patch], lr=0.08)
    
    # Parameters for transformations
    scales = [1.0, 0.975, 1.025, 0.95, 1.05]
    angles = [-5, -4, -3, -2, -1, 0, 1, 2, 3, 4, 5]
    
    best_activation = float('-inf')
    best_patch = None
    
    try:
        # Optimization loop
        pbar = tqdm(range(num_steps))
        for step in pbar:
            optimizer.zero_grad()

            # Create a copy for transformations
            processed_patch = padded_patch.clone()

            # Apply transformations
            ox, oy = np.random.randint(-4, 5, 2)
            processed_patch = jitter(processed_patch, ox, oy)
            processed_patch = random_scale(processed_patch, scales)
            processed_patch = random_rotate(processed_patch, angles)
            ox, oy = np.random.randint(-2, 3, 2)
            processed_patch = jitter(processed_patch, ox, oy)

            # Crop padding
            processed_patch = processed_patch[:, :, 4:-4, 4:-4]

            # Clamp values to [0, 1]
            processed_patch.data.clamp_(0, 1)

            # Explicitly enforce size consistency after transformations
            processed_patch = F.interpolate(
                processed_patch,
                size=(patch_size, patch_size),
                mode='bilinear',
                align_corners=False
            )

            # Create a full-sized image with the patch
            full_img = torch.zeros((1, 3, 64, 64), device=device)

            # Calculate where to place the patch
            img_i, img_j = center_coords
            half_size = patch_size // 2

            # Ensure the patch fits within the image bounds
            start_i = max(0, img_i - half_size)
            start_j = max(0, img_j - half_size)
            end_i = min(64, start_i + patch_size)
            end_j = min(64, start_j + patch_size)

            # Calculate the actual size of the region
            region_height = end_i - start_i
            region_width = end_j - start_j

            # Ensure patch matches region size exactly
            if (region_height != patch_size) or (region_width != patch_size):
                patch_to_place = F.interpolate(
                    processed_patch,
                    size=(region_height, region_width),
                    mode='bilinear',
                    align_corners=False
                )
            else:
                patch_to_place = processed_patch

            # Place the patch in the full image
            full_img[:, :, start_i:end_i, start_j:end_j] = patch_to_place

            # Forward pass
            _ = model(full_img)

            # Get activation and compute loss
            if 'sae_features' in activations:
                feature_map = activations['sae_features'][0, feature_idx]
                radius = 1
                act_start_i = max(0, act_i - radius)
                act_start_j = max(0, act_j - radius)
                act_end_i = min(act_h, act_i + radius + 1)
                act_end_j = min(act_w, act_j + radius + 1)
                region_activation = feature_map[act_start_i:act_end_i, act_start_j:act_end_j]
                activation_loss = -region_activation.mean()
            else:
                activation_loss = torch.tensor(0.0, device=device)

            # Regularization losses
            tv_loss = 1e-3 * total_variation(padded_patch)
            l2_loss = 1e-3 * torch.norm(padded_patch)

            # Total loss
            loss = activation_loss + tv_loss + l2_loss

            loss.backward()
            optimizer.step()

            # Post-processing
            with torch.no_grad():
                padded_patch.data.clamp_(0, 1)

                current_activation = -activation_loss.item()
                if current_activation > best_activation:
                    best_activation = current_activation
                    best_patch = processed_patch.clone()

            # Update progress bar
            pbar.set_postfix({'loss': loss.item(), 'act': current_activation})
        
        # Process the final patch
        result_patch = best_patch.clone()
        
        # If we used decorrelation, apply it to the final patch for visualization
        if use_decorrelation:
            # Create a full image with the patch
            full_img = torch.zeros((1, 3, 64, 64), device=device)
            
            # Resize the patch to match the region exactly
            patch_to_place = F.interpolate(
                result_patch, 
                size=(region_height, region_width), 
                mode='bilinear', 
                align_corners=False
            )
            
            # Place the patch in the full image
            full_img[:, :, start_i:end_i, start_j:end_j] = patch_to_place
            
            # Apply whitening
            whitened_img = apply_whitening(full_img, whitening_matrix, mean_color)
            
            # Apply unwhitening to get back to normal color space
            unwhitened_img = apply_unwhitening(whitened_img, unwhitening_matrix, mean_color)
            
            # Extract the patch again
            result_patch = unwhitened_img[:, :, start_i:end_i, start_j:end_j]
            
            # Ensure values are in [0,1]
            result_patch = torch.clamp(result_patch, 0, 1)
        
        # Resize to standard patch size if needed
        if result_patch.shape[2] != patch_size or result_patch.shape[3] != patch_size:
            result_patch = F.interpolate(
                result_patch, 
                size=(patch_size, patch_size), 
                mode='bilinear', 
                align_corners=False
            )
        
        return result_patch
        
    finally:
        # Clean up hooks
        hook.remove()
        sae_hook.remove()

def generate_and_optimize_patches(model, sae, feature_idx, num_steps=1000, use_decorrelation=False):
    """
    Generate synthetic images and optimize specific regions that maximally activate a feature.
    
    Args:
        model: The base model
        sae: The SAE model
        feature_idx: Index of the feature to visualize
        num_steps: Total number of optimization steps
        use_decorrelation: Whether to use color decorrelation
        
    Returns:
        patches: List of tuples (patch, score, coordinates)
    """
    # Step 1: Generate initial full-size synthetic image
    print(f"Step 1: Generating initial synthetic image for feature {feature_idx}")
    initial_image, _ = generate_synthetic_image(model, sae, feature_idx, num_steps=num_steps//2, use_decorrelation=use_decorrelation)
    
    # Step 2: Identify high-activation regions
    print(f"Step 2: Identifying high-activation regions")
    activation_map = get_feature_activation_map(model, sae, initial_image, feature_idx)
    high_activation_regions = find_high_activation_regions(activation_map, top_k=3)
    
    print(f"Found {len(high_activation_regions)} high-activation regions")
    for i, region in enumerate(high_activation_regions):
        print(f"  Region {i+1}: coords={region['act_coords']}, score={region['score']:.4f}")
    
    # Step 3: Extract and optimize patches for each high-activation region
    print(f"Step 3: Extracting and optimizing patches")
    optimized_patches = []
    for i, region in enumerate(high_activation_regions):
        # Extract coordinates in image space
        img_coords = map_activation_to_image_coords(region, activation_map.shape, (64, 64))
        print(f"  Region {i+1}: activation coords={region['act_coords']}, image coords={img_coords}")
        
        # Extract initial patch
        initial_patch = extract_patch(initial_image, img_coords, patch_size=24)
        
        # Optimize just this patch
        optimized_patch = optimize_patch(model, sae, feature_idx, initial_patch, img_coords, 
                                       num_steps=num_steps//2, use_decorrelation=use_decorrelation)
        
        optimized_patches.append((optimized_patch, region['score'], img_coords))
    
    return optimized_patches

# Create directory for saving visualizations
os.makedirs('synthetic_patches', exist_ok=True)

# Select features to visualize
feature_indices = list(range(8))  # Start with first 8 features

# Generate and optimize patches for each feature
for feature_idx in feature_indices:
    print(f"\nProcessing feature {feature_idx}")
    
    # Standard visualization (no decorrelation)
    print("Generating standard patches...")
    standard_patches = generate_and_optimize_patches(
        model=model,
        sae=sae,
        feature_idx=feature_idx,
        num_steps=1000,
        use_decorrelation=False
    )
    
    # Decorrelated visualization
    print("Generating decorrelated patches...")
    decorrelated_patches = generate_and_optimize_patches(
        model=model,
        sae=sae,
        feature_idx=feature_idx,
        num_steps=1000,
        use_decorrelation=True
    )
    
    # Create a figure to display the results
    num_patches = len(standard_patches)
    fig, axes = plt.subplots(2, num_patches, figsize=(4*num_patches, 8))
    
    # Display standard patches
    for i, (patch, score, coords) in enumerate(standard_patches):
        patch_np = patch.detach().cpu().squeeze().permute(1, 2, 0).numpy()
        patch_np = np.clip(patch_np, 0, 1)
        
        axes[0, i].imshow(patch_np)
        axes[0, i].set_title(f'Standard\nScore: {score:.2f}\nPos: {coords}')
        axes[0, i].axis('off')
        
        # Save individual patch
        plt.imsave(f'synthetic_patches/feature_{feature_idx}_standard_patch_{i}.png', patch_np)
    
    # Display decorrelated patches
    for i, (patch, score, coords) in enumerate(decorrelated_patches):
        patch_np = patch.detach().cpu().squeeze().permute(1, 2, 0).numpy()
        patch_np = np.clip(patch_np, 0, 1)
        
        axes[1, i].imshow(patch_np)
        axes[1, i].set_title(f'Decorrelated\nScore: {score:.2f}\nPos: {coords}')
        axes[1, i].axis('off')
        
        # Save individual patch
        plt.imsave(f'synthetic_patches/feature_{feature_idx}_decorrelated_patch_{i}.png', patch_np)
    
    # Add overall title
    fig.suptitle(f'Feature {feature_idx} Synthetic Patches', fontsize=16)
    
    # Save the combined figure
    plt.tight_layout()
    plt.savefig(f'synthetic_patches/feature_{feature_idx}_patches.png')
    plt.close(fig)

print("All synthetic patches generated successfully!")

In [None]:
# Import necessary libraries
import torch
import numpy as np
import matplotlib.pyplot as plt
import gym
from PIL import Image
import os
from tqdm import tqdm

# Import necessary functions
from utils.helpers import load_interpretable_model
from sae_cnn import ConvSAE
from extract_sae_features import replace_layer_with_sae
from feature_vis_sae import load_sae_from_checkpoint

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# Load model
model = load_interpretable_model()
model.to(device)
model.eval()

# Load SAE model - explicitly set layer_number instead of trying to get it from the function
sae_checkpoint_path = "../checkpoints/sae_checkpoint_step_4500000.pt"
sae = load_sae_from_checkpoint(sae_checkpoint_path, device)[0]  # Just get the SAE
layer_name = "conv4a"  # Target layer
layer_number = 8  # Explicitly set layer number

def cosine_similarity(patch1, patch2):
    """
    Compute cosine similarity between two patches.
    
    Args:
        patch1: First patch tensor
        patch2: Second patch tensor
        
    Returns:
        similarity: Cosine similarity value
    """
    # Flatten patches
    vec1 = patch1.flatten()
    vec2 = patch2.flatten()
    
    # Compute cosine similarity
    norm1 = torch.norm(vec1) + 1e-8
    norm2 = torch.norm(vec2) + 1e-8
    sim = torch.dot(vec1, vec2) / (norm1 * norm2)
    
    return sim.item()

def extract_activation_patches(activation_map, observation, patch_size=24, top_k=5):
    """
    Extract patches around high-activation regions in the activation map.
    
    Args:
        activation_map: 2D tensor of activations for a specific feature
        observation: The corresponding input image (3, 64, 64)
        patch_size: Size of patches to extract
        top_k: Number of top activations to consider
        
    Returns:
        patches: List of (score, patch, coordinates) tuples
    """
    patches = []
    act_h, act_w = activation_map.shape
    
    # Convert observation to numpy for easier handling
    if observation.shape[0] == 3:  # CHW format
        obs_np = observation.detach().cpu().numpy().transpose(1, 2, 0)  # Convert to HWC
    else:
        obs_np = observation.detach().cpu().numpy()
    
    # Ensure observation is in correct format
    if obs_np.shape != (64, 64, 3):
        print(f"Warning: Unexpected observation shape: {obs_np.shape}")
        return patches
    
    # Scale factors between activation map and original image
    scale_h = 64 / act_h
    scale_w = 64 / act_w
    
    # Find regions of high activation
    flat_activations = activation_map.flatten()
    top_indices = np.argsort(flat_activations.detach().cpu().numpy())[-top_k:]
    
    for idx in top_indices:
        # Convert flat index to 2D coordinates
        i, j = idx // act_w, idx % act_w
        
        # Convert to image space coordinates
        img_i = int(i * scale_h)
        img_j = int(j * scale_w)
        
        # Calculate patch boundaries
        half_size = patch_size // 2
        start_i = max(0, img_i - half_size)
        start_j = max(0, img_j - half_size)
        end_i = min(64, img_i + half_size)
        end_j = min(64, img_j + half_size)
        
        # Extract patch
        patch = obs_np[start_i:end_i, start_j:end_j, :]
        
        # Get activation score for this region
        act_start_i = max(0, i - 1)
        act_start_j = max(0, j - 1)
        act_end_i = min(act_h, i + 2)
        act_end_j = min(act_w, j + 2)
        region_score = activation_map[act_start_i:act_end_i, act_start_j:act_end_j].mean().item()
        
        # Resize patch to standard size if needed
        if patch.shape[:2] != (patch_size, patch_size):
            try:
                patch_pil = Image.fromarray((patch * 255).astype(np.uint8))
                patch_pil = patch_pil.resize((patch_size, patch_size), Image.BILINEAR)
                patch = np.array(patch_pil).astype(np.float32) / 255.0
            except Exception as e:
                print(f"Failed to resize patch: {e}, patch shape: {patch.shape}")
                continue
        
        # Convert back to tensor
        patch_tensor = torch.from_numpy(patch).permute(2, 0, 1)  # HWC -> CHW
        patches.append((region_score, patch_tensor, (img_i, img_j)))
    
    # Sort by activation score
    patches.sort(key=lambda x: x[0], reverse=True)
    
    return patches

def generate_batch_observations(num_samples=100, env_name='procgen:procgen-heist-v0'):
    """
    Generate a batch of observations from the environment.
    
    Args:
        num_samples: Number of observations to collect
        env_name: Name of the environment
        
    Returns:
        observations: List of observation tensors
    """
    print(f"Collecting {num_samples} observations from {env_name}")
    
    env = gym.make(env_name)
    observations = []
    
    with tqdm(total=num_samples) as pbar:
        while len(observations) < num_samples:
            obs = env.reset()
            done = False
            
            while not done and len(observations) < num_samples:
                # Convert observation to tensor
                obs_tensor = torch.from_numpy(obs).permute(2, 0, 1).float() / 255.0
                observations.append(obs_tensor)
                
                # Take a random action
                action = env.action_space.sample()
                obs, _, done, _ = env.step(action)
                
                pbar.update(1)
                if len(observations) >= num_samples:
                    break
    
    env.close()
    return observations

def get_sae_feature_activations(model, sae, observations, feature_idx):
    """
    Get activations for a specific SAE feature across multiple observations.
    
    Args:
        model: The base model
        sae: The SAE model
        observations: List of observation tensors
        feature_idx: Index of the feature to get activations for
        
    Returns:
        activations: List of (activation_map, observation) tuples
    """
    print(f"Getting activations for feature {feature_idx}")
    
    # Attach SAE to the model
    sae_hook = replace_layer_with_sae(model, sae, layer_number)
    
    # Set up hooks to capture activations
    all_activations = []
    
    def sae_activation_hook(module, input, output):
        if isinstance(output, tuple) and len(output) >= 3:
            # Store the activation for the specific feature
            all_activations.append(output[2][0, feature_idx])
        return output
    
    # Register hook on the SAE
    hook = sae.register_forward_hook(sae_activation_hook)
    
    try:
        # Process each observation
        results = []
        for obs in tqdm(observations):
            # Add batch dimension
            obs_batch = obs.unsqueeze(0).to(device)
            
            # Forward pass
            with torch.no_grad():
                _ = model(obs_batch)
            
            # Get the activation for this observation
            if all_activations:
                activation_map = all_activations.pop()
                results.append((activation_map, obs))
        
        return results
    
    finally:
        # Clean up hooks
        hook.remove()
        sae_hook.remove()

def find_max_activating_patches(model, sae, feature_idx, num_samples=100, patch_size=24, top_k=5, diversity_weight=2.0):
    """
    Find diverse patches that maximally activate a specific SAE feature.
    
    Args:
        model: The base model
        sae: The SAE model
        feature_idx: Index of the feature to find patches for
        num_samples: Number of observations to process
        patch_size: Size of patches to extract
        top_k: Number of top patches to return
        diversity_weight: Weight for diversity bonus
        
    Returns:
        max_patches: List of (effective_score, raw_score, patch, coordinates) tuples
    """
    # Generate observations
    observations = generate_batch_observations(num_samples)
    
    # Get activations for the feature
    activations = get_sae_feature_activations(model, sae, observations, feature_idx)
    
    # Extract patches from high-activation regions
    all_patches = []
    for activation_map, obs in activations:
        patches = extract_activation_patches(activation_map, obs, patch_size, top_k=3)
        all_patches.extend(patches)
    
    # Sort by activation score
    all_patches.sort(key=lambda x: x[0], reverse=True)
    
    # Select diverse patches
    selected_patches = []
    
    # Process candidate patches
    for raw_score, patch, coords in all_patches:
        # For the first patch, just use raw score
        if len(selected_patches) == 0:
            effective_score = raw_score
            selected_patches.append((effective_score, raw_score, patch, coords))
            continue
        
        # Compute similarity to existing selected patches
        similarities = []
        for _, _, existing_patch, _ in selected_patches:
            sim = cosine_similarity(patch, existing_patch)
            similarities.append(sim)
        
        # Calculate average similarity
        avg_similarity = sum(similarities) / len(similarities)
        
        # Apply diversity bonus
        diversity_bonus = diversity_weight * (1.0 - avg_similarity)
        effective_score = raw_score + diversity_bonus
        
        # Add to selected patches if better than existing or if we need more
        if len(selected_patches) < top_k:
            selected_patches.append((effective_score, raw_score, patch, coords))
        else:
            # Find the patch with the lowest effective score
            min_idx = min(range(len(selected_patches)), key=lambda i: selected_patches[i][0])
            if effective_score > selected_patches[min_idx][0]:
                selected_patches[min_idx] = (effective_score, raw_score, patch, coords)
        
        # Sort by effective score
        selected_patches.sort(key=lambda x: x[0], reverse=True)
        
        # Keep only top_k
        if len(selected_patches) > top_k:
            selected_patches = selected_patches[:top_k]
    
    return selected_patches

def visualize_max_activations(feature_idx, max_patches):
    """
    Visualize patches that maximally activate a specific feature.
    
    Args:
        feature_idx: Index of the feature
        max_patches: List of (effective_score, raw_score, patch, coordinates) tuples
    """
    # Create directory for saving visualizations
    os.makedirs('max_activations', exist_ok=True)
    
    # Create a figure
    num_patches = len(max_patches)
    fig, axes = plt.subplots(1, num_patches, figsize=(4*num_patches, 4))
    
    if num_patches == 1:
        axes = [axes]
    
    # Display each patch
    for i, (effective_score, raw_score, patch, coords) in enumerate(max_patches):
        # Convert patch to numpy for display
        patch_np = patch.detach().cpu().numpy().transpose(1, 2, 0)
        patch_np = np.clip(patch_np, 0, 1)
        
        # Display patch
        axes[i].imshow(patch_np)
        axes[i].set_title(f'Raw: {raw_score:.2f}\nEff: {effective_score:.2f}\nPos: {coords}')
        axes[i].axis('off')
        
        # Save individual patch
        plt.imsave(f'max_activations/feature_{feature_idx}_patch_{i}.png', patch_np)
    
    # Add overall title
    fig.suptitle(f'Feature {feature_idx} - Max Activating Patches', fontsize=16)
    
    # Save the combined figure
    plt.tight_layout()
    plt.savefig(f'max_activations/feature_{feature_idx}_patches.png')
    
    # Display figure inline
    plt.show()

# Select features to visualize
feature_indices = list(range(8))  # Start with first 8 features

# Find max activating patches for each feature
for feature_idx in feature_indices:
    print(f"\nProcessing feature {feature_idx}")
    
    # Find max activating patches with diversity
    max_patches = find_max_activating_patches(
        model=model,
        sae=sae,
        feature_idx=feature_idx,
        num_samples=100,
        patch_size=24,
        top_k=4,
        diversity_weight=2.0  # Adjust this to control diversity vs. raw activation
    )
    
    # Visualize patches
    visualize_max_activations(feature_idx, max_patches)

print("All max activating patches found and visualized successfully!")