# Muon vs. AdamW Experiment Analysis

This notebook analyzes the results of the OMat24 experiments, comparing the performance and learning geometry of the Muon optimizer against AdamW.


In [None]:
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

# Load the metrics data
try:
    df_full = pd.read_csv('results/metrics.csv')
    
    # The first unnamed run is the baseline, let's give it a proper name
    # Find the row where the run name changes to identify the end of the first run
    first_run_end_idx = df_full[df_full['run'] != df_full['run'].iloc[0]].index[0]
    df_full.loc[:first_run_end_idx-1, 'run'] = 'omat24_baseline'

    print("Data loaded successfully.")
    display(df_full.head())
    print("\nFinal epoch metrics for each run:")
    display(df_full.groupby('run').last())

except FileNotFoundError:
    print("Error: results/metrics.csv not found.")
    print("Please make sure you have downloaded the results from the GPU.")
except IndexError:
    print("Warning: Could not automatically rename the baseline run. The CSV might only contain one experiment type.")
    display(df_full.head())


# Set plot style
sns.set_theme(style="whitegrid")


In [None]:
plt.figure(figsize=(12, 8))
ax = sns.lineplot(data=df_full, x='epoch', y='val_mae', hue='run', marker='o', alpha=0.8)

ax.set_title('Validation MAE vs. Epoch for Different Optimizers', fontsize=16)
ax.set_xlabel('Epoch', fontsize=12)
ax.set_ylabel('Validation Mean Absolute Error (MAE)', fontsize=12)
ax.legend(title='Optimizer Run')
ax.set_yscale('log')
ax.grid(True, which="both", ls="--")


# Annotate the minimum MAE for each run
for run_name in df_full['run'].unique():
    run_df = df_full[df_full['run'] == run_name].dropna(subset=['val_mae'])
    if not run_df.empty:
        min_mae_row = run_df.loc[run_df['val_mae'].idxmin()]
        
        plt.annotate(f'{min_mae_row["val_mae"]:.4f}',
                     (min_mae_row['epoch'], min_mae_row['val_mae']),
                     textcoords="offset points",
                     xytext=(0,10),
                     ha='center',
                     arrowprops=dict(arrowstyle="->", connectionstyle="arc3,rad=.2"))

plt.tight_layout()
plt.show()


In [None]:
# Manually create a DataFrame with the final test set results from the eval script
test_results = {
    "run": [
        "omat24_baseline",
        "omat24_muon_full",
        "omat24_muon_no_ortho",
        "omat24_muon_firstk",
    ],
    "Final Test MAE": [0.3296, 0.4087, 0.3844, 0.3238],
    "Final Test RMSE": [0.4183, 0.5095, 0.4737, 0.4107],
}
df_test = pd.DataFrame(test_results)

print("--- Final Test Set Performance ---")
# Sort by the most important metric to see the winner clearly
display(df_test.sort_values("Final Test MAE"))


In [None]:
# Clean up run names for plotting
df_test['run_label'] = df_test['run'].str.replace('omat24_', '').str.replace('_', ' ').str.title()

# Create the bar plot
plt.figure(figsize=(12, 7))
ax = sns.barplot(data=df_test.sort_values("Final Test MAE"), 
                 x='run_label', y='Final Test MAE', 
                 hue='run_label', palette='plasma', dodge=False)

ax.set_title('Final Test Set MAE by Optimizer Strategy', fontsize=16)
ax.set_xlabel('Optimizer Strategy', fontsize=12)
ax.set_ylabel('Final Mean Absolute Error (MAE)', fontsize=12)
plt.xticks(rotation=15, ha='right')
ax.get_legend().remove() # Remove redundant legend

# Add annotations to the bars
for p in ax.patches:
    ax.annotate(f'{p.get_height():.4f}', 
                (p.get_x() + p.get_width() / 2., p.get_height()), 
                ha = 'center', va = 'center', 
                xytext = (0, 9), 
                textcoords = 'offset points',
                fontsize=12, weight='bold')

plt.tight_layout()
plt.show()


In [None]:
# Extract the final epoch metrics for each run
final_metrics = df_full.groupby('run').last().reset_index()

# Select only the rattle MAE columns
rattle_mae_cols = ['val_mae_low_rattle', 'val_mae_medium_rattle', 'val_mae_high_rattle']
plot_data = final_metrics[['run'] + rattle_mae_cols]

# Melt the dataframe to make it suitable for a bar plot
plot_data_melted = plot_data.melt(id_vars='run', var_name='Rattle Level', value_name='Final MAE')

# Clean up the names for the plot
plot_data_melted['Rattle Level'] = plot_data_melted['Rattle Level'].str.replace('val_mae_', '').str.replace('_', ' ').str.title()

# Create the bar plot
plt.figure(figsize=(14, 8))
ax = sns.barplot(data=plot_data_melted, x='run', y='Final MAE', hue='Rattle Level', palette='viridis')

ax.set_title('Final Validation MAE by Optimizer and Structure Rattle Level', fontsize=16)
ax.set_xlabel('Optimizer Run', fontsize=12)
ax.set_ylabel('Final Validation MAE', fontsize=12)
plt.xticks(rotation=15, ha='right')

# Add annotations to the bars
for p in ax.patches:
    if p.get_height() > 0:
        ax.annotate(f'{p.get_height():.4f}', 
                    (p.get_x() + p.get_width() / 2., p.get_height()), 
                    ha = 'center', va = 'center', 
                    xytext = (0, 9), 
                    textcoords = 'offset points')

plt.tight_layout()
plt.show()
