In [1]:
import sys

sys.path.append("..")

%load_ext autoreload
%autoreload complete

In [6]:
from math import ceil
from functools import partial
from pathlib import Path
import datetime

import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
import xarray as xr
import dask
import cftime

from lmrecon.io import IntakeESMLoader, save_mfdataset, open_mfdataset
from lmrecon.plotting import plot_field
from lmrecon.util import stack_state, unstack_state, is_dask_array, get_data_path, average_annually
from lmrecon.eof import EOF, EOFMethod
from lmrecon.spaces import PhysicalSpaceForecastSpaceMapper, Detrend, NanMask
from lmrecon.logger import logging_disabled

In [66]:
ds = open_mfdataset(get_data_path() / "cmip6_annual_anomalies/MPI-ESM1-2-LR/past2k")[
    ["ohc700", "tas", "psl", "pr", "zg500"]
]
ds = ds.isel(time=slice(0, 30))
ds

In [62]:
data = ds["zg500"].dropna("lat")
data = stack_state(data)
data -= data.mean(dim="time")
data

In [63]:
eof = EOF(10)
eof.fit(data.data)

In [64]:
fig, axs = plt.subplots(
    1,
    3,
    figsize=(15, 5),
    subplot_kw=dict(projection=ccrs.EqualEarth(central_longitude=198)),
)

axs[0].set_title("EOF 0")
plot_field(
    axs[0],
    unstack_state(xr.DataArray(eof.get_component(1), coords=dict(state=data.state)))["zg500"].T,
    cmap="RdBu_r",
    same_limits=True,
)

axs[1].set_title("ZG500 original")
plot_field(axs[1], unstack_state(data.isel(time=0))["zg500"].T, cmap="RdBu_r", same_limits=True)

axs[2].set_title("ZG500 reconstructed")
plot_field(
    axs[2],
    unstack_state(
        xr.DataArray(
            eof.project_backwards(eof.project_forwards(data.isel(time=0).data)),
            coords=dict(state=data.state),
        )
    )["zg500"].T,
    cmap="RdBu_r",
    same_limits=True,
)