# GDyNet-ferro Results Analysis

This notebook demonstrates comprehensive analysis of GDyNet training results and predictions.

**Features:**
- Training metrics visualization (loss, VAMP1, VAMP2)
- State population analysis (pie chart, bar chart)
- Koopman operator analysis (relaxation timescales, CK tests, eigenanalysis)
- 3D visualization of state probabilities

**Updated for current codebase structure with postprocess module**

## 1. Setup and Imports

In [None]:
import os
import sys

# Add project root to path
project_root = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.insert(0, project_root)

import glob
import json
from pathlib import Path

import numpy as np
import pandas as pd
import torch

import matplotlib
import matplotlib.pyplot as plt
import matplotlib.cm as cmx
from mpl_toolkits.mplot3d import Axes3D
from sklearn.preprocessing import MinMaxScaler
import seaborn as sns

%matplotlib inline

# Project-specific imports
from postprocess.postprocess import GDYNetAnalyzer, analyze_predictions
from postprocess.koopman_postprocess import (
    KoopmanAnalysis,
    plot_timescales,
    plot_ck_tests,
    plot_eigenanalysis,
)

# Plotting configuration
plt.rcParams.update({
    'figure.figsize': (10, 6),
    'axes.grid': True,
    'grid.alpha': 0.3,
    'axes.spines.top': False,
    'axes.spines.right': False,
})
sns.set_style('whitegrid')

print(f"Project root: {project_root}")
print("Imports successful!")

## 2. Configuration

Update these paths to match your experiment output directory.

In [None]:
# ===== UPDATE THESE PATHS =====
# Model type: 'gdynet_vanilla' or 'gdynet_ferro'
MODEL_NAME = 'gdynet_vanilla'

# Base output directory (created by trainer.py)
OUTPUT_DIR = f'../output/{MODEL_NAME}'

# Time parameters (adjust based on your simulation)
# Here you have to put the dumping rate. For LAMMPS --> timestep * dump_rate 
# Example - timestep is 0.25 fs and dump_rate is 4 then 4*0.25 fs --> 1 fs --> 1e-6 ns (1 fs)
TIME_UNIT_NS = 1e-4  # Timestep in nanoseconds (e.g., 0.1 ps = 1e-4 ns)

# ===== Construct file paths =====
predictions_path = f'{OUTPUT_DIR}/predictions/{MODEL_NAME}_predictions.npy'
metrics_path = f'{OUTPUT_DIR}/metrics/metrics.json'

print(f"Model: {MODEL_NAME}")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Predictions: {predictions_path}")
print(f"Metrics: {metrics_path}")

# Check if files exist
pred_exists = Path(predictions_path).exists()
metrics_exists = Path(metrics_path).exists()

print(f"\nFiles exist:")
print(f"  Predictions: {'Yes' if pred_exists else 'NO - Run evaluation first'}")
print(f"  Metrics: {'Yes' if metrics_exists else 'NO - Run training first'}")

## 3. Load Results

In [None]:
# Create analyzer (handles both predictions and metrics)
analyzer = GDYNetAnalyzer(
    predictions_path=predictions_path,
    metrics_path=metrics_path if metrics_exists else None
)

# Get predictions
preds = analyzer.predictions
n_states = preds.shape[2]

print(f"Predictions shape: {preds.shape}")
print(f"  Frames: {preds.shape[0]}")
print(f"  Atoms/Batch: {preds.shape[1]}")
print(f"  States: {n_states}")

## 4. Training Metrics Visualization

In [None]:
if analyzer.metrics:
    # Load metrics
    metrics = analyzer.metrics
    
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    # Plot Loss
    ax = axes[0]
    if 'train_losses_avg' in metrics:
        ax.plot(metrics['train_losses_avg'], label='Train', linewidth=2)
    if 'val_losses_avg' in metrics:
        ax.plot(metrics['val_losses_avg'], label='Val', linewidth=2)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('Loss')
    ax.set_title('Training and Validation Loss')
    ax.legend()
    
    # Plot VAMP1
    ax = axes[1]
    if 'train_vamp1_scores_avg' in metrics:
        ax.plot(metrics['train_vamp1_scores_avg'], label='Train', linewidth=2)
    if 'val_vamp1_scores_avg' in metrics:
        ax.plot(metrics['val_vamp1_scores_avg'], label='Val', linewidth=2)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('VAMP1 Score')
    ax.set_title('VAMP1 Metric')
    ax.legend()
    
    # Plot VAMP2
    ax = axes[2]
    if 'train_vamp2_scores_avg' in metrics:
        ax.plot(metrics['train_vamp2_scores_avg'], label='Train', linewidth=2)
    if 'val_vamp2_scores_avg' in metrics:
        ax.plot(metrics['val_vamp2_scores_avg'], label='Val', linewidth=2)
    ax.set_xlabel('Epoch')
    ax.set_ylabel('VAMP2 Score')
    ax.set_title('VAMP2 Metric')
    ax.legend()
    
    plt.tight_layout()
    plt.savefig(f'{OUTPUT_DIR}/analysis_training_metrics.png', dpi=300, bbox_inches='tight')
    plt.show()
else:
    print("No metrics available. Run training first.")

## 5. State Populations

Visualize the population of each dynamical state using pie chart and bar chart.

In [None]:
# Calculate state populations
probs = preds.sum(axis=(0, 1))
probs = probs / probs.sum()

labels = [f'State {i}' for i in range(len(probs))]
colors = plt.cm.Set2(np.linspace(0, 1, len(probs)))

fig, axes = plt.subplots(1, 2, figsize=(12, 5))

# Pie chart
ax = axes[0]
ax.pie(probs, labels=labels, autopct='%1.2f%%', colors=colors)
ax.axis('equal')
ax.set_title('State Populations')

# Bar chart
ax = axes[1]
ax.bar(labels, probs * 100, color=colors, edgecolor='black')
ax.set_ylabel('Population (%)')
ax.set_title('State Populations')

plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/analysis_state_populations.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nState Populations:")
for i, p in enumerate(probs):
    print(f"  State {i}: {p*100:.2f}%")

## 6. Relaxation Timescales

Analyze the evolution of relaxation timescales as a function of lag time. The model follows Markovian dynamics when timescales become constant.

In [None]:
# Define lag times
max_tau = min(500, preds.shape[0] // 10)
lag = np.arange(1, max_tau, max(1, max_tau // 50))

print(f"Analyzing timescales with lags from 1 to {max_tau}...")

plot_timescales(
    preds,
    lag,
    n_splits=1,
    split_axis=0,
    time_unit_in_ns=TIME_UNIT_NS
)

plt.savefig(f'{OUTPUT_DIR}/analysis_timescales.png', dpi=300, bbox_inches='tight')
plt.show()

## 7. Chapman-Kolmogorov Test

Validate the Markovian assumption by comparing predicted (blue) and estimated (red) transition probabilities. They should match well at the chosen lag time.

In [None]:
# Choose tau_msm based on timescale analysis (where timescales become constant)
tau_msm = min(100, preds.shape[0] // 20)

print(f"Performing CK test with tau_msm = {tau_msm}...")

plot_ck_tests(
    preds,
    tau_msm=tau_msm,
    steps=5,
    n_splits=1,
    split_axis=0,
    time_unit_in_ns=TIME_UNIT_NS
)

plt.savefig(f'{OUTPUT_DIR}/analysis_ck_test.png', dpi=300, bbox_inches='tight')
plt.show()

## 8. Koopman Operator Eigenanalysis

Analyze the Koopman operator eigenvectors to understand the dynamical modes.

- The largest eigenvalue = 1 corresponds to the stationary distribution
- Subsequent eigenvectors show transitions between states

In [None]:
# Create Koopman analyzer
vamp_analyzer = KoopmanAnalysis(epsilon=1e-5)

# Estimate Koopman operator
koopman_op = vamp_analyzer.estimate_koopman_op(preds, tau_msm)

print("Koopman Operator:")
print(koopman_op)

# Plot eigenanalysis
eigvals, eigvecs = plot_eigenanalysis(
    koopman_op,
    tau_msm,
    time_unit_in_ns=TIME_UNIT_NS
)

plt.savefig(f'{OUTPUT_DIR}/analysis_eigenanalysis.png', dpi=300, bbox_inches='tight')
plt.show()

print("\nEigenvalues:", eigvals)
print("\nRelaxation timescales (ns):")
for i, ev in enumerate(eigvals[1:], 1):
    if ev > 0 and ev < 1:
        ts = -tau_msm * TIME_UNIT_NS / np.log(ev)
        print(f"  Mode {i}: {ts:.4f} ns")

## 9. Transition Matrix Visualization

In [None]:
plt.figure(figsize=(8, 6))

sns.heatmap(
    koopman_op,
    annot=True,
    fmt='.3f',
    cmap='YlOrRd',
    xticklabels=[f'State {i}' for i in range(koopman_op.shape[0])],
    yticklabels=[f'State {i}' for i in range(koopman_op.shape[0])],
    vmin=0,
    vmax=1,
    cbar_kws={'label': 'Transition Probability'}
)

plt.xlabel('To State (t + tau)')
plt.ylabel('From State (t)')
plt.title(f'Transition Matrix (tau={tau_msm})')
plt.tight_layout()
plt.savefig(f'{OUTPUT_DIR}/analysis_transition_matrix.png', dpi=300, bbox_inches='tight')
plt.show()

print(f"Row sums (should be ~1.0): {koopman_op.sum(axis=1)}")

## 10. Stationary Distribution

In [None]:
# Eigendecomposition of K^T for stationary distribution
eigvals_full, eigvecs_full = np.linalg.eig(koopman_op.T)

# Stationary distribution (eigenvector with eigenvalue = 1)
idx = np.argmax(np.abs(eigvals_full))
stationary = np.real(eigvecs_full[:, idx])
stationary = np.abs(stationary) / np.abs(stationary).sum()  # Normalize

print("Stationary Distribution:")
for i, s in enumerate(stationary):
    print(f"  State {i}: {s*100:.2f}%")

# Compare with empirical populations
print("\nComparison with Empirical Populations:")
for i in range(len(stationary)):
    print(f"  State {i}: Stationary={stationary[i]*100:.2f}%, Empirical={probs[i]*100:.2f}%")

## 11. 3D Visualization of State Probabilities

Visualize the probability distribution of states in 3D space.

**Note:** Requires trajectory coordinates. Update the paths below if you have them.

In [None]:
def scatter3d(x, y, z, cs, title, colorsMap='jet', angle=30):
    """Plot 3D scatter with color-coded probabilities."""
    scaler = MinMaxScaler(feature_range=(0, 1))
    cs_scaled = scaler.fit_transform(cs.reshape(-1, 1)).flatten()
    
    cm = plt.get_cmap(colorsMap)
    cNorm = matplotlib.colors.Normalize(vmin=min(cs_scaled), vmax=max(cs_scaled))
    scalarMap = cmx.ScalarMappable(norm=cNorm, cmap=cm)
    
    fig = plt.figure(figsize=(8, 6))
    ax = fig.add_subplot(111, projection='3d', proj_type='ortho')
    ax.axis('auto')
    ax.scatter(x, y, z, s=2, c=scalarMap.to_rgba(cs_scaled), marker='.', alpha=0.8)
    scalarMap.set_array(cs_scaled)
    ax.view_init(10, angle)
    
    # Add colorbar
    cbar_ax = fig.add_axes([0.88, 0.15, 0.02, 0.7])
    cbar = fig.colorbar(scalarMap, cax=cbar_ax)
    cbar.set_label('Probability')
    
    plt.suptitle(title)
    plt.tight_layout()
    return fig

In [None]:
# ===== UPDATE THESE PATHS FOR YOUR DATA =====
# Uncomment and modify if you have trajectory coordinates

# Example for NPZ graph file:
# graph_path = '/path/to/your/graphs.npz'
# if Path(graph_path).exists():
#     graph = np.load(graph_path)
#     traj_coords = graph['traj_coords']
#     target_index = graph['target_index']
#     
#     print(f"Trajectory shape: {traj_coords.shape}")
#     print(f"Target indices shape: {target_index.shape}")
#     
#     # Subsample for visualization (every 100 frames)
#     subsample = 100
#     ppred = preds[:traj_coords.shape[0]]  # Match frames
#     
#     for i in range(n_states):
#         fig = scatter3d(
#             traj_coords[::subsample, target_index, 0].flatten(),
#             traj_coords[::subsample, target_index, 1].flatten(),
#             traj_coords[::subsample, target_index, 2].flatten(),
#             cs=ppred[::subsample].reshape(-1, n_states)[:, i],
#             title=f'State {i}',
#             angle=260
#         )
#         fig.savefig(f'{OUTPUT_DIR}/analysis_3d_state_{i}.png', dpi=300, bbox_inches='tight')
#         plt.show()

print("3D visualization requires trajectory coordinates.")
print("Uncomment and modify the code above if you have graph files with traj_coords.")

## 12. State Transitions Over Time

In [None]:
# Analyze state switching events
dominant_states = preds.argmax(axis=2)  # (n_frames, n_atoms)

# Count switches
switches = (dominant_states[1:] != dominant_states[:-1]).sum()
print(f"Total state switches: {switches}")
print(f"Switches per frame: {switches / (dominant_states.shape[0] - 1):.2f}")
print(f"Switches per atom: {switches / dominant_states.shape[1]:.2f}")

# Plot switches over time
switches_per_frame = (dominant_states[1:] != dominant_states[:-1]).sum(axis=1)

plt.figure(figsize=(14, 5))
plt.plot(switches_per_frame, alpha=0.7, linewidth=1)
plt.xlabel('Frame')
plt.ylabel('Number of Atoms Switching State')
plt.title('State Switching Events Over Time')
plt.grid(True, alpha=0.3)
plt.savefig(f'{OUTPUT_DIR}/analysis_state_switches.png', dpi=300, bbox_inches='tight')
plt.show()

## 13. Export Summary

In [None]:
# Export comprehensive summary
summary_path = f'{OUTPUT_DIR}/analysis_summary.json'
analyzer.export_summary(summary_path)

print(f"Analysis summary exported to: {summary_path}")

# List all saved files
print("\n" + "="*60)
print("SAVED FILES")
print("="*60)
for f in sorted(glob.glob(f'{OUTPUT_DIR}/analysis_*.png')):
    print(f"  {os.path.basename(f)}")
print(f"  {os.path.basename(summary_path)}")

---

## Summary

This notebook demonstrated:

1. **Training Metrics** - Loss and VAMP score visualization
2. **State Populations** - Pie chart and bar chart of dynamical states
3. **Relaxation Timescales** - Identification of Markovian lag time
4. **Chapman-Kolmogorov Test** - Validation of Markovian dynamics
5. **Eigenanalysis** - Koopman operator decomposition
6. **Transition Matrix** - State transition probabilities
7. **Stationary Distribution** - Long-time state populations
8. **3D Visualization** - Spatial distribution of states
9. **State Transitions** - Switching events over time

All analysis plots are saved to:
```
output/{model_name}/
├── analysis_training_metrics.png
├── analysis_state_populations.png
├── analysis_timescales.png
├── analysis_ck_test.png
├── analysis_eigenanalysis.png
├── analysis_transition_matrix.png
├── analysis_state_switches.png
└── analysis_summary.json
```