# Demo notebook: Regrid sea surface temperature data
The data is originally on a tripolar grid to have a smooth grid over the north pole   
We regrid it to Discrete Global Grid System (DGGS) via standard latitude-longitude grid  

## 1. Load required libraries

In [None]:
# Install xarray-healpy and dggs libraries for regridding
%pip install git+https://github.com/IAOCEA/xarray-healpy.git git+https://github.com/xarray-contrib/xdggs.git

In [None]:
import warnings
from pathlib import Path

import cartopy.crs as ccrs  # Map projections
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import s3fs
import xarray as xr  # N-dimensional arrays with dimension, coordinate and attribute labels
from data_handling import load_grid_vertex, regrid_to_dggs, standardize_variable_names

warnings.simplefilter("ignore", category=DeprecationWarning)
xr.set_options(display_expand_data=False, display_expand_attrs=False, keep_attrs=True)

## 2. Load sea surface temperature data on tripolar grid

In [None]:
# Define file paths
endpoint_url = "https://server-data.fair2adapt.sigma2.no"
tripolar_grid_data_path = "s3://CS1/data/model/JRAOC20TRNRPv2_hm_sst_2010-01.nc"

# Extract files from S3
client_kwargs = {"endpoint_url": endpoint_url}
s3 = s3fs.S3FileSystem(anon=True, client_kwargs=client_kwargs)

# Get tripolar grid data (Opening the ds typically takes a few minutes)
ds = xr.open_dataset(s3.open(tripolar_grid_data_path))

# Display the subset dataset
ds

In [None]:
# data_path = Path("./CS1-nird/data/")
# tripolar_grid_data_path = data_path / "model" / "JRAOC20TRNRPv2_hm_sst_2010-01.nc"
# ds = xr.open_dataset(tripolar_grid_data_path)

In [None]:
# Get grid location information
data_path = Path("./CS1-nird/data/")  # Needed to load local vairables
# grid_file_path = data_path / "grid" / "grid.nc"
# plat, plon, pclat, pclon = load_grid_vertex(grid_file_path)
grid_file_path = "s3://CS1/data/grid/grid.nc"
plat, plon, pclat, pclon = load_grid_vertex(s3.open(grid_file_path))

# latitude and longitude variables, not dimensions
ds = ds.assign_coords(lat=(["y", "x"], plat), lon=(["y", "x"], plon))
ds = standardize_variable_names(ds)
# Center the data
ds.coords["longitude"] = (ds.coords["longitude"] + 180) % 360 - 180
# ds.drop coordinates(["plat", "plon"])
ds

In [None]:
proj = ccrs.NearsidePerspective(
    central_longitude=0.0, central_latitude=80.0, satellite_height=3e6
)
fig, ax = plt.subplots(1, figsize=(8, 4.5), dpi=96, subplot_kw={"projection": proj})

# A temperature map
pm0 = ax.pcolormesh(
    plon,
    plat,
    ds.sst[0, :, :],
    vmin=-3,
    vmax=20,
    cmap="viridis",
    transform=ccrs.PlateCarree(),
    shading="auto",
    rasterized=True,
)


# Add coastlines and the lat-lon grid
ax.coastlines(resolution="50m", color="black", linewidth=0.5)
ax.stock_img()
gl = ax.gridlines(ylocs=range(15, 76, 15), draw_labels=True)
gl.ylocator = mpl.ticker.FixedLocator([40, 50, 60, 70, 80])

plt.colorbar(pm0, fraction=0.2, shrink=0.4, label="degC")

ax.set_title("Sea Surface Temperature")
plt.show()

# 3. Load regridded dataset with PlateCarree grid

In [None]:
bilinear_regridded_data_path = (
    data_path
    / "model"
    / "JRAOC20TRNRPv2_hm_sst_2010-01_bil.nc"  # Path to data regridded from tripolar to platecarree using regrid_tripolar_to_platecarree.sh
)
dr = xr.open_dataset(bilinear_regridded_data_path)

In [None]:
dr = dr.rename_dims({"lat": "latitude", "lon": "longitude"})
dr.latitude.attrs["standard_name"] = "latitude"
dr.longitude.attrs["standard_name"] = "longitude"
dr[["longitude", "latitude"]].compute()
dr = dr.rename({"lon": "longitude", "lat": "latitude"})

In [None]:
dr.sst.isel(time=0).plot()

In [None]:
proj = ccrs.NearsidePerspective(
    central_longitude=0.0, central_latitude=80.0, satellite_height=3e6
)
fig, ax = plt.subplots(1, figsize=(8, 4.5), dpi=96, subplot_kw={"projection": proj})

# A temperature map
pm0 = ax.pcolormesh(
    dr.longitude,
    dr.latitude,
    dr.sst[0, :, :],
    vmin=-3,
    vmax=20,
    cmap="viridis",
    transform=ccrs.PlateCarree(),
    shading="auto",
    rasterized=True,
)

# Add coastlines and the lat-lon grid
ax.coastlines(resolution="50m", color="black", linewidth=0.5)
ax.stock_img()
gl = ax.gridlines(ylocs=range(15, 76, 15), draw_labels=True)
gl.ylocator = mpl.ticker.FixedLocator([40, 50, 60, 70, 80])

plt.colorbar(pm0, fraction=0.2, shrink=0.4, label="degC")

ax.set_title("Sea Surface Temperature")
plt.show()

In [None]:
# Have a closer look at a region of interest
lat_min, lat_max = 40, 65
lon_min, lon_max = -15, 30
fig, ax = plt.subplots(figsize=(6, 4))
p = dr.sst.isel(time=0).plot(ax=ax)

# Set zoom limits
ax.set_xlim(lon_min, lon_max)
ax.set_ylim(lat_min, lat_max)

plt.show()

In [None]:
ocean_mask = ~dr.sst.isel(time=0).isnull()  # Mask land as False, ocean as True

# Load conservatively regridded

In [None]:
conservative_regridded_dataset_path = (
    data_path
    / "model"
    / "JRAOC20TRNRPv2_hm_sst_2010-01_con.nc"  # Path to data regridded from tripolar to platecarree using regrid_tripolar_to_platecarree.sh
)
dcon = xr.open_dataset(conservative_regridded_dataset_path)
dcon = dcon.rename_dims({"lat": "latitude", "lon": "longitude"})
dcon.latitude.attrs["standard_name"] = "latitude"
dcon.longitude.attrs["standard_name"] = "longitude"
dcon[["longitude", "latitude"]].compute()
dcon = dcon.rename({"lon": "longitude", "lat": "latitude"})
regrid_diff = dr - dcon

In [None]:
regrid_diff.sst.isel(time=0).plot()

# 4. Regrid from PlateCarree to Healpy DGGS

In [None]:
nside = 256  # Each side of the original 12 faces in Healpix is divided into nside parts
healpy_grid_level = int(np.log2(nside))  # Healpix level
number_of_cells = 12 * nside**2  # The resulting total number of cells

min_vertices = 2  # Minimum number of vertices for a valid transcription for regridding.
# 1 is the most liberal, meaning that only one is needed

print("nside:", nside)
print("Level:", healpy_grid_level)
print("Number of cells:", number_of_cells)

regridded = regrid_to_dggs(
    dcon, nside, min_vertices, method="bilinear", mask=ocean_mask
)
ds_regridded = regridded.sst.compute().squeeze()

In [None]:
ds_regridded.dggs.explore()

# 5. Save the regridded data to zarr

In [None]:
save_location = data_path / f"SST-healpix-lvl-{healpy_grid_level}.zarr"
ds_regridded.to_zarr(save_location, mode="w")