In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

plt.style.use('seaborn-v0_8')
sns.set_palette("husl")

In [None]:
df = pd.read_csv('ablation_results.csv')
print(f"Loaded {len(df)} experiments")
print(f"Models: {df['model'].unique()}")
print(f"Seeds: {df['seed'].unique()}")
print(f"WPN values: {sorted(df['wpn'].dropna().unique())}")

df.head()

In [None]:
diffusion_data = df[df['model'] == 'Diffusion'].groupby('seed').agg({
    'lml': 'mean', 'mse': 'mean'
}).reset_index()

grf_data = df[df['model'] == 'GRF'].groupby(['wpn', 'seed']).agg({
    'lml': 'mean', 'mse': 'mean'
}).reset_index()

grf_ablation_data = df[df['model'] == 'GRF-ablation'].groupby(['wpn', 'seed']).agg({
    'lml': 'mean', 'mse': 'mean'
}).reset_index()

diffusion_mean_lml = diffusion_data['lml'].mean()
diffusion_mean_mse = diffusion_data['mse'].mean()

grf_stats = grf_data.groupby('wpn').agg({
    'lml': ['mean', 'std'], 'mse': ['mean', 'std']
}).reset_index()
grf_stats.columns = ['wpn', 'lml_mean', 'lml_std', 'mse_mean', 'mse_std']

grf_ablation_stats = grf_ablation_data.groupby('wpn').agg({
    'lml': ['mean', 'std'], 'mse': ['mean', 'std']
}).reset_index()
grf_ablation_stats.columns = ['wpn', 'lml_mean', 'lml_std', 'mse_mean', 'mse_std']

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))

ax1.fill_between(grf_stats['wpn'], 
                 grf_stats['lml_mean'] - grf_stats['lml_std'], 
                 grf_stats['lml_mean'] + grf_stats['lml_std'], 
                 alpha=0.3, color='blue')
ax1.plot(grf_stats['wpn'], grf_stats['lml_mean'], 'o-', color='blue', label='GRF')

ax1.fill_between(grf_ablation_stats['wpn'], 
                 grf_ablation_stats['lml_mean'] - grf_ablation_stats['lml_std'], 
                 grf_ablation_stats['lml_mean'] + grf_ablation_stats['lml_std'], 
                 alpha=0.3, color='orange')
ax1.plot(grf_ablation_stats['wpn'], grf_ablation_stats['lml_mean'], 's-', color='orange', label='GRF-ablation')

ax1.axhline(y=diffusion_mean_lml, color='red', linestyle='--', label='Diffusion')

ax1.set_xlabel('Walks per Node')
ax1.set_ylabel('Log Marginal Likelihood')
ax1.set_xscale('log')
ax1.set_title('Log Marginal Likelihood vs Walks per Node')
ax1.legend()
ax1.grid(True, alpha=0.3)

ax2.fill_between(grf_stats['wpn'], 
                 grf_stats['mse_mean'] - grf_stats['mse_std'], 
                 grf_stats['mse_mean'] + grf_stats['mse_std'], 
                 alpha=0.3, color='blue')
ax2.plot(grf_stats['wpn'], grf_stats['mse_mean'], 'o-', color='blue', label='GRF')

ax2.fill_between(grf_ablation_stats['wpn'], 
                 grf_ablation_stats['mse_mean'] - grf_ablation_stats['mse_std'], 
                 grf_ablation_stats['mse_mean'] + grf_ablation_stats['mse_std'], 
                 alpha=0.3, color='orange')
ax2.plot(grf_ablation_stats['wpn'], grf_ablation_stats['mse_mean'], 's-', color='orange', label='GRF-ablation')

ax2.axhline(y=diffusion_mean_mse, color='red', linestyle='--', label='Diffusion')

ax2.set_xlabel('Walks per Node')
ax2.set_ylabel('Mean Squared Error')
ax2.set_xscale('log')
ax2.set_yscale('log')
ax2.set_title('MSE vs Walks per Node')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
print("Performance Summary:")
print(f"Diffusion - LML: {diffusion_mean_lml:.3f}, MSE: {diffusion_mean_mse:.6f}")
print("\nBest GRF performance:")
best_grf_lml = grf_stats.loc[grf_stats['lml_mean'].idxmax()]
best_grf_mse = grf_stats.loc[grf_stats['mse_mean'].idxmin()]
print(f"GRF - Best LML: {best_grf_lml['lml_mean']:.3f} (wpn={int(best_grf_lml['wpn'])})")
print(f"GRF - Best MSE: {best_grf_mse['mse_mean']:.6f} (wpn={int(best_grf_mse['wpn'])})")

print("\nBest GRF-ablation performance:")
best_grf_ab_lml = grf_ablation_stats.loc[grf_ablation_stats['lml_mean'].idxmax()]
best_grf_ab_mse = grf_ablation_stats.loc[grf_ablation_stats['mse_mean'].idxmin()]
print(f"GRF-ablation - Best LML: {best_grf_ab_lml['lml_mean']:.3f} (wpn={int(best_grf_ab_lml['wpn'])})")
print(f"GRF-ablation - Best MSE: {best_grf_ab_mse['mse_mean']:.6f} (wpn={int(best_grf_ab_mse['wpn'])})")