In [None]:
import torch

# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device = "cpu"
print(f"Using device: {device}")

# Set the default tensor type and device for all computations
torch.set_default_device(device)

In [None]:
import torch
from torch.utils.data import Dataset
import h5py

from autoemulate.core.types import TensorLike

from autoemulate.experimental.data.spatio_temporal_dataset import AutoEmulateDataset , MHDDataset

In [None]:
# Example with fusion

from torch.utils.data import DataLoader

dataset = MHDDataset("/bask/homes/h/hdjd5168/vjgo8416-ai-phy-sys/marj/AE_exploratory/FNO/complete_mhd_dataset/mhd_dataset.h5", t_in=2, t_out=53)

# this is temp as the simulated  data I have is not in teh well format
for i in range(len(dataset.all_input_fields)):
    dataset.all_input_fields[i] = dataset.all_input_fields[i].permute(0, 2, 3, 1)  # [T,C,H,W] → [T,H,W,C]
    dataset.all_output_fields[i] = dataset.all_output_fields[i].permute(0, 2, 3, 1)

dataloader = DataLoader(dataset, batch_size=1, shuffle=False)

batch_fusion = next(iter(dataloader))



In [None]:
batch_fusion["input_fields"].shape, batch_fusion["output_fields"].shape

In [None]:
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
import numpy as np
idx = np.random.randint(0, 199)
random_sample = dataset[idx]
data = random_sample["output_fields"].detach().cpu().numpy()
data = random_sample["input_fields"].detach().cpu().numpy()
print(data.shape)
time_steps, height, width, num_channels = data.shape
fig, axes = plt.subplots(1,2, figsize=(10,5))

ims = [ax.imshow(data[t, :, :, 0], aspect='auto', cmap='viridis') for t, ax in zip(range(time_steps), axes)]

def animate(frame):
    for i, im in enumerate(ims):
        d = data[frame, :, :, i]
        im.set_array(d)
        im.set_clim(d.min(), d.max())
    return ims

anim = animation.FuncAnimation(fig, animate, frames=time_steps, interval=200, repeat=False)
HTML(anim.to_jshtml())

In [None]:
from autoemulate.experimental.emulators.fno import FNOEmulator
emulator = FNOEmulator(
    n_vars=1,  # density + temperature + constants
    n_modes=(64, 64,1),  # spatial + temporal modes (match t_out)
    hidden_channels=64,
    lr=1e-3,
    epochs=100
)
emulator.fit(dataloader, None)


In [None]:
x = batch["input_fields"][:, :, :, :, :1]  # [batch, time, height, width, channels]
y = batch["output_fields"][:, :, :, :, :1]  # [batch, time, height, width, channels]
x.shape, y.shape


In [None]:
x = x.permute(0, 4, 1, 2, 3)  # Convert to [batch, channels, time, height, width]
x.shape


In [None]:
import torch

def prepare_batch(sample, channels = (0,), with_constants=True, with_time=False):
    # Get input fields, constant scalars and output fields
    x = sample["input_fields"][:, :, :, :, channels]  # [batch, time, height, width, len(channels)]
    constant_scalars = sample["constant_scalars"]  # [batch, n_constants]
    y = sample["output_fields"][:, :, :, :, channels]  # [batch, time, height, width, len(channels)]
    
    # Permute both x and y
    x = x.permute(0, 4, 1, 2, 3)  # [batch, len(channels), time, height, width]
    y = y.permute(0, 4, 1, 2, 3)  # [batch, len(channels), time, height, width]

    # Only add constants to input, not output
    if with_constants:
        # Assign spatio-temporal dims to constants
        time_window, height, width = x.shape[2], x.shape[3], x.shape[4]
        n_constants = constant_scalars.shape[-1]

        # Add spatio-temporal dims to constants
        c_broadcast = constant_scalars.reshape(1, n_constants, 1, 1, 1).expand(1, n_constants, time_window, height, width)
        
        # Concatenate along channel dimension
        x = torch.cat([x, c_broadcast], dim=1)

    if not with_time:
        # Take last time step for both input and output
        return x[:, :, -1, :, :], y[:, :, -1, :, :]
    # Otherwise include time
    return x, y


In [None]:
# Without time
x_with_constants, y = prepare_batch(batch, channels=(0,), with_time=True)
print(f"Concatenated x shape: {x_with_constants.shape}")
print(f"Output y shape: {y.shape}")


In [None]:
# With time
x_with_constants, y = prepare_batch(batch, channels=(0,), with_time=True)
print(f"Concatenated x shape: {x_with_constants.shape}")
print(f"Output y shape: {y.shape}")


In [None]:
prepare_batch(batch, with_time=True)[0].shape

In [None]:

# Pass through model
model(prepare_batch(batch, with_time=True)[0]).shape
