# Load Data

In [None]:
import torch
from auto_cast.data.dataset import BOUTDataset
from auto_cast.data.datamodule import SpatioTemporalDataModule
from torch.utils.data import DataLoader



In [None]:
from auto_cast.data.dataset import BOUTDataset
from auto_cast.data.datamodule import SpatioTemporalDataModule

# Load with datamodule
datamodule = SpatioTemporalDataModule(
    data_path="data/bout_split",
    dataset_cls=BOUTDataset,
    n_steps_input=5,   # Use 5 input frames
    n_steps_output=40,  # Predict 5 future frames
    stride=1,
    batch_size=4,
    dtype=torch.float32,  # Convert from float64 to float32
    ftype="torch",
    verbose=True,
)




In [None]:
# Get a batch
train_loader = datamodule.train_dataloader()
batch = next(iter(train_loader))

In [None]:
batch.input_fields.shape, batch.output_fields.shape, batch.constant_scalars.shape

In [None]:
import torch
import torch.nn as nn
from azula.noise import CosineSchedule
from auto_cast.data.dataset import BOUTDataset
from auto_cast.data.datamodule import SpatioTemporalDataModule
from auto_cast.types import EncodedBatch


# ============================================================================
# 2. Load Data
# ============================================================================

print("="*70)
print("Loading BOUT Dataset")
print("="*70)

datamodule = SpatioTemporalDataModule(
    data_path="data/bout_split",
    dataset_cls=BOUTDataset,
    n_steps_input=5,
    n_steps_output=5,
    stride=1,
    batch_size=4,
    dtype=torch.float32,
    ftype="torch",
    verbose=True,
)

train_loader = datamodule.train_dataloader()
batch = next(iter(train_loader))

print(f"\nðŸ“Š Batch shapes (after collate_batches):")
print(f"   Input:  {batch.input_fields.shape}")   # Should be [4, 5, 1, 256, 256]
print(f"   Output: {batch.output_fields.shape}")  # Should be [4, 5, 1, 256, 256]
print(f"   Const scalars: {batch.constant_scalars.shape if batch.constant_scalars is not None else None}")



In [None]:
from azula.nn.unet import UNet
from azula.nn.embedding import SineEncoding

class TemporalUNetBackbone(nn.Module):
    """Azula UNet with proper time embedding."""
    
    def __init__(
        self,
        in_channels: int = 1,
        out_channels: int = 1,
        mod_features: int = 256,
        hid_channels: tuple = (32, 64, 128),
        hid_blocks: tuple = (2, 2, 2),
        spatial: int = 2,
        periodic: bool = False,
    ):
        super().__init__()
        
        # Time embedding
        self.time_embedding = nn.Sequential(
            SineEncoding(mod_features),
            nn.Linear(mod_features, mod_features),
            nn.SiLU(),
            nn.Linear(mod_features, mod_features),
        )
        
        self.unet = UNet(
            in_channels=in_channels,
            out_channels=out_channels,
            cond_channels=0,
            mod_features=mod_features,
            hid_channels=hid_channels,
            hid_blocks=hid_blocks,
            kernel_size=3,
            stride=2,
            spatial=spatial,
            periodic=periodic,
        )
    
    def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
        B, T, C, H, W = x.shape
        
        # Embed time
        t_emb = self.time_embedding(t)  # (B, mod_features)
        
        # Flatten temporal
        x_flat = x.reshape(B * T, C, H, W)
        t_emb_expanded = t_emb.repeat_interleave(T, dim=0)
        
        # Process
        out_flat = self.unet(x_flat, mod=t_emb_expanded)
        
        # Reshape
        return out_flat.reshape(B, T, C, H, W)




# ============================================================================
# 3. Create DiffusionProcessor
# ============================================================================

In [None]:
from auto_cast.processors.diffusion import DiffusionProcessor

B, T, C, H, W = batch.output_fields.shape

# Create backbone
backbone = TemporalUNetBackbone(
    in_channels=C,
    out_channels=C,
    mod_features=256,
    hid_channels=(16, 32, 64),  # Small for testing
    hid_blocks=(2, 2, 2),
    spatial=2,
    periodic=False,
)
# Create schedule
schedule = CosineSchedule(alpha_min=0.001, sigma_min=0.001)

# Create processor
processor = DiffusionProcessor(
    backbone=backbone,
    schedule=schedule,
    denoiser_type='karras',
    learning_rate=1e-4,
)


# ============================================================================
# 4. Test Forward Pass
# ============================================================================

In [None]:
# Create encoded batch
encoded_batch = EncodedBatch(
    encoded_inputs=batch.input_fields,
    encoded_output_fields=batch.output_fields,
    encoded_info={}
)

# Test training step
loss = processor.training_step(encoded_batch, 0)
print(f"Loss: {loss.item():.4f}")

# Test map (prediction)
output = processor.map(batch.input_fields)
print(f"Output shape: {output.shape}")
print(f"Expected shape: {batch.input_fields.shape}")
assert output.shape == batch.input_fields.shape, "Shape mismatch!"




# ============================================================================
# 5. Quick Training Loop
# ============================================================================



In [None]:
n

# ============================================================================
# 6. Visualize Results
# ============================================================================

In [None]:
import matplotlib.pyplot as plt
import numpy as np


processor.eval()

# Get a test batch
test_batch = next(iter(train_loader))
test_encoded = EncodedBatch(
    encoded_inputs=test_batch.input_fields,
    encoded_output_fields=test_batch.output_fields,
    encoded_info={}
)

# Get ground truth and prediction
with torch.no_grad():
    # Ground truth
    x_gt = test_batch.output_fields  # (B, T, C, H, W)
    
    # Prediction (denoised at t=0)
    x_pred = processor.map(test_batch.input_fields)
    
    # Noisy samples at different noise levels
    B = x_gt.shape[0]
    t_levels = [0.0, 0.25, 0.5, 0.75, 1.0]
    noisy_samples = []
    
    for t_val in t_levels:
        t = torch.full((B,), t_val, device=x_gt.device)
        alpha_t, sigma_t = processor.schedule(t)
        alpha_t = alpha_t.view(-1, 1, 1, 1, 1)
        sigma_t = sigma_t.view(-1, 1, 1, 1, 1)
        noise = torch.randn_like(x_gt)
        x_noisy = alpha_t * x_gt + sigma_t * noise
        noisy_samples.append(x_noisy)

# Convert to numpy for plotting
x_gt_np = x_gt[0, 0, 0].cpu().numpy()       # First batch, first time, first channel
x_pred_np = x_pred[0, 0, 0].cpu().numpy()
noisy_np = [x[0, 0, 0].cpu().numpy() for x in noisy_samples]


In [None]:
import matplotlib.pyplot as plt
import numpy as np

# Shared colorbar settings
cbar_kw = {'fraction': 0.03, 'pad': 0.02}

# ============================================================================
# Plot 1: Ground Truth vs Prediction
# ============================================================================
fig, axes = plt.subplots(1, 2, figsize=(6, 2.5), dpi=100)
im0 = axes[0].imshow(x_gt_np, cmap='viridis')
axes[0].set_title('Ground Truth', fontsize=9)
plt.colorbar(im0, ax=axes[0], **cbar_kw)

im1 = axes[1].imshow(x_pred_np, cmap='viridis')
axes[1].set_title('Prediction', fontsize=9)
plt.colorbar(im1, ax=axes[1], **cbar_kw)

plt.suptitle('BOUT Vorticity Field', fontsize=10, y=0.98)
plt.tight_layout()
plt.show()

# ============================================================================
# Plot 2: Noise Schedule (Compact)
# ============================================================================
fig, axes = plt.subplots(1, len(t_levels), figsize=(10, 1.8), dpi=100)
for idx, (t_val, noisy) in enumerate(zip(t_levels, noisy_np)):
    im = axes[idx].imshow(noisy, cmap='viridis')
    axes[idx].set_title(f't={t_val:.1f}', fontsize=8)

plt.suptitle('Forward Diffusion', fontsize=10, y=1.02)
plt.tight_layout()
plt.show()

# ============================================================================
# Plot 3: Error Analysis (Compact)
# ============================================================================
error = np.abs(x_gt_np - x_pred_np)
fig, axes = plt.subplots(1, 3, figsize=(8, 2.2), dpi=100)

im0 = axes[0].imshow(x_gt_np, cmap='viridis')
axes[0].set_title('Ground Truth', fontsize=9)

im1 = axes[1].imshow(x_pred_np, cmap='viridis')
axes[1].set_title('Prediction', fontsize=9)

im2 = axes[2].imshow(error, cmap='Reds')
axes[2].set_title(f'Error (MAE={error.mean():.3f})', fontsize=9)

plt.suptitle('Prediction Quality', fontsize=10, y=0.98)
plt.tight_layout()
plt.show()

# ============================================================================
# Statistics (Compact)
# ============================================================================
with torch.no_grad():
    mse = ((x_gt - x_pred) ** 2).mean().item()
    mae = (x_gt - x_pred).abs().mean().item()

print(f"MSE:        {mse:.6f}")
print(f"MAE:        {mae:.6f}")
print(f"GT Range:   [{x_gt.min().item():.3f}, {x_gt.max().item():.3f}]")
print(f"Pred Range: [{x_pred.min().item():.3f}, {x_pred.max().item():.3f}]")


In [None]:
# ============================================================================
# Plot 4: Time Series Comparison (Every 10 Steps)
# ============================================================================

# Show every 10th time step
T = x_gt.shape[1]  # Total number of time steps
step_size = 1
time_indices = list(range(0, T, step_size))
if time_indices[-1] != T - 1:  # Always include the last frame
    time_indices.append(T - 1)

n_frames = len(time_indices)
fig, axes = plt.subplots(2, n_frames, figsize=(4*n_frames, 8))

# Handle case where we only have 1 frame
if n_frames == 1:
    axes = axes.reshape(2, 1)

for plot_idx, t_idx in enumerate(time_indices):
    # Ground truth
    gt_frame = x_gt[0, t_idx, 0].cpu().numpy()
    im0 = axes[0, plot_idx].imshow(gt_frame, cmap='viridis')
    axes[0, plot_idx].set_title(f'GT t={t_idx}', fontsize=12)
    if plot_idx == n_frames - 1:  # Add colorbar to last plot
        plt.colorbar(im0, ax=axes[0, plot_idx], fraction=0.046)
    
    # Prediction
    pred_frame = x_pred[0, t_idx, 0].cpu().numpy()
    im1 = axes[1, plot_idx].imshow(pred_frame, cmap='viridis')
    axes[1, plot_idx].set_title(f'Pred t={t_idx}', fontsize=12)
    if plot_idx == n_frames - 1:  # Add colorbar to last plot
        plt.colorbar(im1, ax=axes[1, plot_idx], fraction=0.046)

axes[0, 0].set_ylabel('Ground Truth', fontsize=14, rotation=0, ha='right', va='center')
axes[1, 0].set_ylabel('Prediction', fontsize=14, rotation=0, ha='right', va='center')

plt.suptitle(f'Temporal Evolution (Every {step_size} Steps)', fontsize=16, y=0.99)
plt.tight_layout()

plt.show()

In [None]:
from matplotlib.animation import FuncAnimation
from IPython.display import HTML, Image as IPImage


# Get data for video
x_gt_video = x_gt[0, :, 0].cpu().numpy()    # (T, H, W)
x_pred_video = x_pred[0, :, 0].cpu().numpy()  # (T, H, W)
T = x_gt_video.shape[0]

# Determine common color scale
vmin = min(x_gt_video.min(), x_pred_video.min())
vmax = max(x_gt_video.max(), x_pred_video.max())

# Create figure
fig, axes = plt.subplots(1, 2, figsize=(10, 4), dpi=80)

# Initial frames
im0 = axes[0].imshow(x_gt_video[0], cmap='viridis', vmin=vmin, vmax=vmax)
axes[0].set_title('Ground Truth', fontsize=12)
plt.colorbar(im0, ax=axes[0], fraction=0.03, pad=0.02)

im1 = axes[1].imshow(x_pred_video[0], cmap='viridis', vmin=vmin, vmax=vmax)
axes[1].set_title('Prediction', fontsize=12)
plt.colorbar(im1, ax=axes[1], fraction=0.03, pad=0.02)

# Time text
time_text = fig.suptitle(f'Time Step: 0/{T-1}', fontsize=13, y=0.98)

def update(frame):
    """Update function for animation."""
    im0.set_data(x_gt_video[frame])
    im1.set_data(x_pred_video[frame])
    time_text.set_text(f'Time Step: {frame}/{T-1}')
    return [im0, im1, time_text]

# Create animation
anim = FuncAnimation(fig, update, frames=T, interval=100, blit=False, repeat=True)

# Display the GIF in notebook
display(IPImage(anim))

# ============================================================================
# ROLLOUT TRAINING - Multi-step Autoregressive Prediction
# ============================================================================


In [None]:
# --- Configuration Setup ---
# Note: You need to set n_steps_output high enough to cover the max_rollout_steps
# (10 steps * 4 output frames/step = 40 total frames needed for full supervision).

TOTAL_GT_FRAMES = 40  # Must be 4 * 10

# Load data with sufficient output steps for full rollout supervision
datamodule_rollout = SpatioTemporalDataModule(
    data_path="data/bout_split",
    dataset_cls=BOUTDataset,
    n_steps_input=1,      # 1 input frame
    n_steps_output=TOTAL_GT_FRAMES, # 40 output frames (REQUIRED for 10-step supervision)
    stride=1,              
    batch_size=4,          
    dtype=torch.float32,
    ftype="torch",
    verbose=True,
)

train_loader_rollout = datamodule_rollout.train_dataloader()


# Optional: Add a validation step here calling a separate inference function.

In [None]:

# Create processor with rollout settings
processor_rollout = DiffusionProcessor(
    # ... (backbone, schedule, denoiser_type must be defined previously)
    # Using placeholders here:
    backbone=backbone,
    schedule=schedule,
    denoiser_type='karras',
    teacher_forcing_ratio=0.5,  # 50% teacher forcing
    stride=1,
    max_rollout_steps=10,
    learning_rate=1e-4,
)

processor_rollout.train()
optimizer = torch.optim.Adam(processor_rollout.parameters(), lr=1e-4)

# --- Autoregressive Training Loop ---

for step in range(100):
    # Ensure all components (optimizer, schedule, etc.) are defined before this loop
    
    # 1. Fetch and package batch
    # We must use a 'try-except' block to restart the iterator when it runs out of data
    try:
        batch = next(iter(train_loader_rollout))
    except StopIteration:
        train_loader_rollout = datamodule_rollout.train_dataloader()
        batch = next(iter(train_loader_rollout))
        
    encoded_batch = EncodedBatch(
        encoded_inputs=batch.input_fields,
        encoded_output_fields=batch.output_fields,
        encoded_info={}
    )
    
    optimizer.zero_grad()
    
    # 2. Use the new, loss-accumulating rollout for training 
    # This single call executes 10 prediction steps and returns the mean accumulated loss.
    # It does NOT return predictions or ground_truth tensors.
    loss = processor_rollout.rollout(encoded_batch)    
    # 3. Optimization Step
    loss.backward()
    optimizer.step()
    
    if step % 10 == 0:
        print(f"   Step {step}: Autoregressive Rollout Loss (Mean Diffusion) = {loss.item():.4f}")

# ============================================================================
# TEST ROLLOUT - Generate Full Time Series
# ============================================================================

In [None]:
processor_rollout.eval()

# Get a test batch
test_batch = next(iter(train_loader_rollout))
test_encoded = EncodedBatch(
    encoded_inputs=test_batch.input_fields,
    encoded_output_fields=test_batch.output_fields,
    encoded_info={}
)

with torch.no_grad():
    # Get rollout predictions
    predictions, ground_truth = processor_rollout.rollout(test_encoded)
    
    print(f"\nðŸ“Š Rollout Results:")
    print(f"   Input shape:       {test_batch.input_fields.shape}")
    print(f"   Predictions shape: {predictions.shape}")
    print(f"   Ground truth shape: {ground_truth.shape if ground_truth is not None else 'None'}")


In [None]:

# ============================================================================
# VISUALIZE ROLLOUT - Full Time Series Video
# ============================================================================

print("\n" + "="*70)
print("Creating Rollout Video")
print("="*70)

from matplotlib.animation import FuncAnimation
from IPython.display import Image as IPImage

# Get full time series data
x_gt_full = ground_truth[0, :, 0].cpu().numpy()     # (T_total, H, W)
x_pred_full = predictions[0, :, 0].cpu().numpy()     # (T_total, H, W)
T_full = x_gt_full.shape[0]

print(f"\nðŸŽ¬ Creating animation with {T_full} frames...")

# Common color scale
vmin = min(x_gt_full.min(), x_pred_full.min())
vmax = max(x_gt_full.max(), x_pred_full.max())

# Create compact figure
fig, axes = plt.subplots(1, 2, figsize=(10, 4), dpi=100)

# Initial frames
im0 = axes[0].imshow(x_gt_full[0].squeeze(), cmap='viridis', vmin=vmin, vmax=vmax)
axes[0].set_title('Ground Truth', fontsize=12)
axes[0].axis('off')
plt.colorbar(im0, ax=axes[0], fraction=0.03, pad=0.02)

im1 = axes[1].imshow(x_pred_full[0].squeeze(), cmap='viridis', vmin=vmin, vmax=vmax)
axes[1].set_title('Rollout Prediction', fontsize=12)
axes[1].axis('off')
plt.colorbar(im1, ax=axes[1], fraction=0.03, pad=0.02)

# Time text with error
time_text = fig.suptitle(f'Time Step: 0/{T_full-1} | MAE: 0.000', fontsize=13, y=0.98)

def update(frame):
    """Update function for animation."""
    im0.set_data(x_gt_full[frame])
    im1.set_data(x_pred_full[frame])
    
    # Compute error for this frame
    mae = np.abs(x_gt_full[frame].squeeze() - x_pred_full[frame].squeeze()).mean()
    time_text.set_text(f'Time Step: {frame}/{T_full-1} | MAE: {mae:.4f}')
    
    return [im0, im1, time_text]

# Create animation
anim = FuncAnimation(fig, update, frames=T_full, interval=100, repeat=True)
anim.save('rollout_video.gif', writer='pillow', fps=10, dpi=100)
plt.close()

print("âœ… Rollout video saved: rollout_video.gif")

# Display in notebook
display(IPImage('rollout_video.gif'))

# ============================================================================
# ROLLOUT ERROR ANALYSIS
# ============================================================================

print("\n" + "="*70)
print("Rollout Error Analysis")
print("="*70)

# Compute error over time
with torch.no_grad():
    mse_per_frame = ((ground_truth - predictions) ** 2).mean(dim=[0, 2, 3, 4]).cpu().numpy()
    mae_per_frame = (ground_truth - predictions).abs().mean(dim=[0, 2, 3, 4]).cpu().numpy()

# Plot error over time
fig, axes = plt.subplots(1, 2, figsize=(12, 4), dpi=100)

axes[0].plot(mse_per_frame, marker='o', linewidth=2, markersize=4)
axes[0].set_xlabel('Time Step', fontsize=11)
axes[0].set_ylabel('MSE', fontsize=11)
axes[0].set_title('Mean Squared Error over Time', fontsize=12)
axes[0].grid(alpha=0.3)

axes[1].plot(mae_per_frame, marker='o', linewidth=2, markersize=4, color='orange')
axes[1].set_xlabel('Time Step', fontsize=11)
axes[1].set_ylabel('MAE', fontsize=11)
axes[1].set_title('Mean Absolute Error over Time', fontsize=12)
axes[1].grid(alpha=0.3)

plt.suptitle('Rollout Error Accumulation', fontsize=14, y=1.02)
plt.tight_layout()
plt.savefig('rollout_error_analysis.png', dpi=150, bbox_inches='tight')
print("âœ… Saved: rollout_error_analysis.png")
plt.show()

# Print statistics
print(f"\nðŸ“Š Rollout Statistics:")
print(f"   Initial MAE:  {mae_per_frame[0]:.6f}")
print(f"   Final MAE:    {mae_per_frame[-1]:.6f}")
print(f"   Mean MAE:     {mae_per_frame.mean():.6f}")
print(f"   Error growth: {mae_per_frame[-1] / mae_per_frame[0]:.2f}x")

# ============================================================================
# SIDE-BY-SIDE COMPARISON: Single-Step vs Rollout
# ============================================================================

print("\n" + "="*70)
print("Comparison: Single-Step vs Rollout")
print("="*70)

# Single-step prediction (original)
with torch.no_grad():
    single_step_pred = processor.map(test_batch.input_fields)

# Plot comparison at different time points
time_points = [0, T_full//4, T_full//2, 3*T_full//4, T_full-1]

fig, axes = plt.subplots(3, len(time_points), figsize=(15, 8), dpi=100)

for idx, t in enumerate(time_points):
    if t < single_step_pred.shape[1]:
        # Ground truth
        axes[0, idx].imshow(x_gt_full[t].squeeze(), cmap='viridis', vmin=vmin, vmax=vmax)
        axes[0, idx].set_title(f't={t}', fontsize=10)
        axes[0, idx].axis('off')
        
        # Single-step
        single_frame = single_step_pred[0, min(t, single_step_pred.shape[1]-1), 0].cpu().numpy()
        axes[1, idx].imshow(single_frame, cmap='viridis', vmin=vmin, vmax=vmax)
        axes[1, idx].axis('off')
        
        # Rollout
        axes[2, idx].imshow(x_pred_full[t].squeeze(), cmap='viridis', vmin=vmin, vmax=vmax)
        axes[2, idx].axis('off')

axes[0, 0].set_ylabel('Ground Truth', fontsize=11, rotation=0, ha='right', va='center')
axes[1, 0].set_ylabel('Single-Step', fontsize=11, rotation=0, ha='right', va='center')
axes[2, 0].set_ylabel('Rollout', fontsize=11, rotation=0, ha='right', va='center')

plt.suptitle('Prediction Comparison Over Time', fontsize=14, y=0.98)
plt.tight_layout()
plt.savefig('comparison_single_vs_rollout.png', dpi=150, bbox_inches='tight')
print("âœ… Saved: comparison_single_vs_rollout.png")
plt.show()

print("\nâœ… Rollout testing complete!")