# Tutorial 4: Visualization Guide

## Creating Figures with ASCICat

This tutorial covers advanced visualization techniques for presenting ASCI results:

1. Understanding the 4-panel standard figure
2. Creating custom visualizations
3. 3D score space visualization
4. Radar charts for catalyst comparison
5. Exporting high-resolution figures

---

## 1. Setup and Data Preparation

In [None]:
import sys
sys.path.insert(0, '..')

from ascicat import ASCICalculator
from ascicat.visualizer import Visualizer
from ascicat.config import REACTION_CONFIGS

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import matplotlib.patches as mpatches
from matplotlib.colors import LinearSegmentedColormap
import seaborn as sns
import os

# Set figure defaults
plt.rcParams.update({
    'figure.dpi': 150,
    'savefig.dpi': 600,
    'font.size': 11,
    'axes.labelsize': 12,
    'axes.titlesize': 13,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 10,
    'figure.figsize': (10, 8),
    'axes.linewidth': 1.2,
    'axes.grid': True,
    'grid.alpha': 0.3
})

print("Visualization setup complete!")

In [None]:
# Load and prepare HER data
calc = ASCICalculator(reaction='HER', scoring_method='linear')
calc.load_data('../data/HER_clean.csv')
results = calc.calculate_asci(w_a=0.33, w_s=0.33, w_c=0.34)

print(f"Loaded {len(results)} catalysts")
print(f"ASCI range: [{results['ASCI'].min():.3f}, {results['ASCI'].max():.3f}]")

## 2. The Standard 4-Panel Figure

ASCICat generates a standardized 4-panel figure:

- **Panel A**: 3D score space (activity, stability, cost)
- **Panel B**: ASCI rank vs adsorption energy
- **Panel C**: Volcano optimization landscape
- **Panel D**: Top performers breakdown

In [None]:
# Generate using built-in Visualizer
viz = Visualizer(results, calc.config)

# Create output directory
output_dir = '../results/tutorial_viz_figures'
os.makedirs(output_dir, exist_ok=True)

# Generate standard figures
viz.generate_publication_figures(output_dir)
print(f"Standard figures saved to: {output_dir}")

## 3. Custom 3D Score Space Visualization

In [None]:
# Create custom 3D visualization
fig = plt.figure(figsize=(12, 10))
ax = fig.add_subplot(111, projection='3d')

# Sample data for clarity (stratified by ASCI score)
n_samples = min(500, len(results))
sampled = results.sample(n=n_samples, random_state=42)

# Create scatter plot
scatter = ax.scatter(
    sampled['activity_score'],
    sampled['stability_score'],
    sampled['cost_score'],
    c=sampled['ASCI'],
    cmap='viridis',
    s=30,
    alpha=0.6
)

# Highlight top 10
top_10 = results.nsmallest(10, 'rank')
ax.scatter(
    top_10['activity_score'],
    top_10['stability_score'],
    top_10['cost_score'],
    c='red',
    s=150,
    marker='*',
    edgecolors='black',
    linewidths=1,
    label='Top 10 ASCI'
)

# Labels and title
ax.set_xlabel('Activity Score', fontsize=12, labelpad=10)
ax.set_ylabel('Stability Score', fontsize=12, labelpad=10)
ax.set_zlabel('Cost Score', fontsize=12, labelpad=10)
ax.set_title('3D ASCI Score Space', fontsize=14, fontweight='bold', pad=20)

# Add colorbar
cbar = fig.colorbar(scatter, ax=ax, shrink=0.6, pad=0.1)
cbar.set_label('ASCI Score', fontsize=11)

ax.legend(loc='upper left')
ax.view_init(elev=25, azim=45)

plt.tight_layout()
plt.savefig(f'{output_dir}/3D_score_space.png', dpi=300, bbox_inches='tight')
plt.show()
print(f"Saved: {output_dir}/3D_score_space.png")

## 4. Volcano Plot with ASCI Overlay

In [None]:
# Create enhanced volcano plot
fig, ax = plt.subplots(figsize=(12, 8))

# Main scatter
scatter = ax.scatter(
    results['DFT_ads_E'],
    results['activity_score'],
    c=results['ASCI'],
    cmap='viridis',
    s=40,
    alpha=0.6,
    edgecolors='none'
)

# Mark Sabatier optimum
ax.axvline(calc.config.optimal_energy, color='red', linestyle='--', 
           linewidth=2, label=f'Sabatier Optimum ({calc.config.optimal_energy} eV)')

# Add activity window
sigma = calc.config.activity_width
ax.axvspan(calc.config.optimal_energy - sigma, 
           calc.config.optimal_energy + sigma,
           alpha=0.1, color='red', label=f'Optimal Window (\u00b1{sigma} eV)')

# Highlight top 10 with labels
top_10 = results.nsmallest(10, 'rank')
ax.scatter(top_10['DFT_ads_E'], top_10['activity_score'],
           c='red', s=120, marker='*', edgecolors='black',
           linewidths=1, zorder=5, label='Top 10 ASCI')

# Add labels for top 5
for idx, row in top_10.head(5).iterrows():
    label = row.get('symbol', f'#{int(row["rank"])}')
    ax.annotate(label, (row['DFT_ads_E'], row['activity_score']),
                xytext=(5, 5), textcoords='offset points', fontsize=9,
                fontweight='bold')

# Styling
ax.set_xlabel('Adsorption Energy (eV)', fontsize=13)
ax.set_ylabel('Activity Score', fontsize=13)
ax.set_title('HER Volcano Plot with ASCI Overlay', fontsize=14, fontweight='bold')
ax.legend(loc='lower right', fontsize=10)

# Colorbar
cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('ASCI Score', fontsize=11)

plt.tight_layout()
plt.savefig(f'{output_dir}/volcano_ASCI_overlay.png', dpi=300, bbox_inches='tight')
plt.show()

## 5. Radar Chart: Top Catalyst Comparison

In [None]:
def create_radar_chart(catalysts_df, title="Catalyst Comparison"):
    """
    Create a radar chart comparing multiple catalysts.
    """
    categories = ['Activity', 'Stability', 'Cost']
    N = len(categories)
    
    # Create angles for radar chart
    angles = [n / float(N) * 2 * np.pi for n in range(N)]
    angles += angles[:1]  # Complete the circle
    
    fig, ax = plt.subplots(figsize=(10, 8), subplot_kw=dict(polar=True))
    
    colors = plt.cm.tab10(np.linspace(0, 1, len(catalysts_df)))
    
    for idx, (_, row) in enumerate(catalysts_df.iterrows()):
        values = [row['activity_score'], row['stability_score'], row['cost_score']]
        values += values[:1]  # Complete the circle
        
        label = row.get('symbol', f"Rank #{int(row['rank'])}")
        
        ax.plot(angles, values, 'o-', linewidth=2, label=f"{label} (ASCI: {row['ASCI']:.3f})",
                color=colors[idx])
        ax.fill(angles, values, alpha=0.1, color=colors[idx])
    
    # Set category labels
    ax.set_xticks(angles[:-1])
    ax.set_xticklabels(categories, fontsize=12)
    
    # Set radial limits
    ax.set_ylim(0, 1)
    ax.set_yticks([0.2, 0.4, 0.6, 0.8, 1.0])
    ax.set_yticklabels(['0.2', '0.4', '0.6', '0.8', '1.0'], fontsize=9)
    
    ax.set_title(title, fontsize=14, fontweight='bold', pad=20)
    ax.legend(loc='upper right', bbox_to_anchor=(1.3, 1.0), fontsize=10)
    
    return fig, ax

# Create radar chart for top 5
top_5 = results.nsmallest(5, 'rank')
fig, ax = create_radar_chart(top_5, "Top 5 HER Catalysts: Score Breakdown")
plt.tight_layout()
plt.savefig(f'{output_dir}/radar_top5.png', dpi=300, bbox_inches='tight')
plt.show()

## 6. Heatmap: Score Correlation Matrix

In [None]:
# Create correlation matrix
score_cols = ['DFT_ads_E', 'surface_energy', 'Cost', 
              'activity_score', 'stability_score', 'cost_score', 'ASCI']
corr_matrix = results[score_cols].corr()

# Rename for display
rename_dict = {
    'DFT_ads_E': 'Ads. Energy',
    'surface_energy': 'Surf. Energy',
    'Cost': 'Material Cost',
    'activity_score': 'S_activity',
    'stability_score': 'S_stability',
    'cost_score': 'S_cost',
    'ASCI': 'ASCI'
}
corr_matrix = corr_matrix.rename(index=rename_dict, columns=rename_dict)

# Create heatmap
fig, ax = plt.subplots(figsize=(10, 8))

mask = np.triu(np.ones_like(corr_matrix, dtype=bool), k=1)
sns.heatmap(corr_matrix, mask=mask, annot=True, cmap='RdBu_r', center=0,
            fmt='.2f', square=True, linewidths=1, ax=ax,
            cbar_kws={'label': 'Correlation Coefficient'})

ax.set_title('Property and Score Correlation Matrix', fontsize=14, fontweight='bold')

plt.tight_layout()
plt.savefig(f'{output_dir}/correlation_heatmap.png', dpi=300, bbox_inches='tight')
plt.show()

## 7. Distribution Plots with Statistical Annotations

In [None]:
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

score_info = [
    ('activity_score', 'Activity Score', '#2ecc71'),
    ('stability_score', 'Stability Score', '#3498db'),
    ('cost_score', 'Cost Score', '#e74c3c'),
    ('ASCI', 'ASCI Score', '#9b59b6')
]

for ax, (col, title, color) in zip(axes.flat, score_info):
    data = results[col]
    
    # Histogram with KDE
    ax.hist(data, bins=40, color=color, alpha=0.6, edgecolor='white', density=True)
    
    # Add KDE
    from scipy import stats
    kde = stats.gaussian_kde(data)
    x_range = np.linspace(data.min(), data.max(), 100)
    ax.plot(x_range, kde(x_range), color='black', linewidth=2)
    
    # Statistical annotations
    mean_val = data.mean()
    std_val = data.std()
    median_val = data.median()
    
    ax.axvline(mean_val, color='red', linestyle='--', linewidth=2, label=f'Mean: {mean_val:.3f}')
    ax.axvline(median_val, color='blue', linestyle=':', linewidth=2, label=f'Median: {median_val:.3f}')
    
    # Add text box with statistics
    textstr = f'n = {len(data)}\n$\\mu$ = {mean_val:.3f}\n$\\sigma$ = {std_val:.3f}'
    props = dict(boxstyle='round', facecolor='white', alpha=0.8)
    ax.text(0.95, 0.95, textstr, transform=ax.transAxes, fontsize=10,
            verticalalignment='top', horizontalalignment='right', bbox=props)
    
    ax.set_xlabel(title, fontsize=12)
    ax.set_ylabel('Density', fontsize=12)
    ax.set_title(f'{title} Distribution', fontsize=13, fontweight='bold')
    ax.legend(loc='upper left', fontsize=9)

plt.tight_layout()
plt.savefig(f'{output_dir}/score_distributions.png', dpi=300, bbox_inches='tight')
plt.show()

## 8. Ranked Bar Chart: Top Performers

In [None]:
# Create horizontal bar chart for top 20
top_20 = results.nsmallest(20, 'rank').copy()
top_20 = top_20.iloc[::-1]  # Reverse for plotting

fig, ax = plt.subplots(figsize=(12, 10))

# Create stacked bar chart
labels = [row.get('symbol', f"#{int(row['rank'])}") for _, row in top_20.iterrows()]
y_pos = np.arange(len(labels))

# Weight contributions
w_a, w_s, w_c = 0.33, 0.33, 0.34
activity_contrib = top_20['activity_score'] * w_a
stability_contrib = top_20['stability_score'] * w_s
cost_contrib = top_20['cost_score'] * w_c

# Stacked bars
bars1 = ax.barh(y_pos, activity_contrib, color='#2ecc71', label=f'Activity ({w_a})')
bars2 = ax.barh(y_pos, stability_contrib, left=activity_contrib, 
                color='#3498db', label=f'Stability ({w_s})')
bars3 = ax.barh(y_pos, cost_contrib, left=activity_contrib + stability_contrib,
                color='#e74c3c', label=f'Cost ({w_c})')

# Add ASCI values
for i, (idx, row) in enumerate(top_20.iterrows()):
    ax.text(row['ASCI'] + 0.01, i, f"{row['ASCI']:.3f}", va='center', fontsize=9)

ax.set_yticks(y_pos)
ax.set_yticklabels(labels)
ax.set_xlabel('ASCI Score (Weighted Contribution)', fontsize=12)
ax.set_title('Top 20 Catalysts: ASCI Score Breakdown', fontsize=14, fontweight='bold')
ax.legend(loc='lower right', fontsize=10)
ax.set_xlim(0, 1.1)

plt.tight_layout()
plt.savefig(f'{output_dir}/top20_breakdown.png', dpi=300, bbox_inches='tight')
plt.show()

## 9. Export High-Resolution Figures

Best practices for exporting figures:

In [None]:
def save_figure(fig, filename, formats=['png', 'pdf', 'svg']):
    """
    Save figure in multiple formats.
    
    Parameters:
    -----------
    fig : matplotlib figure
    filename : str, base filename without extension
    formats : list, output formats
    """
    for fmt in formats:
        filepath = f"{filename}.{fmt}"
        fig.savefig(filepath, 
                    dpi=600 if fmt == 'png' else 300,
                    bbox_inches='tight',
                    facecolor='white',
                    edgecolor='none')
        print(f"Saved: {filepath}")

# Example: Create and save a figure
fig, ax = plt.subplots(figsize=(8, 6))

scatter = ax.scatter(
    results['DFT_ads_E'], 
    results['ASCI'],
    c=results['ASCI'],
    cmap='viridis',
    s=25,
    alpha=0.7
)

ax.axvline(calc.config.optimal_energy, color='red', linestyle='--', linewidth=2)
ax.set_xlabel('Adsorption Energy (eV)')
ax.set_ylabel('ASCI Score')
ax.set_title('HER Catalyst Screening Results')

cbar = plt.colorbar(scatter, ax=ax)
cbar.set_label('ASCI')

plt.tight_layout()

# Save in multiple formats
save_figure(fig, f'{output_dir}/example_figure', formats=['png', 'pdf'])
plt.show()

## 10. Summary of Generated Figures

List all figures created in this tutorial:

In [None]:
# List all generated figures
print("Generated Figures")
print("="*60)

if os.path.exists(output_dir):
    files = sorted(os.listdir(output_dir))
    for f in files:
        filepath = os.path.join(output_dir, f)
        size_kb = os.path.getsize(filepath) / 1024
        print(f"  {f:<40} ({size_kb:.1f} KB)")
else:
    print("Output directory not found.")

## Summary

In this tutorial, we covered:

1. **Standard 4-panel figures** using the built-in Visualizer
2. **3D score space** visualization for exploring the multi-objective landscape
3. **Volcano plots** with ASCI overlay
4. **Radar charts** for comparing individual catalysts
5. **Correlation heatmaps** for understanding property relationships
6. **Distribution plots** with statistical annotations
7. **Stacked bar charts** showing ASCI component breakdown
8. **High-resolution export** techniques

### Tips for Figures
- Use 600 DPI for PNG exports
- Export PDF/SVG for vector graphics
- Maintain consistent color schemes
- Include clear legends and axis labels
- Use appropriate marker sizes for data density

---

**Next:** Tutorial 5 - Sensitivity Analysis