In [1]:
import torch
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
import os
import seaborn as sns
import numpy as np
# Set style for better visualization
sns.set_palette("husl")

In [2]:
# Create directory for saving plots if it doesn't exist
os.makedirs('singular_value_plots', exist_ok=True)

# Function to load singular values
def load_singular_values(layer_num, proj_type):
    file_path = f'data/{proj_type}_proj_weight/S_{proj_type}_layer_{layer_num}.pt'
    return torch.load(file_path)

# Projection types to process
proj_types = ['down', 'up', 'gate']

# Process each projection type
for proj_type in proj_types:
    print(f"Processing {proj_type}_proj_weight...")
    
    # Plot individual layer singular values
    all_singular_values = []

    for layer in range(1, 33):  # Layers 1-32
        singular_values = load_singular_values(layer, proj_type).cpu()
        all_singular_values.append(singular_values.numpy())
        
        plt.figure(figsize=(10, 6))
        plt.plot(singular_values.numpy(), '-o', markersize=2, alpha=0.7)
        plt.title(f'Singular Values Distribution - {proj_type}_proj Layer {layer}')
        plt.xlabel('Index')
        plt.ylabel('Singular Value (log scale)')
        plt.yscale('log')  # Using log scale for better visualization
        plt.grid(True)
        
        # Create custom y-ticks with more detailed log scale values
        min_val = singular_values.min().item()
        max_val = singular_values.max().item()
        log_min = np.floor(np.log10(min_val))
        log_max = np.ceil(np.log10(max_val))
        
        # Generate major ticks at powers of 10
        major_ticks = [10**i for i in range(int(log_min), int(log_max) + 1)]
        
        # Generate minor ticks at 1e-1, 2e-1, 3e-1, etc. for each power of 10
        minor_ticks = []
        for power in range(int(log_min), int(log_max) + 1):
            for i in range(1, 10):
                minor_ticks.append(i * 10**power)
        
        plt.yticks(major_ticks + minor_ticks)
        # Add tick labels for major ticks only
        plt.gca().set_yticklabels([f'$10^{{{int(np.log10(tick))}}}$' if tick in major_ticks else '' for tick in major_ticks + minor_ticks])
        plt.tight_layout()
        plt.show()
    #     # Save the plot
    #     plt.savefig(f'singular_value_plots/{proj_type}_layer_{layer}_singular_values.png', dpi=300, bbox_inches='tight')
    #     plt.close()

    # # Create summary plot
    # plt.figure(figsize=(15, 8))

    # # Plot all layers
    # for layer, values in enumerate(all_singular_values, 1):
    #     plt.plot(values, alpha=0.5, label=f'Layer {layer}')

    # plt.title(f'Singular Values Distribution Across All Layers - {proj_type}_proj')
    # plt.xlabel('Index')
    # plt.ylabel('Singular Value (log scale)')
    # plt.yscale('log')
    # plt.grid(True)
    
    # # Create custom y-ticks with more detailed log scale values for summary plot
    # all_values = np.concatenate(all_singular_values)
    # min_val = all_values.min()
    # max_val = all_values.max()
    # log_min = np.floor(np.log10(min_val))
    # log_max = np.ceil(np.log10(max_val))
    
    # # Generate major ticks at powers of 10
    # major_ticks = [10**i for i in range(int(log_min), int(log_max) + 1)]
    
    # # Generate minor ticks at 1e-1, 2e-1, 3e-1, etc. for each power of 10
    # minor_ticks = []
    # for power in range(int(log_min), int(log_max) + 1):
    #     for i in range(1, 10):
    #         minor_ticks.append(i * 10**power)
    
    # plt.yticks(major_ticks + minor_ticks)
    # # Add tick labels for major ticks only in scientific notation
    # plt.gca().set_yticklabels([f'$10^{{{int(np.log10(tick))}}}$' if tick in major_ticks else '' for tick in major_ticks + minor_ticks])
    # plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    # plt.tight_layout()
    # # Save the summary plot
    # plt.savefig(f'singular_value_plots/{proj_type}_all_layers_summary.png', dpi=300, bbox_inches='tight')
    # plt.close()

    # # Create a heatmap of singular values across layers
    # plt.figure(figsize=(15, 10))
    # singular_values_matrix = np.array(all_singular_values)

    # sns.heatmap(singular_values_matrix, 
    #             cmap='viridis',
    #             norm=LogNorm(),  # Using the correctly imported LogNorm
    #             cbar_kws={'label': 'Singular Value (log scale)'})

    # plt.title(f'Heatmap of Singular Values Across Layers - {proj_type}_proj')
    # plt.xlabel('Singular Value Index')
    # plt.ylabel('Layer')
    # plt.tight_layout()

    # # Save the heatmap
    # plt.savefig(f'singular_value_plots/{proj_type}_singular_values_heatmap.png', dpi=300, bbox_inches='tight')
    # plt.close()

SyntaxError: 'return' outside function (1182918330.py, line 51)