In [None]:
import sys
import os
from pathlib import Path

# Add project root to path
PROJECT_ROOT = Path('/mnt/home/mlee1/vdm_BIND')
sys.path.insert(0, str(PROJECT_ROOT))

import numpy as np
import matplotlib.pyplot as plt
from tensorboard.backend.event_processing import event_accumulator
import glob
from collections import defaultdict

# Setup publication-quality plotting
plt.rcParams.update({
    'font.size': 14,
    'font.family': 'serif',
    'axes.labelsize': 14,
    'axes.titlesize': 16,
    'xtick.labelsize': 12,
    'ytick.labelsize': 12,
    'legend.fontsize': 11,
    'figure.figsize': (12, 8),
    'figure.dpi': 150,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight',
})

# Output directory for figures
FIGURE_DIR = PROJECT_ROOT / 'analysis' / 'figures' / 'training'
FIGURE_DIR.mkdir(parents=True, exist_ok=True)
print(f"Figures will be saved to: {FIGURE_DIR}")

## Configuration

Select the model type and version to analyze.

In [None]:
# ============================================================================
# MODEL CONFIGURATION - MODIFY THIS SECTION
# ============================================================================

# Model type: 'clean' (3-channel joint) or 'triple' (3 separate VDMs)
MODEL_TYPE = 'clean'

# Model name and version
MODEL_NAME = 'clean_vdm_aggressive_stellar'
VERSION = 2

# TensorBoard logs root
TB_LOGS_ROOT = Path('/mnt/home/mlee1/ceph/tb_logs')

# Build log path
LOG_PATH = TB_LOGS_ROOT / MODEL_NAME / f'version_{VERSION}'
print(f"Loading logs from: {LOG_PATH}")

# Verify path exists
if not LOG_PATH.exists():
    print(f"❌ Log path does not exist!")
    print(f"Available versions:")
    available = list((TB_LOGS_ROOT / MODEL_NAME).glob('version_*'))
    for v in sorted(available):
        print(f"  - {v.name}")
else:
    print(f"✓ Log path found")

## Load TensorBoard Events

In [None]:
def load_tensorboard_logs(log_path):
    """
    Load all scalar metrics from TensorBoard event files.
    
    Returns:
        dict: {metric_name: {'steps': [...], 'values': [...]}}
    """
    event_files = list(Path(log_path).glob('events.out.tfevents.*'))
    if not event_files:
        raise FileNotFoundError(f"No event files found in {log_path}")
    
    print(f"Found {len(event_files)} event file(s)")
    
    metrics = defaultdict(lambda: {'steps': [], 'values': [], 'wall_times': []})
    
    for event_file in event_files:
        ea = event_accumulator.EventAccumulator(
            str(event_file),
            size_guidance={'scalars': 0}  # Load all scalars
        )
        ea.Reload()
        
        for tag in ea.Tags()['scalars']:
            events = ea.Scalars(tag)
            for event in events:
                metrics[tag]['steps'].append(event.step)
                metrics[tag]['values'].append(event.value)
                metrics[tag]['wall_times'].append(event.wall_time)
    
    # Convert to numpy arrays and sort by step
    for tag in metrics:
        sort_idx = np.argsort(metrics[tag]['steps'])
        metrics[tag]['steps'] = np.array(metrics[tag]['steps'])[sort_idx]
        metrics[tag]['values'] = np.array(metrics[tag]['values'])[sort_idx]
        metrics[tag]['wall_times'] = np.array(metrics[tag]['wall_times'])[sort_idx]
    
    return dict(metrics)

# Load the logs
metrics = load_tensorboard_logs(LOG_PATH)
print(f"\nLoaded {len(metrics)} metrics:")
for tag in sorted(metrics.keys()):
    n_points = len(metrics[tag]['steps'])
    print(f"  {tag}: {n_points} points")

## Plot Overall Training/Validation Loss

In [None]:
def smooth_curve(values, weight=0.9):
    """Exponential moving average smoothing."""
    smoothed = []
    last = values[0]
    for v in values:
        smoothed_val = last * weight + (1 - weight) * v
        smoothed.append(smoothed_val)
        last = smoothed_val
    return np.array(smoothed)

fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Left: Training loss
ax = axes[0]
if 'train/loss' in metrics:
    steps = metrics['train/loss']['steps']
    values = metrics['train/loss']['values']
    ax.plot(steps, values, alpha=0.3, color='blue', label='Raw')
    ax.plot(steps, smooth_curve(values), color='blue', linewidth=2, label='Smoothed')
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
ax.set_title('Training Loss')
ax.legend()
ax.grid(True, alpha=0.3)

# Right: Validation loss
ax = axes[1]
if 'val/loss' in metrics:
    steps = metrics['val/loss']['steps']
    values = metrics['val/loss']['values']
    ax.plot(steps, values, 'o-', color='orange', markersize=3, label='Validation')
elif 'val/elbo' in metrics:
    steps = metrics['val/elbo']['steps']
    values = metrics['val/elbo']['values']
    ax.plot(steps, values, 'o-', color='orange', markersize=3, label='Validation ELBO')
ax.set_xlabel('Step')
ax.set_ylabel('Loss')
ax.set_title('Validation Loss')
ax.legend()
ax.grid(True, alpha=0.3)

plt.suptitle(f'{MODEL_NAME} (v{VERSION})', fontsize=16, y=1.02)
plt.tight_layout()

# Save figure
fig.savefig(FIGURE_DIR / f'{MODEL_NAME}_v{VERSION}_overall_loss.png')
print(f"Saved: {FIGURE_DIR / f'{MODEL_NAME}_v{VERSION}_overall_loss.png'}")
plt.show()

## Plot Per-Channel Losses

Visualize the loss contribution from each output channel (DM Hydro, Gas, Stars).

In [None]:
# Find per-channel loss metrics
channel_metrics = {}
channel_names = ['DM Hydro', 'Gas', 'Stars']
channel_colors = ['#1f77b4', '#2ca02c', '#d62728']  # Blue, Green, Red

# Look for various naming conventions
for i, name in enumerate(channel_names):
    possible_keys = [
        f'train/loss_ch{i}',
        f'train/channel_{i}_loss',
        f'train/loss_channel_{i}',
        f'train/{name.lower().replace(" ", "_")}_loss',
    ]
    for key in possible_keys:
        if key in metrics:
            channel_metrics[name] = metrics[key]
            print(f"Found {name} loss: {key}")
            break

# Also check for diffusion loss per channel (more common in VDM)
if not channel_metrics:
    print("\nLooking for diffusion loss metrics...")
    for key in metrics:
        if 'diffusion' in key.lower() or 'channel' in key.lower():
            print(f"  Found: {key}")

In [None]:
# Plot per-channel losses if available
if channel_metrics:
    fig, ax = plt.subplots(figsize=(12, 6))
    
    for i, (name, data) in enumerate(channel_metrics.items()):
        steps = data['steps']
        values = data['values']
        color = channel_colors[i]
        
        ax.plot(steps, values, alpha=0.3, color=color)
        ax.plot(steps, smooth_curve(values), color=color, linewidth=2, label=name)
    
    ax.set_xlabel('Step')
    ax.set_ylabel('Loss')
    ax.set_title(f'Per-Channel Training Loss - {MODEL_NAME} (v{VERSION})')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.tight_layout()
    fig.savefig(FIGURE_DIR / f'{MODEL_NAME}_v{VERSION}_channel_losses.png')
    print(f"Saved: {FIGURE_DIR / f'{MODEL_NAME}_v{VERSION}_channel_losses.png'}")
    plt.show()
else:
    print("No per-channel loss metrics found.")
    print("Available metrics containing 'loss':")
    for key in sorted(metrics.keys()):
        if 'loss' in key.lower():
            print(f"  {key}")

## Plot VDM Loss Components

VDM has multiple loss terms: diffusion loss, latent loss, and reconstruction loss.

In [None]:
# Find VDM-specific loss components
vdm_loss_keys = {
    'Diffusion': ['train/diffusion_loss', 'train/loss_diffusion', 'diffusion_loss'],
    'Latent': ['train/latent_loss', 'train/loss_latent', 'latent_loss'],
    'Reconstruction': ['train/recons_loss', 'train/reconstruction_loss', 'recons_loss'],
    'ELBO': ['train/elbo', 'elbo'],
}

vdm_losses = {}
for loss_name, possible_keys in vdm_loss_keys.items():
    for key in possible_keys:
        if key in metrics:
            vdm_losses[loss_name] = metrics[key]
            print(f"Found {loss_name}: {key}")
            break

if vdm_losses:
    n_losses = len(vdm_losses)
    fig, axes = plt.subplots(1, n_losses, figsize=(5*n_losses, 5))
    if n_losses == 1:
        axes = [axes]
    
    colors = plt.cm.Set2(np.linspace(0, 1, n_losses))
    
    for ax, (name, data), color in zip(axes, vdm_losses.items(), colors):
        steps = data['steps']
        values = data['values']
        
        ax.plot(steps, values, alpha=0.3, color=color)
        ax.plot(steps, smooth_curve(values), color=color, linewidth=2)
        ax.set_xlabel('Step')
        ax.set_ylabel('Loss')
        ax.set_title(f'{name} Loss')
        ax.grid(True, alpha=0.3)
    
    plt.suptitle(f'VDM Loss Components - {MODEL_NAME} (v{VERSION})', fontsize=16, y=1.02)
    plt.tight_layout()
    
    fig.savefig(FIGURE_DIR / f'{MODEL_NAME}_v{VERSION}_vdm_components.png')
    print(f"Saved: {FIGURE_DIR / f'{MODEL_NAME}_v{VERSION}_vdm_components.png'}")
    plt.show()
else:
    print("No VDM-specific loss components found.")

## Compare Multiple Model Versions (Optional)

Compare loss curves across different model versions or architectures.

In [None]:
# ============================================================================
# MODEL COMPARISON CONFIGURATION
# ============================================================================

COMPARE_MODELS = True  # Set to True to enable comparison

# Models to compare: list of (model_name, version, label, color)
MODELS_TO_COMPARE = [
    ('clean_vdm_aggressive_stellar', 2, 'Clean v2', '#1f77b4'),
    # ('clean_vdm_aggressive_stellar', 1, 'Clean v1', '#ff7f0e'),
    # ('clean_vdm_triple', 0, 'Triple v0', '#2ca02c'),
]

if COMPARE_MODELS and len(MODELS_TO_COMPARE) > 1:
    print("Loading metrics for model comparison...")
    
    comparison_data = {}
    for model_name, version, label, color in MODELS_TO_COMPARE:
        log_path = TB_LOGS_ROOT / model_name / f'version_{version}'
        if log_path.exists():
            try:
                comparison_data[label] = {
                    'metrics': load_tensorboard_logs(log_path),
                    'color': color
                }
                print(f"  ✓ Loaded {label}")
            except Exception as e:
                print(f"  ✗ Failed to load {label}: {e}")
        else:
            print(f"  ✗ Path not found: {log_path}")
else:
    print("Model comparison disabled or only one model specified.")

In [None]:
if COMPARE_MODELS and len(MODELS_TO_COMPARE) > 1 and comparison_data:
    fig, axes = plt.subplots(1, 2, figsize=(14, 5))
    
    # Training loss comparison
    ax = axes[0]
    for label, data in comparison_data.items():
        if 'train/loss' in data['metrics']:
            steps = data['metrics']['train/loss']['steps']
            values = data['metrics']['train/loss']['values']
            ax.plot(steps, smooth_curve(values), color=data['color'], 
                   linewidth=2, label=label)
    ax.set_xlabel('Step')
    ax.set_ylabel('Loss')
    ax.set_title('Training Loss Comparison')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    # Validation loss comparison
    ax = axes[1]
    for label, data in comparison_data.items():
        val_key = 'val/loss' if 'val/loss' in data['metrics'] else 'val/elbo'
        if val_key in data['metrics']:
            steps = data['metrics'][val_key]['steps']
            values = data['metrics'][val_key]['values']
            ax.plot(steps, values, 'o-', color=data['color'], 
                   markersize=3, label=label)
    ax.set_xlabel('Step')
    ax.set_ylabel('Loss')
    ax.set_title('Validation Loss Comparison')
    ax.legend()
    ax.grid(True, alpha=0.3)
    
    plt.suptitle('Model Comparison', fontsize=16, y=1.02)
    plt.tight_layout()
    
    fig.savefig(FIGURE_DIR / 'model_comparison_losses.png')
    print(f"Saved: {FIGURE_DIR / 'model_comparison_losses.png'}")
    plt.show()

## Summary Statistics

In [None]:
print(f"\n{'='*60}")
print(f"Training Summary: {MODEL_NAME} (v{VERSION})")
print(f"{'='*60}\n")

# Final training loss
if 'train/loss' in metrics:
    final_train = metrics['train/loss']['values'][-1]
    min_train = metrics['train/loss']['values'].min()
    print(f"Training Loss:")
    print(f"  Final: {final_train:.4f}")
    print(f"  Min:   {min_train:.4f}")
    print()

# Final validation loss
val_key = 'val/loss' if 'val/loss' in metrics else 'val/elbo'
if val_key in metrics:
    final_val = metrics[val_key]['values'][-1]
    min_val = metrics[val_key]['values'].min()
    best_step = metrics[val_key]['steps'][np.argmin(metrics[val_key]['values'])]
    print(f"Validation Loss:")
    print(f"  Final: {final_val:.4f}")
    print(f"  Best:  {min_val:.4f} (step {best_step})")
    print()

# Training time
if 'train/loss' in metrics:
    wall_times = metrics['train/loss']['wall_times']
    total_time = (wall_times[-1] - wall_times[0]) / 3600  # hours
    total_steps = metrics['train/loss']['steps'][-1]
    print(f"Training Time:")
    print(f"  Total: {total_time:.2f} hours")
    print(f"  Steps: {total_steps}")
    print(f"  Rate:  {total_steps/total_time:.1f} steps/hour")