# Perplexity Violin Plot Analysis

This notebook reads CSV files from the `ncbi_downloads_sequences_train_40` directory and creates a violin plot showing the distribution of perplexity values across different training steps and a baseline model.

## Data Sources:
- **Step 9**: Early training checkpoint (160 consumed samples)
- **Step 49**: Training checkpoint (800 consumed samples)
- **Step 99**: Training checkpoint (1600 consumed samples)
- **Step 199**: Training checkpoint (3200 consumed samples)
- **Baseline**: nemo2_evo2_7b model

Run the cells below to generate the violin plot.


In [14]:
# Install required packages if not already installed
import subprocess
import sys

def install(package):
    subprocess.check_call([sys.executable, "-m", "pip", "install", package])

try:
    import plotly
except ImportError:
    print("Installing plotly...")
    install('plotly')
    
try:
    import pandas
except ImportError:
    print("Installing pandas...")
    install('pandas')

# Note: For saving plots as PDF, you'll also need kaleido:
# pip install kaleido
try:
    import kaleido
except ImportError:
    print("Warning: kaleido not installed. Installing for PDF export...")
    install('kaleido')


In [4]:
# Alternative version with statistics and multiple outputs
import pandas as pd
import plotly.express as px
import numpy as np
from pathlib import Path

# dir_name = "ncbi_downloads_sequences_train_40"
# dir_name = "ncbi_downloads_sequences_test_60"
dir_name = "prokaryotic_host_sequences"
fna_name = "prokaryotic_host_sequences"
# fna_name = "merged"
# Define the directory containing CSV files
csv_dir = Path(dir_name)
list_name = dir_name

# Define the CSV files to read with order for display
csv_files = [
    ("Baseline", f"ppl_results_{fna_name}_nemo2_evo2_7b_8k.csv"),
    ("Step 10", f"ppl_results_{fna_name}_epoch=0-step=9-consumed_samples=160.csv"),
    ("Step 50", f"ppl_results_{fna_name}_epoch=0-step=49-consumed_samples=800.csv"),
    ("Step 100", f"ppl_results_{fna_name}_epoch=0-step=99-consumed_samples=1600.csv"),
    ("Step 200", f"ppl_results_{fna_name}_epoch=0-step=199-consumed_samples=3200.csv"),
    
]

# Read all CSV files and store data
perplexity_results = []
stats_summary = []

for trace_name, csv_file in csv_files:
    file_path = csv_dir / csv_file
    df = pd.read_csv(file_path)
    perplexity_values = df['perplexity'].values
    
    # Add all perplexity values with their trace name
    for ppl in perplexity_values:
        perplexity_results.append({
            'trace': trace_name,
            'perplexity': ppl
        })
    
    # Calculate statistics
    stats = {
        'Model': trace_name,
        'Count': len(perplexity_values),
        'Mean': np.mean(perplexity_values),
        'Median': np.median(perplexity_values),
        'Std': np.std(perplexity_values),
        'Min': np.min(perplexity_values),
        'Max': np.max(perplexity_values),
        'Q1': np.percentile(perplexity_values, 25),
        'Q3': np.percentile(perplexity_values, 75)
    }
    stats_summary.append(stats)
    
    print(f"Loaded {trace_name}: {len(df)} samples")
    print(f"  - Mean: {stats['Mean']:.4f}, Median: {stats['Median']:.4f}")
    print(f"  - Range: [{stats['Min']:.4f}, {stats['Max']:.4f}]")
    print()

# Convert to DataFrame for plotting
df_plot = pd.DataFrame(perplexity_results)

# Create violin plot using plotly express
fig = px.violin(
    df_plot,
    x="trace",
    y="perplexity",
    box=True,
    points="all",
    labels={"perplexity": "Perplexity", "trace": "Model/Step"},
    title=f"Perplexity Distribution across Training Steps (File: {list_name})",
    category_orders={"trace": [name for name, _ in csv_files]},  # Preserve order
    color="trace",  # Color by trace
    color_discrete_sequence=['#3498db', '#e74c3c', '#2ecc71', '#f39c12', '#9b59b6']
)

# Update layout
fig.update_layout(
    title_x=0.5,
    template="plotly_white",
    width=1200,
    height=700,
    font=dict(size=14),
    showlegend=False
)

# Set y-axis range with some padding for the data
y_min = 0.0
y_max = 6.0
fig.update_yaxes(range=[y_min, y_max])

# Update x-axis
# fig.update_xaxes(tickangle=-30)

fig.update_layout(
    plot_bgcolor='white'
    )
fig.update_xaxes(
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
)
fig.update_yaxes(
    {'gridcolor': 'lightgrey', 'zerolinecolor': 'lightgrey'},
    mirror=True,
    ticks='outside',
    showline=True,
    linecolor='black',
    title_standoff=0
)

# Show the plot
fig.show()

# Save outputs
out_file_pdf = f"ppl_violin_plot_{list_name}.pdf"

fig.write_image(out_file_pdf)

print(f"\nPlots saved as:")
print(f"  - {out_file_pdf}")



# Display statistics summary
print("\n" + "="*80)
print("Statistical Summary:")
print("="*80)
stats_df = pd.DataFrame(stats_summary)
print(stats_df.to_string(index=False, float_format=lambda x: f'{x:.4f}'))


Loaded Baseline: 1000 samples
  - Mean: 2.6104, Median: 2.5562
  - Range: [1.0777, 4.4479]

Loaded Step 10: 1000 samples
  - Mean: 2.7120, Median: 2.6442
  - Range: [1.1349, 4.8555]

Loaded Step 50: 1000 samples
  - Mean: 2.6287, Median: 2.6053
  - Range: [1.1172, 4.4635]

Loaded Step 100: 1000 samples
  - Mean: 2.6687, Median: 2.6289
  - Range: [1.1099, 4.7015]

Loaded Step 200: 1000 samples
  - Mean: 2.6527, Median: 2.6243
  - Range: [1.1377, 6.6059]




Plots saved as:
  - ppl_violin_plot_prokaryotic_host_sequences.pdf
  - ppl_violin_plot_prokaryotic_host_sequences.html

Statistical Summary:
   Model  Count   Mean  Median    Std    Min    Max     Q1     Q3
Baseline   1000 2.6104  2.5562 0.7924 1.0777 4.4479 1.9932 3.1484
 Step 10   1000 2.7120  2.6442 0.8493 1.1349 4.8555 2.0533 3.2983
 Step 50   1000 2.6287  2.6053 0.7509 1.1172 4.4635 2.0502 3.1417
Step 100   1000 2.6687  2.6289 0.7887 1.1099 4.7015 2.0682 3.2066
Step 200   1000 2.6527  2.6243 0.7474 1.1377 6.6059 2.0803 3.1544
