In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Data
architectures = ['Baseline\n(2 layers)', '+ LayerNorm', '+ Skip Conn.', '+ 7 Layers']
mse_values = [0.2000, 0.1380, 0.1350, 0.0778]
improvements = [None, 31.0, 2.2, 42.4]

# Purple-blue-green colormap
colors = ['#440154', '#3b528d', '#21918c', '#5ec962']

fig, ax = plt.subplots(figsize=(10, 6))

# Create bars
bars = ax.bar(architectures, mse_values, color=colors, edgecolor='black', linewidth=1.5, alpha=0.8)

# Add MSE values on top of bars
for i, (bar, mse) in enumerate(zip(bars, mse_values)):
    height = bar.get_height()
    ax.text(bar.get_x() + bar.get_width()/2., height,
            f'{mse:.4f}',
            ha='center', va='bottom', fontsize=11, fontweight='bold')
    
    # Add improvement percentage
    if improvements[i] is not None:
        ax.text(bar.get_x() + bar.get_width()/2., height * 0.5,
                f'â†“ {improvements[i]:.1f}%',
                ha='center', va='center', fontsize=10, fontweight='bold', color='white')

# Styling
ax.set_ylabel('Validation MSE', fontsize=12, fontweight='bold')
ax.set_xlabel('Model Architecture', fontsize=12, fontweight='bold')
ax.set_title('Ablation Study: Impact of Optimizations on GCN Performance', 
             fontsize=13, fontweight='bold', pad=20)

ax.set_ylim(0, 0.22)
ax.grid(axis='y', alpha=0.3, linestyle='--')
ax.set_axisbelow(True)

# Remove top and right spines
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)

plt.tight_layout()
plt.savefig('ablation_study_plot.png', dpi=300, bbox_inches='tight')
plt.show()