In [None]:
import torch

# device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
device = "cpu"
print(f"Using device: {device}")

# Set the default tensor type and device for all computations
torch.set_default_device(device)

In [None]:

from the_well.data import WellDataset

the_well = WellDataset(
    well_base_path="../data/the_well/datasets",
    well_dataset_name="active_matter",
    well_split_name="test",
    n_steps_input=80,
    n_steps_output=1,
    
)


In [None]:
batch = next(iter(the_well))
x = batch["input_fields"]
x.shape

In [None]:
x = torch.arange(0, 80*1*1*11).reshape(80, 1, 1, 11)
x.shape


In [None]:
x[2:7]

In [None]:
x.unfold(0, 5, 2).permute(0, -1, 1, 2, 3)[1]
# x.unfold(0, 5, 1).permute(0, -1, 1, 2, 3)[1]

In [None]:
import torch
from torch.utils.data import Dataset
import h5py

from autoemulate.core.types import TensorLike

class AutoEmulateDataset(Dataset):
    # Two steps:
    # 1. Move methods from the specific subclass to the new base class
    # 2. Make it look like the well (for __getitem__)
    # TODO: add to performance issue #421
    def __init__(
        self,
        data_path: str,
        n_steps_input: int,
        n_steps_output: int,
        stride: int = 1,
        input_channel_idxs: tuple[int, ...] | None = None,
        output_channel_idxs: tuple[int, ...] | None = None,
    ):
        """
        Args:
            data_path: Path to HDF5 file
            t_in: Number of input timesteps
            t_out: Number of output timesteps  
            stride: Stride between sequences
        """
        self.n_steps_input = n_steps_input
        self.n_steps_output = n_steps_output
        self.stride = stride
        self.input_channel_idxs = input_channel_idxs # (0, 1, 2): predict from 
        self.output_channel_idxs = output_channel_idxs # (0,)
        # (0,)
        # (1, 2)
        # (3, 4, 5, 6)
        
        # Load data
        with h5py.File(data_path, 'r') as f:
            # N: n_trajectories
            self.data: TensorLike = f['data'][:]  # [N, T, W, H, C]
            # TODO: since not supported into
            self.constant_scalars: TensorLike = f['constant_scalars'][:] if 'constant_scalars' in f else None
        
        # Destructured here
        self.n_trajectories, self.n_timesteps, self.width, self.height, self.n_channels = self.data.shape
    
    # TODO: is this required and what should it be
    def __len__(self):
        return self.n_trajectories
    
    def __getitem__(self, idx):
        # # Map flat index to (sample_idx, seq_idx)
        trajectory_idx = idx
        # trajectory_idx = idx // len(self)
        # trajectory_idx = idx % len(self)
        
        # seq_idx = idx % self.sequences_per_sample
        
        # # Get start timestep for this sequence
        # start_t = seq_idx * self.stride
        # (0, 1, 2, 3, 4), (5,)
        #    (0, 1, 2, 3, 4, 5,) -> (0, 1, 2, 3, 4), (5,)
        # (1, 2, 3, 4, 5), (6,)
        # data = self.data[sample_idx]
        
        # num_int_subtrajectories: (sequence_length - (n_step_inputs + n_step_outputs)) // stride) + 1
        # 81 length for active matter
        fields = (
            self.data[trajectory_idx]
            .unfold(0, self.n_steps_input + self.n_steps_output, self.stride)
            .permute(0, -1, 1, 2, 3) # [num_int_subtrajectories, T_in + T_out, W, H, C]
        )
        input_fields = fields[:, :self.n_steps_input, ...]  # [num_int_subtrajectories, T_in, W, H, C]
        output_fields = fields[:, self.n_steps_input:, ...]  # [num_int_subtrajectories, T_out, W, H, C]

        constant_scalars = self.constant_scalars[trajectory_idx] if self.constant_scalars is not None else None
        # return torch.FloatTensor(input_seq), torch.FloatTensor(output_seq)
        return {
            "input_fields": input_fields,
            "output_fields": output_fields,
            "constant_scalars": constant_scalars,
            # Keys that are not in the well
            # "some_other_field": ...
            "space_grid": ...,
            "input_time_grid": ...,
            "output_time_grid": ...,
        }
    

class MHDDataset(AutoEmulateDataset):
    """PyTorch Dataset for MHD data"""

    def __init__(self, data_path: str, t_in: int = 5, t_out: int = 10, stride: int = 1):
        super().__init__(data_path, n_steps_input=t_in, n_steps_output=t_out, stride=stride)



    class MyFNO():
        def forward(self, batch):

            # Implement the forward pass
            x = batch["input_fields"]
            # constant_scalars = batch["constant_scalars"]
            constant_scalars = batch["constant_params"]
            return x

In [None]:
batch["input_fields"].shape

In [None]:

the_well.metadata

In [None]:

from torch.utils.data import DataLoader

train_loader = DataLoader(the_well, batch_size=1)

batch = next(iter(train_loader))


In [None]:
batch.keys()

In [None]:
batch["input_fields"].shape

In [None]:
batch["output_time_grid"]


In [None]:
the_well.

In [None]:

batch["input_fields"].shape, batch["output_fields"].shape


In [None]:

batch["input_fields"][0, 0, :, :, 0].shape


In [None]:

import matplotlib.pyplot as plt

plt.imshow(batch["input_fields"][0, 0, :, :, 0].cpu())


In [None]:
from neuralop.models import FNO

model = FNO(
    n_modes=(2, 16, 16),
    hidden_channels=16,
    in_channels=4,
    out_channels=1,
)


In [None]:
from autoemulate.core.types import TensorLike
from autoemulate.emulators.base import PyTorchBackend

import torch

def prepare_batch(sample, channels = (0,), with_constants=True, with_time=False):
    # Get input fields, constant scalars and output fields
    x = sample["input_fields"][:, :, :, :, channels]  # [batch, time, height, width, len(channels)]
    constant_scalars = sample["constant_scalars"]  # [batch, n_constants]
    y = sample["output_fields"][:, :, :, :, channels]  # [batch, time, height, width, len(channels)]
    
    # Permute both x and y
    x = x.permute(0, 4, 1, 2, 3)  # [batch, len(channels), time, height, width]
    y = y.permute(0, 4, 1, 2, 3)  # [batch, len(channels), time, height, width]

    # Only add constants to input, not output
    if with_constants:
        # Assign spatio-temporal dims to constants
        time_window, height, width = x.shape[2], x.shape[3], x.shape[4]
        n_constants = constant_scalars.shape[-1]

        # Add spatio-temporal dims to constants
        c_broadcast = constant_scalars.reshape(1, n_constants, 1, 1, 1).expand(1, n_constants, time_window, height, width)
        
        # Concatenate along channel dimension
        x = torch.cat([x, c_broadcast], dim=1)

    if not with_time:
        # Take last time step for both input and output
        return x[:, :, -1, :, :], y[:, :, -1, :, :]
    # Otherwise include time
    return x, y

class FNOEmulator(PyTorchBackend):
    def __init__(self, x, y, *args, **kwargs):
        self.model = FNO(**kwargs)

    def _fit(self, x: DataLoader, y: DataLoader):
        channels = (0,)  # Which channel to use
        print_shapes = False
        for idx, batch in enumerate(train_loader):
            # Prepare input with constants
            x, y = prepare_batch(
                batch, channels=channels, with_constants=True, with_time=True
            )
            
            # Predictions
            y_pred = model(x)

            # Print shapes
            if print_shapes:
                print(x.shape, y.shape, y_pred.shape)
            
            # Get loss
            # Take the first time idx as the next time step prediction
            loss = self.loss_fn(y_pred[:, :, :1, ...], y)

            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()

            print(f"sample {idx:5d}, loss: {loss.item():.5e}")

    def forward(self, x: TensorLike):
        return self.model(x)
    
    def _predict(self, x, with_grad):
        return super()._predict(x, with_grad)


In [None]:
x = batch["input_fields"][:, :, :, :, :1]  # [batch, time, height, width, channels]
y = batch["output_fields"][:, :, :, :, :1]  # [batch, time, height, width, channels]
x.shape, y.shape


In [None]:
x = x.permute(0, 4, 1, 2, 3)  # Convert to [batch, channels, time, height, width]
x.shape


In [None]:
import torch

def prepare_batch(sample, channels = (0,), with_constants=True, with_time=False):
    # Get input fields, constant scalars and output fields
    x = sample["input_fields"][:, :, :, :, channels]  # [batch, time, height, width, len(channels)]
    constant_scalars = sample["constant_scalars"]  # [batch, n_constants]
    y = sample["output_fields"][:, :, :, :, channels]  # [batch, time, height, width, len(channels)]
    
    # Permute both x and y
    x = x.permute(0, 4, 1, 2, 3)  # [batch, len(channels), time, height, width]
    y = y.permute(0, 4, 1, 2, 3)  # [batch, len(channels), time, height, width]

    # Only add constants to input, not output
    if with_constants:
        # Assign spatio-temporal dims to constants
        time_window, height, width = x.shape[2], x.shape[3], x.shape[4]
        n_constants = constant_scalars.shape[-1]

        # Add spatio-temporal dims to constants
        c_broadcast = constant_scalars.reshape(1, n_constants, 1, 1, 1).expand(1, n_constants, time_window, height, width)
        
        # Concatenate along channel dimension
        x = torch.cat([x, c_broadcast], dim=1)

    if not with_time:
        # Take last time step for both input and output
        return x[:, :, -1, :, :], y[:, :, -1, :, :]
    # Otherwise include time
    return x, y


In [None]:
# Without time
x_with_constants, y = prepare_batch(batch, channels=(0,), with_time=True)
print(f"Concatenated x shape: {x_with_constants.shape}")
print(f"Output y shape: {y.shape}")


In [None]:
# With time
x_with_constants, y = prepare_batch(batch, channels=(0,), with_time=True)
print(f"Concatenated x shape: {x_with_constants.shape}")
print(f"Output y shape: {y.shape}")


In [None]:
prepare_batch(batch, with_time=True)[0].shape

In [None]:

# Pass through model
model(prepare_batch(batch, with_time=True)[0]).shape


In [None]:
from torch.optim import AdamW
from torch.nn import MSELoss

# Create new model
model = FNO(
    n_modes=(2, 16, 16),
    hidden_channels=16,
    in_channels=4,
    out_channels=1,
).to(device)

# Explicitly set shuffle=False to ensure monotonic ordering
train_loader = DataLoader(the_well, batch_size=1, shuffle=False)

optimizer = AdamW(
    model.parameters(),
    lr=1e-2,
    # weight_decay=1e-4
)

loss_fn = MSELoss().to(device)
channels = (0,)  # Which channel to use
print_shapes = False
for idx, batch in enumerate(train_loader):
    # Prepare input with constants
    x, y = prepare_batch(batch, channels=channels, with_constants=True, with_time=True)
    
    # Predictions
    y_pred = model(x)

    # Print shapes
    if print_shapes:
        print(x.shape, y.shape, y_pred.shape)
    
    # Get loss
    # Take the first time idx as the next time step prediction
    loss = loss_fn(y_pred[:, :, :1, ...], y)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    print(f"sample {idx:5d}, loss: {loss.item():.5e}")

    # Break after a few samples for testing
    if idx >= 100:
        break