# Cascaded Prediction Demo

This notebook demonstrates cascaded predictions where:
1. The temperature model predicts the next temperature field
2. The predicted temperature is used as input to the microstructure model to predict the next microstructure

We evaluate on all slices in timestep 18 (first test timestep).

## 1. Setup and Imports

In [None]:
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

sys.path.insert(0, "../src")

from lasernet.data.dataset import LaserDataset
from lasernet.utils import get_num_of_slices, load_model_from_path
from lasernet.visualize import plot_temperature_prediction, plot_microstructure_prediction
from lasernet.cascaded import (
    cascaded_prediction,
    cascaded_prediction_timestep,
    autoregressive_cascaded_prediction,
    compute_cascaded_metrics,
)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 2. Load the Trained Models

In [None]:
# Model checkpoint paths
temp_checkpoint_path = Path("../models/best_predrnn_large_temperature_mseloss.ckpt")
micro_checkpoint_path = Path("../models/best_deepcnn_lstm_large_microstructure_mseloss.ckpt")

# Check paths exist
if not temp_checkpoint_path.exists():
    raise FileNotFoundError(f"Temperature model not found at {temp_checkpoint_path}")
if not micro_checkpoint_path.exists():
    raise FileNotFoundError(f"Microstructure model not found at {micro_checkpoint_path}")

print(f"Loading temperature model from {temp_checkpoint_path}")
print(f"Loading microstructure model from {micro_checkpoint_path}")

In [None]:
# Load temperature model (automatically detects correct model class from filename)
temp_model = load_model_from_path(temp_checkpoint_path)
temp_model = temp_model.to(device)
temp_model = temp_model.half()  # Convert to float16 for memory efficiency
temp_model.eval()

print(f"Temperature model loaded successfully")
print(f"  Model class: {temp_model.__class__.__name__}")
print(f"  Parameters: {temp_model.count_parameters():,}")
print(f"  dtype: {next(temp_model.parameters()).dtype}")

In [None]:
# Load microstructure model (automatically detects correct model class from filename)
micro_model = load_model_from_path(micro_checkpoint_path)
micro_model = micro_model.to(device)
micro_model = micro_model.half()  # Convert to float16 for memory efficiency
micro_model.eval()

print(f"Microstructure model loaded successfully")
print(f"  Model class: {micro_model.__class__.__name__}")
print(f"  Parameters: {micro_model.count_parameters():,}")
print(f"  dtype: {next(micro_model.parameters()).dtype}")

## 3. Load Datasets

In [None]:
# Load temperature training dataset to get normalizer
temp_train_dataset = LaserDataset(
    data_path=Path("../data/processed/"),
    field_type="temperature",
    split="train",
    normalize=True,
    plane="xz",
    sequence_length=3,
    target_offset=1,
)
print(f"Temperature train dataset: {len(temp_train_dataset)} samples")

# Load temperature test dataset with same normalizer
temp_test_dataset = LaserDataset(
    data_path=Path("../data/processed/"),
    field_type="temperature",
    split="test",
    normalize=True,
    normalizer=temp_train_dataset.normalizer,
    plane="xz",
    sequence_length=3,
    target_offset=1,
)
print(f"Temperature test dataset: {len(temp_test_dataset)} samples")
print(f"Temperature data shape: {temp_test_dataset.shape}")

In [None]:
# Load microstructure training dataset to get normalizer
micro_train_dataset = LaserDataset(
    data_path=Path("../data/processed/"),
    field_type="microstructure",
    split="train",
    normalize=True,
    plane="xz",
    sequence_length=3,
    target_offset=1,
)
print(f"Microstructure train dataset: {len(micro_train_dataset)} samples")

# Load microstructure test dataset with same normalizer
micro_test_dataset = LaserDataset(
    data_path=Path("../data/processed/"),
    field_type="microstructure",
    split="test",
    normalize=True,
    normalizer=micro_train_dataset.normalizer,
    plane="xz",
    sequence_length=3,
    target_offset=1,
)
print(f"Microstructure test dataset: {len(micro_test_dataset)} samples")
print(f"Microstructure data shape: {micro_test_dataset.shape}")

## 4. Cascaded Prediction for Timestep 18

Timestep 18 is the first timestep in the test split. We'll predict all 94 slices (xz plane).

In [None]:
# Configuration
target_timestep = 18  # First test timestep
plane = "xz"
num_slices = get_num_of_slices(plane)  # 94 slices for xz plane

print(f"Target timestep: {target_timestep}")
print(f"Plane: {plane}")
print(f"Number of slices: {num_slices}")

In [None]:
# The cascaded_prediction function is now imported from lasernet.cascaded
# It performs:
# 1. Temperature prediction from input sequence
# 2. Renormalizes predicted temperature to microstructure scale
# 3. Microstructure prediction using predicted temperature (cascaded)
# 4. Microstructure prediction using ground truth temperature (standard, for comparison)

In [None]:
# Run cascaded prediction for all slices in timestep 18
results = cascaded_prediction_timestep(
    temp_model=temp_model,
    micro_model=micro_model,
    temp_dataset=temp_test_dataset,
    micro_dataset=micro_test_dataset,
    device=device,
    plane=plane,
)

## 5. Compute Metrics

In [None]:
# Compute all metrics using the utility function
metrics = compute_cascaded_metrics(results, temp_test_dataset, micro_test_dataset)

print("Temperature Prediction Metrics (all slices in timestep 18):")
print(f"  Mean MSE (normalized): {metrics['temperature']['mse_mean']:.6f} ± {metrics['temperature']['mse_std']:.6f}")
print(f"  Mean MAE (actual temp): {metrics['temperature']['mae_mean']:.2f} ± {metrics['temperature']['mae_std']:.2f} K")
print(f"  Median MAE: {metrics['temperature']['mae_median']:.2f} K")

In [None]:
# Print microstructure metrics (already computed in metrics dict)
print("Microstructure Prediction Metrics (all slices in timestep 18):")
print("\n  CASCADED (using predicted temperature):")
print(f"    Mean MSE (normalized): {metrics['microstructure_cascaded']['mse_mean']:.6f} ± {metrics['microstructure_cascaded']['mse_std']:.6f}")
print(f"    Mean MAE (IPF-X channels): {metrics['microstructure_cascaded']['mae_mean']:.4f} ± {metrics['microstructure_cascaded']['mae_std']:.4f}")
print("\n  STANDARD (using ground truth temperature):")
print(f"    Mean MSE (normalized): {metrics['microstructure_standard']['mse_mean']:.6f} ± {metrics['microstructure_standard']['mse_std']:.6f}")
print(f"    Mean MAE (IPF-X channels): {metrics['microstructure_standard']['mae_mean']:.4f} ± {metrics['microstructure_standard']['mae_std']:.4f}")
print("\n  Difference (Cascaded - Standard):")
print(f"    MSE difference: {metrics['mse_difference']['mean']:.6f}")

# Store lists for later visualization
temp_mse_list = metrics['temperature']['mse_list']
temp_mae_list = metrics['temperature']['mae_list']
micro_mse_cascaded_list = metrics['microstructure_cascaded']['mse_list']
micro_mse_standard_list = metrics['microstructure_standard']['mse_list']
micro_mae_cascaded_list = metrics['microstructure_cascaded']['mae_list']
micro_mae_standard_list = metrics['microstructure_standard']['mae_list']

## 6. Visualize Results

In [None]:
# Select a representative slice to visualize (middle of the domain)
viz_slice_idx = num_slices // 2
result = results[viz_slice_idx]

print(f"Visualizing slice {viz_slice_idx} (Y-slice at middle of domain)")

In [None]:
# Visualize temperature prediction
temp_input_denorm = temp_test_dataset.denormalize(result.temp_input_seq)
temp_target_denorm = temp_test_dataset.denormalize(result.temp_target)
temp_pred_denorm = temp_test_dataset.denormalize(result.temp_pred)

fig = plot_temperature_prediction(
    input_seq=temp_input_denorm,
    target=temp_target_denorm,
    prediction=temp_pred_denorm,
    title=f"Temperature Prediction - Slice {viz_slice_idx}",
)
plt.show()

In [None]:
# Visualize microstructure prediction (CASCADED)
micro_input_denorm = micro_test_dataset.denormalize(result.micro_input_seq)
micro_target_denorm = micro_test_dataset.denormalize_target(result.micro_target)
micro_pred_cascaded_denorm = micro_test_dataset.denormalize_target(result.micro_pred_cascaded)

fig = plot_microstructure_prediction(
    input_seq=micro_input_denorm,
    target=micro_target_denorm,
    prediction=micro_pred_cascaded_denorm,
    title=f"Microstructure Prediction (CASCADED) - Slice {viz_slice_idx}",
)
plt.show()

In [None]:
# Visualize microstructure prediction (STANDARD - using ground truth temperature)
micro_pred_standard_denorm = micro_test_dataset.denormalize_target(result.micro_pred_standard)

fig = plot_microstructure_prediction(
    input_seq=micro_input_denorm,
    target=micro_target_denorm,
    prediction=micro_pred_standard_denorm,
    title=f"Microstructure Prediction (STANDARD) - Slice {viz_slice_idx}",
)
plt.show()

In [None]:
# Compare cascaded vs standard microstructure predictions side by side
fig, axes = plt.subplots(2, 3, figsize=(15, 10))

# Top row: Ground truth and cascaded prediction
# IPF-X RGB visualization
target_rgb = np.clip(np.transpose(micro_target_denorm[0:3].numpy(), (2, 1, 0)), 0, 1).astype(np.float32)
cascaded_rgb = np.clip(np.transpose(micro_pred_cascaded_denorm[0:3].numpy(), (2, 1, 0)), 0, 1).astype(np.float32)
standard_rgb = np.clip(np.transpose(micro_pred_standard_denorm[0:3].numpy(), (2, 1, 0)), 0, 1).astype(np.float32)

axes[0, 0].imshow(target_rgb, aspect='equal', origin='lower')
axes[0, 0].set_title('Ground Truth')
axes[0, 0].set_xlabel('X coordinate')
axes[0, 0].set_ylabel('Z coordinate')

axes[0, 1].imshow(cascaded_rgb, aspect='equal', origin='lower')
axes[0, 1].set_title('Cascaded Prediction\n(using predicted temperature)')
axes[0, 1].set_xlabel('X coordinate')
axes[0, 1].set_ylabel('Z coordinate')

axes[0, 2].imshow(standard_rgb, aspect='equal', origin='lower')
axes[0, 2].set_title('Standard Prediction\n(using ground truth temperature)')
axes[0, 2].set_xlabel('X coordinate')
axes[0, 2].set_ylabel('Z coordinate')

# Bottom row: Error maps
cascaded_error = np.mean((micro_target_denorm[0:3].numpy() - micro_pred_cascaded_denorm[0:3].numpy()) ** 2, axis=0).astype(np.float32)
standard_error = np.mean((micro_target_denorm[0:3].numpy() - micro_pred_standard_denorm[0:3].numpy()) ** 2, axis=0).astype(np.float32)
error_diff = cascaded_error - standard_error  # Positive = cascaded is worse

# Use same color scale for error maps
vmax_error = max(cascaded_error.max(), standard_error.max())

im0 = axes[1, 0].imshow(cascaded_error.T, cmap='RdYlBu_r', aspect='equal', origin='lower', vmin=0, vmax=vmax_error)
axes[1, 0].set_title(f'Cascaded Error (MSE: {cascaded_error.mean():.4f})')
axes[1, 0].set_xlabel('X coordinate')
axes[1, 0].set_ylabel('Z coordinate')
plt.colorbar(im0, ax=axes[1, 0], fraction=0.046, pad=0.04)

im1 = axes[1, 1].imshow(standard_error.T, cmap='RdYlBu_r', aspect='equal', origin='lower', vmin=0, vmax=vmax_error)
axes[1, 1].set_title(f'Standard Error (MSE: {standard_error.mean():.4f})')
axes[1, 1].set_xlabel('X coordinate')
axes[1, 1].set_ylabel('Z coordinate')
plt.colorbar(im1, ax=axes[1, 1], fraction=0.046, pad=0.04)

# Difference map (positive = cascaded is worse)
vmax_diff = max(abs(error_diff.min()), abs(error_diff.max()))
im2 = axes[1, 2].imshow(error_diff.T, cmap='RdBu_r', aspect='equal', origin='lower', vmin=-vmax_diff, vmax=vmax_diff)
axes[1, 2].set_title(f'Error Difference\n(Cascaded - Standard)')
axes[1, 2].set_xlabel('X coordinate')
axes[1, 2].set_ylabel('Z coordinate')
plt.colorbar(im2, ax=axes[1, 2], fraction=0.046, pad=0.04)

plt.suptitle(f'Cascaded vs Standard Microstructure Prediction - Slice {viz_slice_idx}', fontsize=14)
plt.tight_layout()
plt.savefig('cascaded_vs_standard_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

## 7. Aggregate Metrics Visualization

In [None]:
# Plot MSE distribution across slices
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Temperature MSE
axes[0].hist(temp_mse_list, bins=20, edgecolor='black', alpha=0.7, color='coral')
axes[0].axvline(np.mean(temp_mse_list), color='red', linestyle='--', linewidth=2, label=f'Mean: {np.mean(temp_mse_list):.6f}')
axes[0].set_xlabel('MSE (normalized)')
axes[0].set_ylabel('Frequency')
axes[0].set_title('Temperature MSE Distribution')
axes[0].legend()
axes[0].grid(alpha=0.3)

# Microstructure MSE comparison
x = np.arange(len(micro_mse_cascaded_list))
axes[1].scatter(x, micro_mse_cascaded_list, alpha=0.6, label='Cascaded', s=10)
axes[1].scatter(x, micro_mse_standard_list, alpha=0.6, label='Standard', s=10)
axes[1].axhline(np.mean(micro_mse_cascaded_list), color='blue', linestyle='--', alpha=0.7)
axes[1].axhline(np.mean(micro_mse_standard_list), color='orange', linestyle='--', alpha=0.7)
axes[1].set_xlabel('Slice Index')
axes[1].set_ylabel('MSE (normalized)')
axes[1].set_title('Microstructure MSE by Slice')
axes[1].legend()
axes[1].grid(alpha=0.3)

# MSE difference
mse_diff = np.array(micro_mse_cascaded_list) - np.array(micro_mse_standard_list)
axes[2].bar(x, mse_diff, alpha=0.7, color=np.where(mse_diff > 0, 'red', 'green'))
axes[2].axhline(0, color='black', linewidth=1)
axes[2].axhline(np.mean(mse_diff), color='purple', linestyle='--', linewidth=2, label=f'Mean diff: {np.mean(mse_diff):.6f}')
axes[2].set_xlabel('Slice Index')
axes[2].set_ylabel('MSE Difference (Cascaded - Standard)')
axes[2].set_title('Microstructure MSE Difference')
axes[2].legend()
axes[2].grid(alpha=0.3)

plt.tight_layout()
plt.savefig('mse_distribution.png', dpi=150, bbox_inches='tight')
plt.show()

# Summary
print(f"\nSummary:")
print(f"  Slices where cascaded is WORSE: {np.sum(mse_diff > 0)}/{len(mse_diff)}")
print(f"  Slices where cascaded is BETTER: {np.sum(mse_diff < 0)}/{len(mse_diff)}")
print(f"  Mean MSE increase from using predicted temperature: {np.mean(mse_diff):.6f} ({np.mean(mse_diff)/np.mean(micro_mse_standard_list)*100:.2f}%)")

## 8. Visualize Multiple Slices

In [None]:
# Visualize a few representative slices
slice_indices = [0, num_slices // 4, num_slices // 2, 3 * num_slices // 4, num_slices - 1]

fig, axes = plt.subplots(len(slice_indices), 4, figsize=(16, 4 * len(slice_indices)))

for row, slice_idx in enumerate(slice_indices):
    result = results[slice_idx]

    # Get denormalized data
    micro_target_denorm = micro_test_dataset.denormalize_target(result.micro_target)
    micro_pred_cascaded_denorm = micro_test_dataset.denormalize_target(result.micro_pred_cascaded)
    micro_pred_standard_denorm = micro_test_dataset.denormalize_target(result.micro_pred_standard)

    # RGB visualization
    target_rgb = np.clip(np.transpose(micro_target_denorm[0:3].numpy(), (2, 1, 0)), 0, 1).astype(np.float32)
    cascaded_rgb = np.clip(np.transpose(micro_pred_cascaded_denorm[0:3].numpy(), (2, 1, 0)), 0, 1).astype(np.float32)
    standard_rgb = np.clip(np.transpose(micro_pred_standard_denorm[0:3].numpy(), (2, 1, 0)), 0, 1).astype(np.float32)

    # Error
    cascaded_error = np.mean((micro_target_denorm[0:3].numpy() - micro_pred_cascaded_denorm[0:3].numpy()) ** 2, axis=0).astype(np.float32)

    axes[row, 0].imshow(target_rgb, aspect='equal', origin='lower')
    axes[row, 0].set_title(f'Slice {slice_idx}: Ground Truth')
    axes[row, 0].set_ylabel('Z')

    axes[row, 1].imshow(cascaded_rgb, aspect='equal', origin='lower')
    axes[row, 1].set_title('Cascaded Prediction')

    axes[row, 2].imshow(standard_rgb, aspect='equal', origin='lower')
    axes[row, 2].set_title('Standard Prediction')

    im = axes[row, 3].imshow(cascaded_error.T, cmap='RdYlBu_r', aspect='equal', origin='lower')
    axes[row, 3].set_title(f'Cascaded Error (MSE: {cascaded_error.mean():.4f})')
    plt.colorbar(im, ax=axes[row, 3], fraction=0.046, pad=0.04)

for ax in axes[-1, :]:
    ax.set_xlabel('X')

plt.suptitle('Cascaded Predictions Across Slices (Timestep 18)', fontsize=14, y=1.01)
plt.tight_layout()
plt.savefig('cascaded_predictions_all_slices.png', dpi=150, bbox_inches='tight')
plt.show()

## 9. Autoregressive Multi-Timestep Prediction

In this section, we perform autoregressive prediction where the output from one timestep is used as input for the next timestep. This tests how errors accumulate over time in a realistic deployment scenario.

**Pipeline for each timestep:**
1. Predict temperature using the previous 3 temperature frames (either from ground truth or previous predictions)
2. Use the predicted temperature + previous microstructure frames to predict the next microstructure
3. Store predictions and use them as inputs for subsequent timesteps

In [None]:
# Configuration for autoregressive prediction
num_autoregressive_steps = 3  # Number of timesteps to predict autoregressively
sequence_length = 3  # Input sequence length required by the model
start_timestep = 18  # First test timestep (relative timestep 0 in test split)
selected_slice = num_slices // 2  # Use middle slice for demonstration

print(f"Autoregressive prediction settings:")
print(f"  Starting timestep: {start_timestep}")
print(f"  Number of autoregressive steps: {num_autoregressive_steps}")
print(f"  Selected slice: {selected_slice} (middle of domain)")
print(f"  Sequence length: {sequence_length}")
print(f"  Will predict timesteps: {start_timestep + 1} to {start_timestep + num_autoregressive_steps}")

In [None]:
# The autoregressive_cascaded_prediction and load_ground_truth_frame functions
# are now imported from lasernet.cascaded
#
# autoregressive_cascaded_prediction performs multi-step prediction where:
# 1. Builds input sequences from previous frames (GT or predicted)
# 2. Predicts temperature
# 3. Predicts microstructure using predicted temperature
# 4. Stores predictions for subsequent timesteps

In [None]:
# Run autoregressive cascaded prediction
print(f"Running autoregressive cascaded prediction for slice {selected_slice}...")
print(f"Using ground truth frames from timesteps {start_timestep} to {start_timestep + sequence_length - 1} as initial context")
print()

autoregressive_results = autoregressive_cascaded_prediction(
    temp_model=temp_model,
    micro_model=micro_model,
    temp_dataset=temp_test_dataset,
    micro_dataset=micro_test_dataset,
    slice_idx=selected_slice,
    num_steps=num_autoregressive_steps,
    sequence_length=sequence_length,
    device=device,
    plane=plane,
)

print()
print("Autoregressive prediction complete!")

### 9.1 Analyze Error Accumulation Over Time

In [None]:
# Plot error accumulation over autoregressive steps
fig, axes = plt.subplots(1, 2, figsize=(12, 4))

steps = np.arange(1, num_autoregressive_steps + 1)
timesteps = [start_timestep + sequence_length + i for i in range(num_autoregressive_steps)]

# Temperature MSE over steps
axes[0].plot(steps, autoregressive_results.temp_mse, 'o-', color='coral', linewidth=2, markersize=8)
axes[0].set_xlabel('Autoregressive Step')
axes[0].set_ylabel('MSE (normalized)')
axes[0].set_title('Temperature Prediction Error Over Time')
axes[0].set_xticks(steps)
axes[0].set_xticklabels([f"Step {s}\n(t={t})" for s, t in zip(steps, timesteps)])
axes[0].grid(alpha=0.3)

# Add percentage increase annotation
if len(autoregressive_results.temp_mse) > 1:
    first_mse = autoregressive_results.temp_mse[0]
    last_mse = autoregressive_results.temp_mse[-1]
    pct_increase = ((last_mse - first_mse) / first_mse) * 100
    axes[0].annotate(f'{pct_increase:+.1f}% from step 1',
                     xy=(steps[-1], last_mse),
                     xytext=(steps[-1] - 0.5, last_mse * 1.2),
                     fontsize=10, ha='center')

# Microstructure MSE over steps
axes[1].plot(steps, autoregressive_results.micro_mse, 's-', color='steelblue', linewidth=2, markersize=8)
axes[1].set_xlabel('Autoregressive Step')
axes[1].set_ylabel('MSE (normalized)')
axes[1].set_title('Microstructure Prediction Error Over Time')
axes[1].set_xticks(steps)
axes[1].set_xticklabels([f"Step {s}\n(t={t})" for s, t in zip(steps, timesteps)])
axes[1].grid(alpha=0.3)

# Add percentage increase annotation
if len(autoregressive_results.micro_mse) > 1:
    first_mse = autoregressive_results.micro_mse[0]
    last_mse = autoregressive_results.micro_mse[-1]
    pct_increase = ((last_mse - first_mse) / first_mse) * 100
    axes[1].annotate(f'{pct_increase:+.1f}% from step 1',
                     xy=(steps[-1], last_mse),
                     xytext=(steps[-1] - 0.5, last_mse * 1.2),
                     fontsize=10, ha='center')

plt.suptitle(f'Error Accumulation in Autoregressive Prediction (Slice {selected_slice})', fontsize=12)
plt.tight_layout()
plt.savefig('autoregressive_error_accumulation.png', dpi=150, bbox_inches='tight')
plt.show()

# Print summary statistics
print("\nError Accumulation Summary:")
print("-" * 50)
print(f"{'Step':<8} {'Timestep':<10} {'Temp MSE':<15} {'Micro MSE':<15}")
print("-" * 50)
for i, (t_mse, m_mse) in enumerate(zip(autoregressive_results.temp_mse, autoregressive_results.micro_mse)):
    print(f"{i+1:<8} {timesteps[i]:<10} {t_mse:<15.6f} {m_mse:<15.6f}")
print("-" * 50)
print(f"{'Total increase:':<18} {autoregressive_results.temp_mse[-1] - autoregressive_results.temp_mse[0]:<15.6f} "
      f"{autoregressive_results.micro_mse[-1] - autoregressive_results.micro_mse[0]:<15.6f}")

### 9.2 Visualize Predictions at Each Autoregressive Step

In [None]:
# Visualize temperature predictions at each autoregressive step
fig, axes = plt.subplots(num_autoregressive_steps, 3, figsize=(12, 4 * num_autoregressive_steps))

for step in range(num_autoregressive_steps):
    # Denormalize for visualization
    temp_pred = autoregressive_results.temp_predictions[step]
    temp_target = autoregressive_results.temp_targets[step]

    temp_pred_denorm = temp_test_dataset.denormalize(temp_pred).numpy()[0]
    temp_target_denorm = temp_test_dataset.denormalize(temp_target).numpy()[0]
    temp_error = np.abs(temp_pred_denorm - temp_target_denorm)

    timestep = start_timestep + sequence_length + step

    # Shared colorbar limits for prediction and target
    vmin = min(temp_pred_denorm.min(), temp_target_denorm.min())
    vmax = max(temp_pred_denorm.max(), temp_target_denorm.max())

    # Ground truth
    im0 = axes[step, 0].imshow(temp_target_denorm.T, cmap='hot', aspect='equal', origin='lower', vmin=vmin, vmax=vmax)
    axes[step, 0].set_title(f'Step {step+1} (t={timestep}): Ground Truth')
    axes[step, 0].set_ylabel('Z')
    plt.colorbar(im0, ax=axes[step, 0], fraction=0.046, pad=0.04, label='T [K]')

    # Prediction
    im1 = axes[step, 1].imshow(temp_pred_denorm.T, cmap='hot', aspect='equal', origin='lower', vmin=vmin, vmax=vmax)
    axes[step, 1].set_title(f'Prediction (MSE: {autoregressive_results.temp_mse[step]:.6f})')
    plt.colorbar(im1, ax=axes[step, 1], fraction=0.046, pad=0.04, label='T [K]')

    # Error
    im2 = axes[step, 2].imshow(temp_error.T, cmap='RdYlBu_r', aspect='equal', origin='lower')
    axes[step, 2].set_title(f'Absolute Error (MAE: {temp_error.mean():.2f} K)')
    plt.colorbar(im2, ax=axes[step, 2], fraction=0.046, pad=0.04, label='|Error| [K]')

for ax in axes[-1, :]:
    ax.set_xlabel('X')

plt.suptitle(f'Temperature Predictions - Autoregressive Steps (Slice {selected_slice})', fontsize=14, y=1.01)
plt.tight_layout()
plt.savefig('autoregressive_temperature_predictions.png', dpi=150, bbox_inches='tight')
plt.show()

In [None]:
# Visualize microstructure predictions at each autoregressive step
fig, axes = plt.subplots(num_autoregressive_steps, 4, figsize=(16, 4 * num_autoregressive_steps))

for step in range(num_autoregressive_steps):
    # Denormalize for visualization
    micro_pred = autoregressive_results.micro_predictions[step]
    micro_target = autoregressive_results.micro_targets[step]

    micro_pred_denorm = micro_test_dataset.denormalize_target(micro_pred)
    micro_target_denorm = micro_test_dataset.denormalize_target(micro_target)

    timestep = start_timestep + sequence_length + step

    # IPF-X RGB visualization (first 3 channels)
    target_rgb = np.clip(np.transpose(micro_target_denorm[0:3].numpy(), (2, 1, 0)), 0, 1).astype(np.float32)
    pred_rgb = np.clip(np.transpose(micro_pred_denorm[0:3].numpy(), (2, 1, 0)), 0, 1).astype(np.float32)

    # Compute error
    error = np.mean((micro_target_denorm[0:3].numpy() - micro_pred_denorm[0:3].numpy()) ** 2, axis=0).astype(np.float32)

    # Difference image (scaled for visibility)
    diff = np.abs(target_rgb - pred_rgb)
    diff_scaled = np.clip(diff * 5, 0, 1)  # Scale up for visibility

    # Ground truth
    axes[step, 0].imshow(target_rgb, aspect='equal', origin='lower')
    axes[step, 0].set_title(f'Step {step+1} (t={timestep}): Ground Truth')
    axes[step, 0].set_ylabel('Z')

    # Prediction
    axes[step, 1].imshow(pred_rgb, aspect='equal', origin='lower')
    axes[step, 1].set_title(f'Prediction (MSE: {autoregressive_results.micro_mse[step]:.6f})')

    # Difference (scaled)
    axes[step, 2].imshow(diff_scaled, aspect='equal', origin='lower')
    axes[step, 2].set_title('|Difference| (5x scaled)')

    # MSE error map
    im = axes[step, 3].imshow(error.T, cmap='RdYlBu_r', aspect='equal', origin='lower')
    axes[step, 3].set_title(f'MSE Error Map')
    plt.colorbar(im, ax=axes[step, 3], fraction=0.046, pad=0.04)

for ax in axes[-1, :]:
    ax.set_xlabel('X')

plt.suptitle(f'Microstructure Predictions (IPF-X) - Autoregressive Steps (Slice {selected_slice})', fontsize=14, y=1.01)
plt.tight_layout()
plt.savefig('autoregressive_microstructure_predictions.png', dpi=150, bbox_inches='tight')
plt.show()

### 9.3 Compare Autoregressive vs Single-Step Predictions

Compare the quality of predictions from the autoregressive approach against single-step predictions (where ground truth is always used as input).

In [None]:
# Run single-step predictions for the same timesteps (using ground truth inputs)
# This serves as a baseline to measure the impact of error accumulation

single_step_temp_mse = []
single_step_micro_mse = []

print("Running single-step predictions for comparison (using ground truth inputs)...")
for step in range(num_autoregressive_steps):
    # Calculate dataset index for this timestep
    # For step 0, we want temporal_offset = sequence_length (predicting t=21 from t=[18,19,20])
    # Dataset index = slice_idx + temporal_offset * num_slices
    # But the dataset already handles sequences, so we need the correct starting index
    
    # Actually, for single-step, we can just use the test dataset directly
    # Index in test dataset corresponds to: slice + (temporal_offset - sequence_length) * num_slices
    # Wait, let me reconsider: the dataset getitem returns sequences ending at different temporal positions
    
    # The dataset index for getting a sequence ending at temporal_offset T is:
    # idx = slice_idx + (T - sequence_length) * num_slices
    # So for step 0 (T=sequence_length=3), idx = slice_idx + 0 * num_slices = slice_idx
    # For step 1 (T=4), idx = slice_idx + 1 * num_slices
    
    temporal_target = sequence_length + step  # 3, 4, 5, ...
    dataset_idx = selected_slice + step * num_slices
    
    # Get temperature prediction
    temp_input, temp_target, _, temp_mask = temp_test_dataset[dataset_idx]
    temp_input_batch = temp_input.unsqueeze(0).half().to(device)
    with torch.no_grad():
        temp_pred = temp_model(temp_input_batch)
    temp_mse = torch.nn.functional.mse_loss(temp_pred[0].cpu().float(), temp_target).item()
    single_step_temp_mse.append(temp_mse)
    
    # Get microstructure prediction (standard - using GT temperature)
    micro_input, micro_target, _, micro_mask = micro_test_dataset[dataset_idx]
    micro_input_batch = micro_input.unsqueeze(0).half().to(device)
    with torch.no_grad():
        micro_pred = micro_model(micro_input_batch)
    micro_mse = torch.nn.functional.mse_loss(micro_pred[0].cpu().float(), micro_target).item()
    single_step_micro_mse.append(micro_mse)
    
    timestep = start_timestep + temporal_target
    print(f"  Step {step + 1}: t={timestep}, Temp MSE={temp_mse:.6f}, Micro MSE={micro_mse:.6f}")

print("\nSingle-step predictions complete!")

In [None]:
# Plot comparison: Autoregressive vs Single-Step
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

steps = np.arange(1, num_autoregressive_steps + 1)
timesteps = [start_timestep + sequence_length + i for i in range(num_autoregressive_steps)]

# Temperature comparison
axes[0].plot(steps, autoregressive_results.temp_mse, 'o-', color='coral', linewidth=2, markersize=10, label='Autoregressive')
axes[0].plot(steps, single_step_temp_mse, 's--', color='darkred', linewidth=2, markersize=10, label='Single-step (GT input)')
axes[0].set_xlabel('Prediction Step')
axes[0].set_ylabel('MSE (normalized)')
axes[0].set_title('Temperature: Autoregressive vs Single-Step')
axes[0].set_xticks(steps)
axes[0].set_xticklabels([f"Step {s}\n(t={t})" for s, t in zip(steps, timesteps)])
axes[0].legend()
axes[0].grid(alpha=0.3)

# Microstructure comparison
axes[1].plot(steps, autoregressive_results.micro_mse, 'o-', color='steelblue', linewidth=2, markersize=10, label='Autoregressive')
axes[1].plot(steps, single_step_micro_mse, 's--', color='darkblue', linewidth=2, markersize=10, label='Single-step (GT input)')
axes[1].set_xlabel('Prediction Step')
axes[1].set_ylabel('MSE (normalized)')
axes[1].set_title('Microstructure: Autoregressive vs Single-Step')
axes[1].set_xticks(steps)
axes[1].set_xticklabels([f"Step {s}\n(t={t})" for s, t in zip(steps, timesteps)])
axes[1].legend()
axes[1].grid(alpha=0.3)

plt.suptitle(f'Autoregressive vs Single-Step Prediction Comparison (Slice {selected_slice})', fontsize=12)
plt.tight_layout()
plt.savefig('autoregressive_vs_single_step.png', dpi=150, bbox_inches='tight')
plt.show()

# Calculate degradation due to autoregressive error accumulation
print("\nError Degradation Analysis:")
print("=" * 70)
print(f"{'Step':<8} {'Timestep':<10} {'Temp (AR)':<12} {'Temp (SS)':<12} {'Δ Temp':<12} {'Micro (AR)':<12} {'Micro (SS)':<12} {'Δ Micro':<12}")
print("-" * 70)
for i in range(num_autoregressive_steps):
    ar_temp = autoregressive_results.temp_mse[i]
    ss_temp = single_step_temp_mse[i]
    ar_micro = autoregressive_results.micro_mse[i]
    ss_micro = single_step_micro_mse[i]

    delta_temp = ar_temp - ss_temp
    delta_micro = ar_micro - ss_micro

    print(f"{i+1:<8} {timesteps[i]:<10} {ar_temp:<12.6f} {ss_temp:<12.6f} {delta_temp:<+12.6f} {ar_micro:<12.6f} {ss_micro:<12.6f} {delta_micro:<+12.6f}")

print("-" * 70)
avg_temp_degradation = np.mean(np.array(autoregressive_results.temp_mse) - np.array(single_step_temp_mse))
avg_micro_degradation = np.mean(np.array(autoregressive_results.micro_mse) - np.array(single_step_micro_mse))
print(f"{'Average degradation:':<30} {avg_temp_degradation:<+12.6f} {'':<12} {avg_micro_degradation:<+12.6f}")

### 9.4 Evolution Visualization - Side-by-Side Comparison Over Time

Visualize the temporal evolution of both ground truth and autoregressive predictions to see how well the model captures the dynamics.

In [None]:
# Create a comprehensive evolution visualization
from lasernet.cascaded import load_ground_truth_frame

fig, axes = plt.subplots(2, num_autoregressive_steps + 1, figsize=(4 * (num_autoregressive_steps + 1), 8))

# Get the initial microstructure state (last frame of input sequence)
initial_micro_frame, _ = load_ground_truth_frame(temp_test_dataset, micro_test_dataset, selected_slice, sequence_length - 1, plane)
initial_micro_denorm = micro_test_dataset.denormalize(initial_micro_frame.unsqueeze(0))[0]
initial_rgb = np.clip(np.transpose(initial_micro_denorm[0:3].numpy(), (2, 1, 0)), 0, 1).astype(np.float32)

# First column: Initial state
axes[0, 0].imshow(initial_rgb, aspect='equal', origin='lower')
axes[0, 0].set_title(f'Initial State\n(t={start_timestep + sequence_length - 1})')
axes[0, 0].set_ylabel('Ground Truth\nZ coordinate')

axes[1, 0].imshow(initial_rgb, aspect='equal', origin='lower')
axes[1, 0].set_title(f'Initial State\n(t={start_timestep + sequence_length - 1})')
axes[1, 0].set_ylabel('Prediction\nZ coordinate')

# Subsequent columns: Evolution over time
for step in range(num_autoregressive_steps):
    timestep = start_timestep + sequence_length + step

    # Ground truth
    micro_target = autoregressive_results.micro_targets[step]
    micro_target_denorm = micro_test_dataset.denormalize_target(micro_target)
    target_rgb = np.clip(np.transpose(micro_target_denorm[0:3].numpy(), (2, 1, 0)), 0, 1).astype(np.float32)

    axes[0, step + 1].imshow(target_rgb, aspect='equal', origin='lower')
    axes[0, step + 1].set_title(f't={timestep}')

    # Prediction
    micro_pred = autoregressive_results.micro_predictions[step]
    micro_pred_denorm = micro_test_dataset.denormalize_target(micro_pred)
    pred_rgb = np.clip(np.transpose(micro_pred_denorm[0:3].numpy(), (2, 1, 0)), 0, 1).astype(np.float32)

    axes[1, step + 1].imshow(pred_rgb, aspect='equal', origin='lower')
    mse = autoregressive_results.micro_mse[step]
    axes[1, step + 1].set_title(f't={timestep}\nMSE: {mse:.4f}')

# Set x labels for bottom row
for ax in axes[1, :]:
    ax.set_xlabel('X coordinate')

plt.suptitle(f'Microstructure Evolution: Ground Truth vs Autoregressive Prediction (Slice {selected_slice})', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig('microstructure_evolution_comparison.png', dpi=150, bbox_inches='tight')
plt.show()

### 9.5 Summary and Conclusions

In [None]:
# Final summary of autoregressive cascaded prediction results
print("=" * 80)
print("AUTOREGRESSIVE CASCADED PREDICTION - SUMMARY")
print("=" * 80)
print()
print(f"Configuration:")
print(f"  - Selected slice: {selected_slice} (middle of domain)")
print(f"  - Starting timestep: {start_timestep}")
print(f"  - Input sequence length: {sequence_length}")
print(f"  - Number of autoregressive steps: {num_autoregressive_steps}")
print(f"  - Predicted timesteps: {[start_timestep + sequence_length + i for i in range(num_autoregressive_steps)]}")
print()
print("-" * 80)
print("TEMPERATURE PREDICTION")
print("-" * 80)
print(f"  Autoregressive:")
print(f"    - Mean MSE: {np.mean(autoregressive_results.temp_mse):.6f}")
print(f"    - Final step MSE: {autoregressive_results.temp_mse[-1]:.6f}")
print(f"    - Error growth: {(autoregressive_results.temp_mse[-1] / autoregressive_results.temp_mse[0] - 1) * 100:.1f}%")
print(f"  Single-step (baseline):")
print(f"    - Mean MSE: {np.mean(single_step_temp_mse):.6f}")
print(f"  Degradation from autoregressive:")
print(f"    - Average: {avg_temp_degradation:+.6f} ({avg_temp_degradation / np.mean(single_step_temp_mse) * 100:+.1f}%)")
print()
print("-" * 80)
print("MICROSTRUCTURE PREDICTION")
print("-" * 80)
print(f"  Autoregressive:")
print(f"    - Mean MSE: {np.mean(autoregressive_results.micro_mse):.6f}")
print(f"    - Final step MSE: {autoregressive_results.micro_mse[-1]:.6f}")
print(f"    - Error growth: {(autoregressive_results.micro_mse[-1] / autoregressive_results.micro_mse[0] - 1) * 100:.1f}%")
print(f"  Single-step (baseline):")
print(f"    - Mean MSE: {np.mean(single_step_micro_mse):.6f}")
print(f"  Degradation from autoregressive:")
print(f"    - Average: {avg_micro_degradation:+.6f} ({avg_micro_degradation / np.mean(single_step_micro_mse) * 100:+.1f}%)")
print()
print("-" * 80)
print("KEY OBSERVATIONS")
print("-" * 80)

# Analyze error trends
temp_growth_rate = (autoregressive_results.temp_mse[-1] / autoregressive_results.temp_mse[0]) ** (1 / (num_autoregressive_steps - 1)) if num_autoregressive_steps > 1 else 1.0
micro_growth_rate = (autoregressive_results.micro_mse[-1] / autoregressive_results.micro_mse[0]) ** (1 / (num_autoregressive_steps - 1)) if num_autoregressive_steps > 1 else 1.0

print(f"  1. Temperature error growth rate: {temp_growth_rate:.3f}x per step")
print(f"  2. Microstructure error growth rate: {micro_growth_rate:.3f}x per step")

if avg_micro_degradation < 0.001:
    print(f"  3. Microstructure predictions are relatively stable despite using predicted temperature")
else:
    print(f"  3. Error accumulation causes noticeable degradation in microstructure predictions")

if temp_growth_rate < 1.5:
    print(f"  4. Temperature predictions remain relatively stable over {num_autoregressive_steps} steps")
else:
    print(f"  4. Temperature prediction error grows significantly over {num_autoregressive_steps} steps")

print()
print("=" * 80)