In [1]:
import sys

sys.path.append("..")

%load_ext autoreload
%autoreload complete

In [2]:
from math import ceil
from functools import partial
from pathlib import Path
import json

import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
import cartopy.crs as ccrs
import xarray as xr
import cf_xarray as cfxr
import dask.array
import cftime
import scipy

from lmrecon.io import IntakeESMLoader, save_mfdataset, open_mfdataset
from lmrecon.plotting import plot_field, format_plot, save_plot
from lmrecon.util import (
    stack_state,
    unstack_state,
    is_dask_array,
    get_data_path, NanMask,
)
from lmrecon.logger import logging_disabled
from lmrecon.eof import EOF
from lmrecon.spaces import PhysicalSpaceForecastSpaceMapper
from lmrecon.lim import LIM
from lmrecon.time import month_name
from lmrecon.stats import area_weighted_mean, Detrend

In [3]:
plot_folder = Path("../plots")

In [4]:
reduced_space_dataset = "2024-06-03T01-20-30"
mapper = PhysicalSpaceForecastSpaceMapper.load(
    get_data_path() / "reduced_space" / reduced_space_dataset / "mapper.pkl"
)

In [5]:
def plot_eofs(fields):
    fig, axs = plt.subplots(
        3,
        len(fields),
        figsize=(len(fields) * 3, 5),
        subplot_kw=dict(projection=ccrs.EqualEarth(central_longitude=198)),
    )

    for ax_col, field in zip(axs.T, fields):
        ax_col[0].set_title(field)

        for n, ax in enumerate(ax_col):
            state = (
                xr.DataArray(
                    mapper.nan_masks[field].backward(
                        mapper.eofs_individual[field].get_component(n)[:, np.newaxis]
                    ),
                    coords=dict(state=mapper.state_coords.sel(field=field), time=[0]),
                )
                .unstack("state")
                .squeeze()
                .drop_vars("field")
                .transpose(..., "lat", "lon")
            )
            plot_field(ax, state, n_level=50, colorbar=False, cmap="RdBu_r", same_limits=True)
            ax.coastlines()

    save_plot(plot_folder, "eofs")


plot_eofs(["tas", "tos", "ohc700", "rsut", "rlut"])