In [None]:
from the_well.data import WellDataset

the_well = WellDataset(
    well_base_path="../data/the_well/datasets",
    well_dataset_name="active_matter",
    well_split_name="test",
    n_steps_input=1,
    n_steps_output=1,
    
)

In [None]:
from torch.utils.data import DataLoader

train_loader = DataLoader(the_well, batch_size=1)

batch = next(iter(train_loader))


In [None]:
batch["input_fields"].shape, batch["output_fields"].shape


In [None]:

batch["input_fields"][0, 0, :, :, 0].shape


In [None]:

import matplotlib.pyplot as plt

plt.imshow(batch["input_fields"][0, 0, :, :, 0])


In [None]:
from neuralop.models import FNO, FNO2d, FNO3d


# Q: what happens if there is more than one channel? How are these combined
model = FNO2d(
    n_modes_height=16,
    n_modes_width=16,
    hidden_channels=16,
    in_channels=4,
    out_channels=1,

)

In [None]:
x = batch["input_fields"][:, :, :, :, :1]  # [batch, time, height, width, channels]
y = batch["output_fields"][:, :, :, :, :1]  # [batch, time, height, width, channels]
x.shape, y.shape


In [None]:
x = x.permute(0, 4, 1, 2, 3)  # Convert to [batch, channels, time, height, width]
x.shape


In [None]:
import torch

def prepare_batch(sample, channels = (0,), with_constants=True, with_time=False):
    # Get input fields, constant scalars and output fields
    x = sample["input_fields"][:, :, :, :, channels]  # [batch, time, height, width, len(channels)]
    constant_scalars = sample["constant_scalars"]  # [batch, n_constants]
    y = sample["output_fields"][:, :, :, :, channels]  # [batch, time, height, width, len(channels)]
    
    # Permute both x and y
    x = x.permute(0, 4, 1, 2, 3)  # [batch, len(channels), time, height, width]
    y = y.permute(0, 4, 1, 2, 3)  # [batch, len(channels), time, height, width]

    # Only add constants to input, not output
    if with_constants:
        # Assign spatio-temporal dims to constants
        time_window, height, width = x.shape[2], x.shape[3], x.shape[4]
        n_constants = constant_scalars.shape[-1]

        # Add spatio-temporal dims to constants
        c_broadcast = constant_scalars.reshape(1, n_constants, 1, 1, 1).expand(1, n_constants, time_window, height, width)
        
        # Concatenate along channel dimension
        x = torch.cat([x, c_broadcast], dim=1)

    if not with_time:
        # Take last time step for both input and output
        return x[:, :, -1, :, :], y[:, :, -1, :, :]
    # Otherwise include time
    return x, y


In [None]:
# Without time
x_with_constants, y = prepare_batch(batch, channels=(0,))
print(f"Concatenated x shape: {x_with_constants.shape}")
print(f"Output y shape: {y.shape}")


In [None]:
# With time
x_with_constants, y = prepare_batch(batch, channels=(0,), with_time=True)
print(f"Concatenated x shape: {x_with_constants.shape}")
print(f"Output y shape: {y.shape}")


In [None]:
# Pass through model
model(prepare_batch(batch)[0]).shape


In [None]:
from torch.optim import AdamW
from torch.nn import MSELoss

# Explicitly set shuffle=False to ensure monotonic ordering
train_loader = DataLoader(the_well, batch_size=1, shuffle=False)

optimizer = AdamW(
    model.parameters(),
    lr=8e-3,
    weight_decay=1e-4
)

loss_fn = MSELoss()
channels = (0,)  # Which channel to use

for idx, sample in enumerate(train_loader):
    # Prepare input with constants
    x, y = prepare_batch(sample, channels=channels, with_constants=True, with_time=False)
    
    # Predictions
    y_pred = model(x)

    # Get loss
    loss = loss_fn(y_pred, y)

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

    print(f"sample {idx:5d}, loss: {loss.item():.5e}")

    # Break after a few samples for testing
    if idx >= 100:
        break