In [6]:
import os
import pandas as pd
import plotly.graph_objects as go

def plot_normalized_perplexity(data_folder):
    """
    Plot normalized perplexity for multiple datasets where each top-k value is in a separate file.
    Features:
    - Square plot with clear spacing
    - Legend on the right side
    - More vibrant colors for better visibility
    - White outlines on markers
    """
    # Define more vibrant colors for better visibility
    vibrant_colors = {
        'aime-math': '#FF5252',       # bright red
        'arxiv': '#4CAF50',           # bright green
        'arxiv-title-abs': '#2196F3', # bright blue
        'chinese': '#FFD600',         # bright yellow
        'english': '#FF9800',         # bright orange
        'french-qa': '#9C27B0',       # bright purple
        'github': '#00BFA5',          # bright teal
        'gsm8k': '#F06292',           # bright pink
    }
    
    # Get list of all datasets
    datasets = set()
    for file in os.listdir(data_folder):
        if file.endswith('.csv'):
            dataset_name = file.split('_perplexity_')[0]
            datasets.add(dataset_name)
    
    # Create figure
    fig = go.Figure()
    
    # Process each dataset
    for dataset in sorted(datasets):
        # Data structure to store perplexity values
        top_k_values = []
        perplexity_values = []
        
        # Read files for this dataset (top1 through top6)
        for k in range(1, 7):
            filename = f"{dataset}_perplexity_top{k}.csv"
            filepath = os.path.join(data_folder, filename)
            
            if os.path.exists(filepath):
                try:
                    df = pd.read_csv(filepath)
                    
                    # Find perplexity column
                    perp_col = None
                    for col in df.columns:
                        if 'perplexity' in col.lower():
                            perp_col = col
                            break
                    
                    if perp_col:
                        # Get mean perplexity
                        mean_perp = df[perp_col].mean()
                        
                        top_k_values.append(k)
                        perplexity_values.append(mean_perp)
                except Exception as e:
                    print(f"Error reading {filepath}: {e}")
        
        # Only process datasets with data
        if perplexity_values:
            # Normalize to 0-1 range
            min_perp = min(perplexity_values)
            max_perp = max(perplexity_values)
            
            if max_perp > min_perp:
                normalized_perplexity = [(p - min_perp) / (max_perp - min_perp) for p in perplexity_values]
            else:
                normalized_perplexity = [0] * len(perplexity_values)
            
            # Get color for this dataset
            color = vibrant_colors.get(dataset, '#757575')
            
            # Add to plot
            fig.add_trace(go.Scatter(
                x=top_k_values,
                y=normalized_perplexity,
                mode='lines+markers',
                name=dataset,
                line=dict(color=color, width=2.5),
                marker=dict(
                    color=color, 
                    size=8,
                    symbol='circle',
                    line=dict(
                        color='white',
                        width=2
                    )
                )
            ))
    
    # Update layout - make it square with proper axes and spacing
    fig.update_layout(
        title={
            'text': 'Normalized Log Perplexity vs Top K',
            'x': 0.5,
            'xanchor': 'center'
        },
        xaxis=dict(
            title='Top K',
            tickmode='array',
            tickvals=list(range(1, 7)),
            range=[0.5, 6.5],
            gridcolor='lightgray',
            showgrid=True,
            zeroline=False
        ),
        yaxis=dict(
            title='Normalized Log Perplexity',
            range=[-0.05, 1.05],  # Added padding
            gridcolor='lightgray',
            showgrid=True,
            zeroline=False
        ),
        # Make plot square with extra space
        width=700,
        height=700,
        plot_bgcolor='white',
        paper_bgcolor='white',
        
        # Move legend to right side
        legend=dict(
            x=1.02,
            y=0.5,
            xanchor='left',
            yanchor='middle',
            bordercolor='lightgray',
            borderwidth=1
        ),
        
        # Add margin for better spacing
        margin=dict(l=60, r=120, t=60, b=60)
    )
    
    return fig

# Usage example:
# fig = plot_normalized_perplexity('pplx-csv')
# fig.show()  # Display the plot
# fig.write_image('perplexity_plot.png')  # Save as image

In [13]:
import os
import pandas as pd
import plotly.graph_objects as go

def plot_normalized_perplexity(data_folder, output_file=None):
    """
    Calculate mean perplexity across multiple datasets, normalize to 0-1 range, and plot.
    
    Parameters:
    -----------
    data_folder : str
        Path to folder containing perplexity CSV files
    output_file : str, optional
        Path to save the output plot, if None, the plot is displayed
    """
    # Define vibrant colors for better visibility
    vibrant_colors = {
        'aime-math': '#FF5252',       # bright red
        'arxiv': '#4CAF50',           # bright green
        'arxiv-title-abs': '#2196F3', # bright blue
        'chinese': '#FFD600',         # bright yellow
        'english': '#FF9800',         # bright orange
        'french-qa': '#9C27B0',       # bright purple
        'github': '#00BFA5',          # bright teal
        'gsm8k': '#F06292',           # bright pink
    }
    
    # Create Plotly figure
    fig = go.Figure()
    
    # Get all datasets and their top-k values
    datasets = {}
    
    # Get list of files in the directory
    files = os.listdir(data_folder)
    print(f"Found {len(files)} files in directory")
    
    # Extract dataset name and top-k value from filenames
    for file in files:
        if '_perplexity_top' in file and file.endswith('.csv'):
            prefix = file.split('_perplexity_top')[0]
            top_k_str = file.split('_perplexity_top')[1].split('.')[0]
            
            try:
                top_k = int(top_k_str)
                
                if prefix not in datasets:
                    datasets[prefix] = {}
                
                # Read the CSV file
                file_path = os.path.join(data_folder, file)
                df = pd.read_csv(file_path)
                
                # Find the perplexity column (could be 'log_perplexity', 'perplexity', etc.)
                perp_col = None
                for col in df.columns:
                    if 'perplexity' in col.lower():
                        perp_col = col
                        break
                
                if perp_col:
                    # Calculate mean log perplexity
                    mean_perplexity = df[perp_col].mean()
                    datasets[prefix][top_k] = mean_perplexity
                    print(f"Dataset {prefix}, Top-{top_k}: Mean {perp_col} = {mean_perplexity}")
            except Exception as e:
                print(f"Error processing {file}: {e}")
    
    print(f"Found {len(datasets)} datasets")
    
    # Normalize and plot each dataset
    for prefix, values in sorted(datasets.items()):
        if len(values) > 0:  # Ensure we have data for this dataset
            top_k_values = sorted(values.keys())
            perplexity_values = [values[k] for k in top_k_values]
            
            # Normalize to 0-1 range
            min_perp = min(perplexity_values)
            max_perp = max(perplexity_values)
            if max_perp > min_perp:  # Avoid division by zero
                normalized_perplexity = [(x - min_perp) / (max_perp - min_perp) for x in perplexity_values]
            else:
                normalized_perplexity = [0 for _ in perplexity_values]
            
            # Get color, or use a default if not defined
            color = vibrant_colors.get(prefix, '#757575')  # Default to medium gray
            
            print(f"Plotting {prefix} with x={top_k_values}, y={normalized_perplexity}")
            
            # Add trace to plot
            fig.add_trace(go.Scatter(
                x=top_k_values,
                y=normalized_perplexity,
                mode='lines+markers',
                name=prefix,
                line=dict(color=color, width=3.5),
                marker=dict(
                    color=color, 
                    size=8,
                    symbol='circle',
                    line=dict(
                        color='white',
                        width=2
                    )
                )
            ))
    
    # Update layout
    fig.update_layout(
        title={
            'text': 'Normalized Log Perplexity vs Top K',
            'x': 0.5,
            'xanchor': 'center',
            'font': {'family': 'Times New Roman'}
        },
        xaxis=dict(
            title='Top K',
            tickmode='array',
            tickvals=list(range(1, 7)),
            range=[0.5, 6.5],
            gridcolor='lightgray',
            showgrid=True,
            zeroline=False,
            titlefont={'family': 'Times New Roman'},
            tickfont={'family': 'Times New Roman'}
        ),
        yaxis=dict(
            title='Normalized Log Perplexity',
            range=[-0.05, 1.05],
            gridcolor='lightgray',
            showgrid=True,
            zeroline=False,
            titlefont={'family': 'Times New Roman'},
            tickfont={'family': 'Times New Roman'}
        ),
        # Make plot area square while allowing overall figure to be wider
        width=900,  # Increased width to accommodate legend
        height=700,  # Keep height the same
        plot_bgcolor='white',
        paper_bgcolor='white',
        legend=dict(
            x=1.02,
            y=0.5,
            xanchor='left',
            yanchor='middle',
            bordercolor='lightgray',
            borderwidth=1,
            font={'family': 'Times New Roman'}
        ),
        margin=dict(l=60, r=120, t=60, b=60)
    )
    
    # Show or save the plot
    if output_file:
        fig.write_image(output_file)
    else:
        fig.show()
    
    return fig

# Example usage:
# fig = plot_normalized_perplexity('pplx-csv')

In [14]:
fig = plot_normalized_perplexity('pplx-csv')

Found 46 files in directory
Dataset gsm8k, Top-3: Mean perplexity = 5.60752153754458
Dataset english, Top-6: Mean perplexity = 3.349964508076304
Dataset gsm8k, Top-2: Mean perplexity = 6.0031008823280265
Dataset english, Top-4: Mean perplexity = 3.3909041119604995
Dataset english, Top-5: Mean perplexity = 3.362496385869292
Dataset gsm8k, Top-1: Mean perplexity = 7.674630131998832
Dataset english, Top-1: Mean perplexity = 4.646167125898538
Dataset gsm8k, Top-4: Mean perplexity = 5.4906668501991716
Dataset english, Top-2: Mean perplexity = 3.661294664304281
Dataset english, Top-3: Mean perplexity = 3.4624180842920675
Dataset arxiv-title-abs, Top-3: Mean perplexity = 14.356538934707642
Dataset french-qa, Top-6: Mean perplexity = 5.998863987922668
Dataset github, Top-6: Mean perplexity = 58.362832218408585
Dataset arxiv-title-abs, Top-2: Mean perplexity = 15.852977024555207
Dataset github, Top-4: Mean perplexity = 62.94342365860939
Dataset french-qa, Top-5: Mean perplexity = 6.030495074135