In [None]:
using DrWatson
@quickactivate "sst-mot-analysis"

import DrWatson: datadir, srcdir
import DrWatson: @quickactivate
include(srcdir("compare_random_samples.jl"))

using GH19, TMI, PythonCall, Revise, CondaPkg
using CSV, DataFrames, NCDatasets

const cfxr = pyimport("cf_xarray");

const xr = pyimport("xarray");
const xe = pyimport("xesmf");
const np = pyimport("numpy")                                


In [None]:
# Configure LGM dataset
TMIversion_lgm = "LGM_90x45x33_G14"
A_lgm, Alu_lgm, γ_lgm, TMIfile_lgm, L_lgm, B_lgm = config(TMIversion_lgm);

In [None]:
function add_boundaries(ds)
    try 
        ds = ds.to_dataset().cf.add_bounds(["lon", "lat"])
    catch
        ds = ds.cf.add_bounds(["lon", "lat"])
    end
    lon_b = cfxr.bounds_to_vertices(ds["lon_bounds"], ds.cf.get_bounds_dim_name("lon"), order=nothing)
    lat_b = cfxr.bounds_to_vertices(ds["lat_bounds"], ds.cf.get_bounds_dim_name("lat"), order=nothing)
    ds["lon_b"] = lon_b
    ds["lat_b"] = lat_b
    return ds

end


In [None]:


function add_boundaries_2d(ds)
    lon = ds.lon.values    # (ny,nx)
    lat = ds.lat.values

    # 1) pad to (ny+2, nx+2):
    #    - X (second dim) is periodic → wrap
    #    - Y (first dim) is not → edge
    lon_ext = np.pad(lon, ((1,1),(1,1)), mode="wrap")
    lat_ext = np.pad(lat, ((1,1),(1,1)), mode="edge")

    # alias for full slice “:”
    full = Slice(nothing, nothing)

    # 2) average the four neighbors at each corner → (ny+1,nx+1)
    lon_b = 0.25 * (
       lon_ext[Slice(nothing, -1), Slice(nothing, -1)] .+
       lon_ext[Slice(1,       nothing), Slice(nothing, -1)] .+
       lon_ext[Slice(nothing, -1), Slice(1,       nothing)] .+
       lon_ext[Slice(1,       nothing), Slice(1,       nothing)]
    )
    lat_b = 0.25 * (
       lat_ext[Slice(nothing, -1), Slice(nothing, -1)] .+
       lat_ext[Slice(1,       nothing), Slice(nothing, -1)] .+
       lat_ext[Slice(nothing, -1), Slice(1,       nothing)] .+
       lat_ext[Slice(1,       nothing), Slice(1,       nothing)]
    )

    # 3) attach with the correct dims
    return ds.assign_coords(
      lon_b = (("y_b","x_b"), lon_b),
      lat_b = (("y_b","x_b"), lat_b),
    )
end


In [None]:
surfidx = surfaceindex(γ_lgm)
TMI_lgm_theta = xr.open_dataset(TMIfile_lgm)["θ"].isel(depth = surfidx - 1).drop_vars("depth")
mask = 1 * TMI_lgm_theta

mask.values = xr.where(~np.isnan(TMI_lgm_theta), 1, 0)
TMI_lgm_theta["mask"] = mask;
TMI_lgm_theta = add_boundaries(TMI_lgm_theta);

In [None]:
 xe.util.grid_global(1, 1)

In [None]:
lgm_DA = xr.open_dataset(datadir("lgmDA_lgm_Ocn_annual.nc"))
lgm_DA["mask"] = xr.where(~np.isnan(lgm_DA["sst"]), 1, 0)

lgm_DA = add_boundaries(lgm_DA)

ds_out = xe.util.grid_global(0.5, 0.5)

area_in  = xe.util.cell_area(lgm_DA,  earth_radius=6371.0)       # for your source
area_out = xe.util.cell_area(ds_out,  earth_radius=6371.0)  # for your target :contentReference[oaicite:0]{index=0}

# ds_out = ds_out.drop_vars("lat_b").drop_vars("lon_b")

regridder = xe.Regridder(lgm_DA, ds_out, method = "conservative_normed", periodic = true, ignore_degenerate=false, )
lgm_DA_remapped = regridder(lgm_DA["sst"])

In [None]:
area_avg(ds, area, mask) = (ds *  area *  mask).sum() / (area *  mask).sum()

println(area_avg(lgm_DA["sst"], area_in, lgm_DA["mask"]))

println(area_avg(lgm_DA_remapped, area_out, TMI_lgm_theta["mask"]))

In [None]:
fig, ax = subplots()
lgm_DA_remapped.plot(ax = ax)
fig

In [None]:
lgm_DA = xr.open_dataset(datadir("lgmDA_lgm_Ocn_annual.nc"))
lgm_DA["mask"] = xr.where(~np.isnan(lgm_DA["sst"]), 1, 0)
lgm_DA = add_boundaries(lgm_DA)

area_in  = xe.util.cell_area(lgm_DA,  earth_radius=6371.0)       # for your source
area_out = xe.util.cell_area(TMI_lgm_theta,  earth_radius=6371.0)  # for your target :contentReference[oaicite:0]{index=0}

regridder = xe.Regridder(lgm_DA, TMI_lgm_theta, method = "bilinear", periodic = true, ignore_degenerate=false, )
lgm_DA_remapped = regridder(lgm_DA["sst"])

regrid_nearest = xe.Regridder(
    lgm_DA, TMI_lgm_theta,
    method="nearest_s2d",
    periodic=true,          # or True if global wrap‑around
    ignore_degenerate=true, 
    extrap_method="nearest_s2d"  # fill any remaining gaps :contentReference[oaicite:0]{index=0}
)

lgm_DA_near = regrid_nearest(lgm_DA["sst"])


In [None]:
area_avg(lgm_DA_near, area_out, TMI_lgm_theta["mask"])

In [None]:
fig, ax = subplots()
TMI_lgm_theta["mask"].plot(ax = ax)
fig

In [None]:
fig, ax = subplots()
lgm_DA.lon.plot(ax = ax)
fig

In [None]:
fig, ax = subplots()
lgm_DA.mask.plot(ax = ax)
fig