In [None]:
from autoemulate.simulations.reaction_diffusion import ReactionDiffusion


rd = ReactionDiffusion(n=32, T=10, dt=0.1, return_timeseries=True)
data = rd.forward_samples_spatiotemporal(20)
y = data["data"]

In [None]:
data.keys()


In [None]:
import matplotlib.pyplot as plt

plt.imshow(y[0, 0, :, :, 0])


In [None]:
from torch.utils.data import DataLoader

from autoemulate.experimental.data.spatiotemporal_dataset import AutoEmulateDataset

dataset = AutoEmulateDataset(data_path=None, data=data, n_steps_input=1, n_steps_output=1)
batch_orig = next(iter(DataLoader(dataset)))

In [None]:
batch_orig["input_fields"].shape, batch_orig["output_fields"].shape, batch_orig["constant_scalars"].shape

In [None]:
# for autoregressive prediction , we need to split at trajectory level 

from torch.utils.data import DataLoader
# Split at trajectory level
n_trajectories = dataset.n_trajectories
train_traj_count = int(0.9 * n_trajectories)

# Get trajectory indices
train_traj_idxs = list(range(train_traj_count))
val_traj_idxs = list(range(train_traj_count, n_trajectories))

print(f"Train trajectories: {len(train_traj_idxs)}, Val trajectories: {len(val_traj_idxs)}")

In [None]:
train_traj_idxs

In [None]:
# Create train data
train_data = {
    'data': data["data"][:train_traj_count],
    'constant_scalars': data["constant_scalars"][:train_traj_count],
    'constant_fields': data["constant_fields"]  # This is None, so just pass it through
}

# Create val data  
val_data = {
    'data': data["data"][train_traj_count:],
    'constant_scalars': data["constant_scalars"][train_traj_count:],
    'constant_fields': data["constant_fields"]  # This is None, so just pass it through
}

In [None]:
train_data["data"].shape

In [None]:
train_dataset = AutoEmulateDataset(
    data_path=None, 
    data=train_data, 
    n_steps_input=1, 
    n_steps_output=1
)

val_dataset = AutoEmulateDataset(
    data_path=None, 
    data=val_data, 
    n_steps_input=1, 
    n_steps_output=1
)

In [None]:
train_loader = DataLoader(train_dataset)
val_loader = DataLoader(val_dataset)
batch = next(iter(train_loader))

In [None]:
batch["input_fields"].shape, batch["output_fields"].shape, batch["constant_scalars"].shape

In [None]:
batch["input_fields"].shape

In [None]:
import matplotlib.pyplot as plt
plt.imshow(batch["input_fields"][0, 0, :, :, 0].cpu())
plt.show()

In [None]:
import matplotlib.pyplot as plt
plt.imshow(batch["input_fields"][0, 0, :, :, 0].cpu())
plt.show()

In [None]:
from autoemulate.experimental.emulators.fno import FNOEmulator

emulator = FNOEmulator(
    n_modes=(1, 16, 16),
    hidden_channels=16,
    in_channels=3,
    out_channels=1,)


In [None]:
next(iter(train_loader)).keys()

In [None]:
# Fit the emulator
emulator.fit(train_loader, None)

In [None]:
train_loader

In [None]:
# Predictions
y_pred = emulator.predict(val_loader, with_grad=False)
y_pred.shape

In [None]:
plt.imshow(y_pred[103,0,0,:,:].cpu())


In [None]:
# Get initial sample
initial_sample = next(iter(val_loader))

# Autoregressive prediction
autoregressive_pred = emulator.predict_autoregressive(initial_sample, n_steps=20)
print(f"Autoregressive prediction shape: {autoregressive_pred.shape}")

In [None]:
data = autoregressive_pred[0, 0].cpu().numpy()
data.shape

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation

# Take first batch, first channel
data = autoregressive_pred[0, 0].cpu().numpy()  # [n_steps, height, width]

# Set consistent color scale across all frames
vmin, vmax = data.min(), data.max()

fig, ax = plt.subplots(figsize=(6, 6))
im = ax.imshow(data[0], animated=True, vmin=vmin, vmax=vmax, cmap='viridis')
plt.colorbar(im)

def animate(frame):
    im.set_array(data[frame])
    ax.set_title(f'Time step {frame}')
    return [im]

ani = animation.FuncAnimation(fig, animate, frames=data.shape[0], 
                            interval=500, blit=True, repeat=True)  # Slower: 500ms
plt.show()

In [None]:
plt.plot(data[:,0,0])
plt.plot(val_data['data'][0,:,0,0,0])

In [None]:
val_data['data'].shape

In [None]:
# Save animation as GIF file
ani.save('autoregressive_prediction.gif', writer='pillow', fps=4)
print("Animation saved as autoregressive_prediction.gif")

In [None]:
# Evaluate
# TODO: add to emulator perhaps as .evaluate()?
import torch
from autoemulate.experimental.emulators.fno import prepare_batch

y_true = torch.cat(
    [
        prepare_batch(
            batch, channels=(0,), with_constants=True, with_time=True
        )[1]
        for batch in DataLoader(dataset)
    ],
    dim=0
)


In [None]:
# TODO: fix with autoregressive prediction
# from torchmetrics import R2Score

# R2Score()(y_pred.reshape(-1).detach(), y_true.reshape(-1).detach()).item()