In [None]:
def analyze_cos_similarities(width_power):
    """
    Analyze cosine similarities between before and after SAE weights for a given width.
    
    Args:
        width_power (int): Power of 2 for width (e.g., 14 for 2^14 or 16 for 2^16)
    """
    import torch
    from huggingface_hub import hf_hub_download
    import matplotlib.pyplot as plt
    import numpy as np

    trainer_ids = list(range(6))
    before_local_dir = "before_cos_sim_saes"
    after_local_dir = "after_cos_sim_saes"


    width_str = f"2pow{width_power}"

    before_repo_id = f"canrager/saebench_gemma-2-2b_width-{width_str}_date-0107"
    after_repo_id = "adamkarvonen/new_kl_finetunes"

    all_results = []

    for trainer_id in trainer_ids:
        base_path = f"gemma-2-2b_top_k_width-{width_str}_date-0107/resid_post_layer_12/trainer_{trainer_id}/ae.pt"
        
        path_to_before_params = hf_hub_download(
            repo_id=before_repo_id,
            filename=base_path,
            force_download=False,
            local_dir=before_local_dir,
        )

        path_to_after_params = hf_hub_download(
            repo_id=after_repo_id,
            filename=base_path,
            force_download=False,
            local_dir=after_local_dir,
        )

        before_params = torch.load(path_to_before_params)
        after_params = torch.load(path_to_after_params)

        # ... existing code for calculating decoder similarities ...
        before_decoder = before_params['decoder.weight']
        after_decoder = after_params['decoder.weight']

        before_decoder_norm = before_decoder / before_decoder.norm(dim=0, keepdim=True)
        after_decoder_norm = after_decoder / after_decoder.norm(dim=0, keepdim=True)
        
        decoder_cos_sims = torch.sum(before_decoder_norm * after_decoder_norm, dim=0)
        
        # Calculate decoder statistics
        decoder_stats = {
            'mean': decoder_cos_sims.mean().item(),
            'median': decoder_cos_sims.median().item(),
            'std': decoder_cos_sims.std().item(),
            'min': decoder_cos_sims.min().item(),
            'max': decoder_cos_sims.max().item(),
            'p25': decoder_cos_sims.quantile(0.25).item(),
            'p75': decoder_cos_sims.quantile(0.75).item(),
        }
        
        # ... existing code for calculating encoder similarities ...
        before_encoder = before_params['encoder.weight']
        after_encoder = after_params['encoder.weight']

        before_encoder_norm = before_encoder / before_encoder.norm(dim=1, keepdim=True)
        after_encoder_norm = after_encoder / after_encoder.norm(dim=1, keepdim=True)

        decoder_norm_ratios = after_decoder.norm(dim=0) / before_decoder.norm(dim=0)
        print(f"\nTrainer {trainer_id} Norm Ratios:")
        print(f"Decoder (after/before):")
        print(f"  Mean: {decoder_norm_ratios.mean():.4f}")
        print(f"  Median: {decoder_norm_ratios.median():.4f}")
        print(f"  Min/Max: {decoder_norm_ratios.min():.4f}/{decoder_norm_ratios.max():.4f}")
        print(f"  25th/75th percentile: {decoder_norm_ratios.quantile(0.25):.4f}/{decoder_norm_ratios.quantile(0.75):.4f}")
        

        # Print encoder norm ratios
        encoder_norm_ratios = after_encoder.norm(dim=1) / before_encoder.norm(dim=1)
        print(f"Encoder (after/before):")
        print(f"  Mean: {encoder_norm_ratios.mean():.4f}")
        print(f"  Median: {encoder_norm_ratios.median():.4f}")
        print(f"  Min/Max: {encoder_norm_ratios.min():.4f}/{encoder_norm_ratios.max():.4f}")
        print(f"  25th/75th percentile: {encoder_norm_ratios.quantile(0.25):.4f}/{encoder_norm_ratios.quantile(0.75):.4f}\n")

        encoder_cos_sims = torch.sum(before_encoder_norm * after_encoder_norm, dim=1)


        assert (encoder_cos_sims.shape[0] == 16384) or (encoder_cos_sims.shape[0] == 65536)
        assert (decoder_cos_sims.shape[0] == 16384) or (decoder_cos_sims.shape[0] == 65536)
        
        encoder_stats = {
            'mean': encoder_cos_sims.mean().item(),
            'median': encoder_cos_sims.median().item(),
            'std': encoder_cos_sims.std().item(),
            'min': encoder_cos_sims.min().item(),
            'max': encoder_cos_sims.max().item(),
            'p25': encoder_cos_sims.quantile(0.25).item(),
            'p75': encoder_cos_sims.quantile(0.75).item(),
        }
        
        result = {
            'trainer_id': trainer_id,
            'decoder': decoder_stats,
            'encoder': encoder_stats,
            'decoder_cos_sims': decoder_cos_sims.cpu().numpy(),
            'encoder_cos_sims': encoder_cos_sims.cpu().numpy(),
        }
        all_results.append(result)
        
        print(f"\nTrainer {trainer_id}:")
        print(f"Decoder statistics:")
        print(f"  Mean: {decoder_stats['mean']:.4f}")
        print(f"  Median: {decoder_stats['median']:.4f}")
        print(f"  Std: {decoder_stats['std']:.4f}")
        print(f"  Min/Max: {decoder_stats['min']:.4f}/{decoder_stats['max']:.4f}")
        print(f"  25th/75th percentile: {decoder_stats['p25']:.4f}/{decoder_stats['p75']:.4f}")
        
        print(f"\nEncoder statistics:")
        print(f"  Mean: {encoder_stats['mean']:.4f}")
        print(f"  Median: {encoder_stats['median']:.4f}")
        print(f"  Std: {encoder_stats['std']:.4f}")
        print(f"  Min/Max: {encoder_stats['min']:.4f}/{encoder_stats['max']:.4f}")
        print(f"  25th/75th percentile: {encoder_stats['p25']:.4f}/{encoder_stats['p75']:.4f}")

    # Plot creation
    k_values = [20, 40, 80, 160, 320, 640]
    decoder_p25 = [result['decoder']['p25'] for result in all_results]
    decoder_p75 = [result['decoder']['p75'] for result in all_results]
    encoder_p25 = [result['encoder']['p25'] for result in all_results]
    encoder_p75 = [result['encoder']['p75'] for result in all_results]

    decoder_range = np.array(decoder_p75) - np.array(decoder_p25)
    encoder_range = np.array(encoder_p75) - np.array(encoder_p25)

    plt.figure(figsize=(12, 6))
    x = np.arange(len(k_values))
    width = 0.35

    plt.bar(x - width/2, decoder_range, width, bottom=decoder_p25, 
            label='Decoder (25th-75th)', color='skyblue', alpha=0.7)
    plt.bar(x + width/2, encoder_range, width, bottom=encoder_p25,
            label='Encoder (25th-75th)', color='lightcoral', alpha=0.7)

    for i in range(len(k_values)):
        plt.plot([i - width/2 - width/4, i - width/2 + width/4], 
                 [decoder_p25[i], decoder_p25[i]], color='blue', linewidth=2)
        plt.plot([i - width/2 - width/4, i - width/2 + width/4], 
                 [decoder_p75[i], decoder_p75[i]], color='blue', linewidth=2)
        
        plt.plot([i + width/2 - width/4, i + width/2 + width/4], 
                 [encoder_p25[i], encoder_p25[i]], color='red', linewidth=2)
        plt.plot([i + width/2 - width/4, i + width/2 + width/4], 
                 [encoder_p75[i], encoder_p75[i]], color='red', linewidth=2)

    plt.xlabel('Top-k Value')
    plt.ylabel('Cosine Similarity')
    if width_power == 14:
        plt.title(f'25th-75th Percentile Ranges of Cosine Similarities (TopK, 16k Width)')
    elif width_power == 16:
        plt.title(f'25th-75th Percentile Ranges of Cosine Similarities (TopK, 65k Width)')
    else:
        raise ValueError(f"Invalid width power: {width_power}")
    plt.xticks(x, k_values)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    # Save the plot
    plt.savefig(f'topk_cos_sim_analysis_width_{width_str}.png')
    plt.close()

# Example usage:
if __name__ == "__main__":
    # For width 2^14
    analyze_cos_similarities(14)
    
    # For width 2^16
    analyze_cos_similarities(16)

In [None]:
def analyze_cos_similarities_relu(width_power):
    """
    Analyze cosine similarities between before and after SAE weights for a given width.
    
    Args:
        width_power (int): Power of 2 for width (e.g., 14 for 2^14 or 16 for 2^16)
    """
    import torch
    from huggingface_hub import hf_hub_download
    import matplotlib.pyplot as plt
    import numpy as np

    trainer_ids = list(range(6))
    before_local_dir = "before_relu_cos_sim_saes"
    after_local_dir = "after_relu_cos_sim_saes_v2"


    width_str = f"2pow{width_power}"

    before_repo_id = f"canrager/saebench_gemma-2-2b_width-{width_str}_date-0107"
    after_repo_id = "adamkarvonen/new_kl_finetunes"

    all_results = []

    for trainer_id in trainer_ids:
        base_path = f"gemma-2-2b_standard_new_width-{width_str}_date-0107/resid_post_layer_12/trainer_{trainer_id}/ae.pt"
        
        path_to_before_params = hf_hub_download(
            repo_id=before_repo_id,
            filename=base_path,
            force_download=False,
            local_dir=before_local_dir,
        )

        path_to_after_params = hf_hub_download(
            repo_id=after_repo_id,
            filename=base_path,
            force_download=False,
            local_dir=after_local_dir,
        )

        before_params = torch.load(path_to_before_params)
        after_params = torch.load(path_to_after_params)

        # ... existing code for calculating decoder similarities ...
        before_decoder = before_params['decoder.weight']
        after_decoder = after_params['decoder.weight']
        before_decoder_norm = before_decoder / before_decoder.norm(dim=0, keepdim=True)
        after_decoder_norm = after_decoder / after_decoder.norm(dim=0, keepdim=True)
        
        decoder_cos_sims = torch.sum(before_decoder_norm * after_decoder_norm, dim=0)
        
        # Calculate decoder statistics
        decoder_stats = {
            'mean': decoder_cos_sims.mean().item(),
            'median': decoder_cos_sims.median().item(),
            'std': decoder_cos_sims.std().item(),
            'min': decoder_cos_sims.min().item(),
            'max': decoder_cos_sims.max().item(),
            'p25': decoder_cos_sims.quantile(0.25).item(),
            'p75': decoder_cos_sims.quantile(0.75).item(),
        }
        
        # ... existing code for calculating encoder similarities ...
        before_encoder = before_params['encoder.weight']
        after_encoder = after_params['encoder.weight']

        before_encoder_norm = before_encoder / before_encoder.norm(dim=1, keepdim=True)
        after_encoder_norm = after_encoder / after_encoder.norm(dim=1, keepdim=True)
        
        encoder_cos_sims = torch.sum(before_encoder_norm * after_encoder_norm, dim=1)

        assert (encoder_cos_sims.shape[0] == 16384) or (encoder_cos_sims.shape[0] == 65536)
        assert (decoder_cos_sims.shape[0] == 16384) or (decoder_cos_sims.shape[0] == 65536)
        
        encoder_stats = {
            'mean': encoder_cos_sims.mean().item(),
            'median': encoder_cos_sims.median().item(),
            'std': encoder_cos_sims.std().item(),
            'min': encoder_cos_sims.min().item(),
            'max': encoder_cos_sims.max().item(),
            'p25': encoder_cos_sims.quantile(0.25).item(),
            'p75': encoder_cos_sims.quantile(0.75).item(),
        }
        
        result = {
            'trainer_id': trainer_id,
            'decoder': decoder_stats,
            'encoder': encoder_stats,
            'decoder_cos_sims': decoder_cos_sims.cpu().numpy(),
            'encoder_cos_sims': encoder_cos_sims.cpu().numpy(),
        }
        all_results.append(result)
        
        print(f"\nTrainer {trainer_id}:")
        print(f"Decoder statistics:")
        print(f"  Mean: {decoder_stats['mean']:.4f}")
        print(f"  Median: {decoder_stats['median']:.4f}")
        print(f"  Std: {decoder_stats['std']:.4f}")
        print(f"  Min/Max: {decoder_stats['min']:.4f}/{decoder_stats['max']:.4f}")
        print(f"  25th/75th percentile: {decoder_stats['p25']:.4f}/{decoder_stats['p75']:.4f}")
        
        print(f"\nEncoder statistics:")
        print(f"  Mean: {encoder_stats['mean']:.4f}")
        print(f"  Median: {encoder_stats['median']:.4f}")
        print(f"  Std: {encoder_stats['std']:.4f}")
        print(f"  Min/Max: {encoder_stats['min']:.4f}/{encoder_stats['max']:.4f}")
        print(f"  25th/75th percentile: {encoder_stats['p25']:.4f}/{encoder_stats['p75']:.4f}")

    # Plot creation
    k_values = [20, 40, 80, 160, 320, 640]
    decoder_p25 = [result['decoder']['p25'] for result in all_results]
    decoder_p75 = [result['decoder']['p75'] for result in all_results]
    encoder_p25 = [result['encoder']['p25'] for result in all_results]
    encoder_p75 = [result['encoder']['p75'] for result in all_results]

    decoder_range = np.array(decoder_p75) - np.array(decoder_p25)
    encoder_range = np.array(encoder_p75) - np.array(encoder_p25)

    trainer_labels = ['Highest L0', 
                     '', 
                     '',
                     '',
                     '',
                     'Lowest L0']

    plt.figure(figsize=(12, 6))
    x = np.arange(len(trainer_labels))
    width = 0.35

    plt.bar(x - width/2, decoder_range, width, bottom=decoder_p25, 
            label='Decoder (25th-75th)', color='skyblue', alpha=0.7)
    plt.bar(x + width/2, encoder_range, width, bottom=encoder_p25,
            label='Encoder (25th-75th)', color='lightcoral', alpha=0.7)

    for i in range(len(trainer_labels)):
        plt.plot([i - width/2 - width/4, i - width/2 + width/4], 
                 [decoder_p25[i], decoder_p25[i]], color='blue', linewidth=2)
        plt.plot([i - width/2 - width/4, i - width/2 + width/4], 
                 [decoder_p75[i], decoder_p75[i]], color='blue', linewidth=2)
        
        plt.plot([i + width/2 - width/4, i + width/2 + width/4], 
                 [encoder_p25[i], encoder_p25[i]], color='red', linewidth=2)
        plt.plot([i + width/2 - width/4, i + width/2 + width/4], 
                 [encoder_p75[i], encoder_p75[i]], color='red', linewidth=2)

    plt.xlabel('Trainers (Ordered by Decreasing L0 Sparsity)')
    plt.ylabel('Cosine Similarity')
    if width_power == 14:
        plt.title(f'25th-75th Percentile Ranges of Cosine Similarities\n(ReLU, 16k Width)')
    elif width_power == 16:
        plt.title(f'25th-75th Percentile Ranges of Cosine Similarities\n(ReLU, 65k Width)')
    else:
        raise ValueError(f"Invalid width power: {width_power}")
    
    plt.xticks(x, trainer_labels, rotation=45)
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    
    # Save the plot
    plt.savefig(f'relu_cos_sim_analysis_width_{width_str}.png')
    plt.close()


if __name__ == "__main__":
    # For width 2^14
    analyze_cos_similarities_relu(14)
    
    # For width 2^16
    analyze_cos_similarities_relu(16)