In [None]:
from pathlib import Path
import xarray as xr
import numpy as np

import sys

sys.path.append(str(Path("../src").resolve()))

import torch
import torch.nn.functional
from torch.utils.data import DataLoader
from st_encoder_decoder import (
    SpatioTemporalModel,
)
from dataset import SSTDataset

## Read data

In [None]:
data_folder = Path("../../data/output/")

file_name = data_folder / "202001_day_ERA5_masked_ts.nc"
daily_data = xr.open_dataset(file_name)

file_name = data_folder / "202001_mon_ERA5_full_ts.nc"
monthly_data = xr.open_dataset(file_name)

# downloded from era5 and regridded using the function `regrid_to_boundary_centered_grid`
# land 1, ocean 0
file_name = data_folder / "era5_lsm_bool.nc"
lsm_mask = xr.open_dataset(file_name)
lsm_mask = lsm_mask["lsm"].isel(time=0)

## Subset data (for fast example)

In [None]:
# coordinates of subset
lon_subset = slice(-50, -30)
lat_subset = slice(-40, -20)

daily_subset = daily_data.sel(lon=lon_subset, lat=lat_subset)
monthly_subset = monthly_data.sel(lon=lon_subset, lat=lat_subset)
lsm_subset = lsm_mask.sel(lon=lon_subset, lat=lat_subset)

In [None]:
daily_subset

In [None]:
monthly_subset

## Create the model

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
model = SpatioTemporalModel(embed_dim=128, patch_size=(1, 16, 16), overlap=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
decoder = model.decoder

In [None]:
dataset = SSTDataset(
    daily_ds=daily_subset,
    monthly_ds=monthly_subset,
    mask_da=lsm_mask,
    patch_size=(16, 16),
    overlap=2,
)

## Train the model

In [None]:
dataloader = DataLoader(
    dataset,
    batch_size=32,
    shuffle=True,
    pin_memory=False,
)

In [None]:
best_loss = float("inf")
patience = 10  # stop if no improvement for <patience> epochs
counter = 0

# Set decoder scale/bias from subset monthly stats
monthly_subset_values = monthly_subset["ts"].data
monthly_mean = float(np.nanmean(monthly_subset_values))
monthly_std = float(np.nanstd(monthly_subset_values))
with torch.no_grad():
    decoder.bias.fill_(monthly_mean)
    decoder.scale.fill_(monthly_std + 1e-6)  # small epsilon to avoid zero

# Training loop with DataLoader
model.train()
for epoch in range(201):
    for batch_idx, batch in enumerate(dataloader):
        # Get batch data
        daily_batch = batch["daily"].to(device)
        monthly_target = batch["monthly"].to(device)
        land_mask = batch["mask"].to(device)

        # Initialize gradients
        optimizer.zero_grad()

        # Batch prediction
        daily_mask = ~torch.isnan(daily_batch)
        daily_batch[~daily_mask] = 0.0
        pred = model(daily_batch, daily_mask, land_mask)

        # Compute loss
        ocean = ~land_mask
        loss = torch.nn.functional.l1_loss(
            pred[0][ocean], monthly_target[:, 0, :, :][ocean]
        )  # For SST there is only one channel
        loss.backward()
        optimizer.step()

    # Early stopping check
    if loss.item() < best_loss:
        best_loss = loss.item()
        counter = 0  # reset counter if improved
    if epoch % 20 == 0:
        print(f"The loss is {best_loss} at epoch {epoch}")
    else:
        counter += 1
        if counter >= patience:
            print(
                f"No improvement for {patience} epochs, stopping early at epoch {epoch}."
            )
            break

print("training done!")
print(loss.item())

## Inspect results and compare

In [None]:
daily_data_input = torch.tensor(
    daily_subset["ts"].data[np.newaxis, np.newaxis, ...]
)  # [batch, band, ...]
daily_mask_input = ~torch.isnan(daily_data_input)
daily_data_input[~daily_mask_input] = 0.0
ocean_mask_subset_input = ~torch.tensor(
    lsm_subset.data[np.newaxis, ...]
)  # [batch, ...]
monthly_prediction = model(daily_data_input, daily_mask_input, ocean_mask_subset_input)

In [None]:
from matplotlib import pyplot as plt

plt.imshow(daily_data_input.squeeze().mean(axis=0))
plt.colorbar()

In [None]:
monthly_subset["ts_pred"] = (
    ("time", "lat", "lon"),
    monthly_prediction[0].cpu().detach().numpy(),  # remove batch dimensions
)

In [None]:
monthly_subset["ts_pred"].plot()

In [None]:
# original target
monthly_subset["ts"].where(~lsm_subset.values).plot()

In [None]:
# Mean Absolute Error
target = monthly_subset["ts"].where(~lsm_subset.values)
rmae = abs(target - monthly_subset["ts_pred"]) / target.std(skipna=True)
rmae.plot()