In [None]:
from pathlib import Path

import sys

sys.path.append(str(Path("../src").resolve()))

import torch
import torch.nn.functional
from torch.utils.data import DataLoader
from st_encoder_decoder import (
    SpatioTemporalModel,
)
from dataset import SSTDataset

## Read data

In [None]:
data_folder = Path("../../data/output/")
daily_files = [f for f in (data_folder / "original").glob("*_ERA5_full_ts.nc")]

# For debug only, choose 4 files
daily_files = daily_files[:2]

monthly_files = [data_folder / "202001_mon_ERA5_full_ts.nc"]
mask_file = data_folder / "era5_lsm_bool.nc"

In [None]:
# Define spatial subset
spatial_subset = {"lon": slice(-50, -10), "lat": slice(-40, -20)}

dataset = SSTDataset(
    daily_files=daily_files,
    monthly_files=monthly_files,
    mask_file=mask_file,
    patch_size=(16, 16),
    overlap=2,
    spatial_subset=spatial_subset,
)

In [None]:
dataloader = DataLoader(
    dataset,
    batch_size=8,
    shuffle=True,
    pin_memory=False,
)

## Create the model

In [None]:
# Initialize model
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SpatioTemporalModel(embed_dim=128, patch_size=(1, 16, 16), overlap=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
# Training loop with DataLoader
model.train()
for epoch in range(10):
    epoch_loss = 0.0
    for batch_idx, batch in enumerate(dataloader):
        # Move data to device
        daily_data = batch["daily"].to(device)
        monthly_target = batch["monthly"].to(device)
        land_mask = batch["mask"].to(device)

        # Forward pass
        optimizer.zero_grad()

        # You may need to adapt prepare_spatiotemporal_batch or call model directly
        # For now, assuming your model can handle the batch directly
        daily_mask = ~torch.isnan(daily_data)
        daily_data[~daily_mask] = 0.0

        pred = model(daily_data, daily_mask, land_mask)

        # Compute loss
        ocean = ~land_mask
        loss = torch.nn.functional.l1_loss(
            pred[0][ocean], monthly_target[:, 0, :, :][ocean]
        )

        # Backward pass
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

    print(f"Epoch {epoch}: Loss = {epoch_loss / len(dataloader):.4f}")

dataloader.close()