In [None]:
from __future__ import absolute_import, division, print_function

In [None]:
# License: MIT

In [None]:
%matplotlib inline

# EOFs of NCEPv1 500hPa geopotential height anomalies

This notebook contains the routines for computing the EOFs and PCs of daily 500 hPa geopotential height anomalies
used for the subsequent FEM-BV-VAR analysis.

## Packages

In [None]:
import os
import warnings
import time

import cartopy.crs as ccrs
import matplotlib.gridspec as gridspec
import matplotlib.path as mpath
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scipy.stats as ss
import xarray as xr

from cartopy.util import add_cyclic_point
from sklearn.decomposition import TruncatedSVD

## File paths

In [None]:
PROJECT_DIR = os.path.join(os.path.dirname(os.path.abspath('produce_eofs.ipynb')),'..')
DATA_DIR = os.path.join(PROJECT_DIR,'data')
RESULTS_DIR = os.path.join(PROJECT_DIR,'results')
ANOM_RESULTS_DIR = os.path.join(RESULTS_DIR, 'anom')
EOFS_RESULTS_DIR = os.path.join(RESULTS_DIR, 'eofs')
EOFS_NC_DIR = os.path.join(EOFS_RESULTS_DIR, 'nc')
EOFS_PLOTS_DIR = os.path.join(EOFS_RESULTS_DIR, 'plt')

if not os.path.exists(DATA_DIR):
    raise RuntimeError("Input data directory '%s' not found" % DATA_DIR)    
    
if not os.path.exists(RESULTS_DIR):
    os.makedirs(RESULTS_DIR)    
    
if not os.path.exists(ANOM_RESULTS_DIR):
    os.makedirs(ANOM_RESULTS_DIR)

if not os.path.exists(EOFS_RESULTS_DIR):
    os.makedirs(EOFS_RESULTS_DIR)

if not os.path.exists(EOFS_NC_DIR):
    os.makedirs(EOFS_NC_DIR)

if not os.path.exists(EOFS_PLOTS_DIR):
    os.makedirs(EOFS_PLOTS_DIR)

In [None]:
# Input file containing the 500 hPa geopotential height values
INPUT_HGT_DATAFILE = os.path.join(DATA_DIR, 'hgt.500.1948_2018.nc')

In [None]:
# Default dimension names etc. for NNR1 data
TIME_NAME = 'time'
LAT_NAME = 'lat'
LON_NAME = 'lon'
VAR_NAME = 'hgt'

## Region definitions

In [None]:
def get_regions(hemisphere):
    """Get named regions defined for a given hemisphere.
    
    Parameters
    ----------
    hemisphere : 'WG' | 'NH' | 'SH'
        Hemisphere to get named regions for.
        
    Returns
    -------
    regions : list
        List of named regions defined for hemisphere.
    """

    if hemisphere == 'WG':
        return ['all']
    
    if hemisphere == 'NH':
        return ['all', 'eurasia', 'pacific', 'atlantic', 'atlantic2',
                'atlantic3', 'atlantic_eurasia']
    
    if hemisphere == 'SH':
        return ['all', 'indian', 'south_america', 'pacific',
                'full_pacific', 'australian']

    raise RuntimeError("Invalid hemisphere '%s'" % hemisphere)

In [None]:
def get_region_lon_bounds(hemisphere, region):
    """Get longitude bounds for a given region.
    
    Parameters
    ----------
    hemisphere : 'WG' | 'NH' | 'SH'
        Hemisphere in which region is located.
        
    region : str
        Name of region to get longitude bounds for.
        
    Returns
    -------
    lon_bounds : array, shape (2,)
        Array containing the longitude bounds for the region.
    """

    if hemisphere == 'WG' or hemisphere == 'NH':

        if region == 'all':
            return np.array([0, 360])

        if region == 'atlantic':
            return np.array([250, 360])
        
        if region == 'atlantic2':
            return np.array([270, 360])
        
        if region == 'atlantic3':
            return np.array([270, 40])

        if region == 'atlantic_eurasia':
            return np.array([250, 120])
        
        if region == 'eurasia':
            return np.array([0, 120])
        
        if region == 'pacific':
            return np.array([120, 250])

        raise ValueError("Invalid region '%s'" % region)

    if hemisphere == 'SH':
        
        if region == 'all':
            return np.array([0, 360])
        
        if region == 'australian':
            return np.array([110, 210])
        
        if region == 'full_pacific':
            return np.array([150, 300])
        
        if region == 'indian':
            return np.array([0, 120])
        
        if region == 'pacific':
            return np.array([120, 250])
        
        if region == 'south_america':
            return np.array([240, 360])

        raise ValueError("Invalid region '%s'" % region)
    
    raise ValueError("Invalid hemisphere '%s'" % hemisphere)

In [None]:
def get_region_data(da, hemisphere='WG', region='all', season='ALL',
                    lat_name=LAT_NAME, lon_name=LON_NAME, time_name=TIME_NAME):
    """Get data restricted to given hemisphere, region and season.
    
    Parameters
    ----------
    da : xarray.DataArray
        DataArray containing the field to subset.
        
    hemisphere : 'WG' | 'NH' | 'SH'
        Hemisphere to restrict to.
        
    region : str
        Name of region to restrict data to.
        
    season : 'ALL' | 'DJF' | 'MAM' | 'JJA' | 'SON'
        Season to restrict data to.
        
    lat_name : str
        Name of the latitude coordinate.
        
    lon_name : str
        Name of the longitude coordinate.
        
    time_name : str
        Name of the time coordinate.
    
    Returns
    -------
    region_data : xarray.DataArray
        DataArray containing the input field restricted to the
        requested region and season.
    """

    if hemisphere not in ('WG', 'NH', 'SH'):
        raise ValueError("Invalid hemisphere '%s'" % hemisphere)

    if hemisphere == 'NH':
        da = da.where(da[lat_name] >= 0, drop=True)
    elif hemisphere == 'SH':
        da = da.where(da[lat_name] <= 0, drop=True)

    lon_bounds = get_region_lon_bounds(hemisphere, region)

    if lon_bounds[1] - lon_bounds[0] < 0:
        da = da.where(
            ((da[lon_name] >= lon_bounds[0]) &
             (da[lon_name] <= 360)) |
            ((da[lon_name] >= 0) &
             (da[lon_name] <= lon_bounds[1])), drop=True)
    else:
        da = da.where(
            (da[lon_name] >= lon_bounds[0]) &
            (da[lon_name] <= lon_bounds[1]), drop=True)

    if season != 'ALL':
        da = da.where(da[time_name].dt.season == season, drop=True)

    return da

## Calculation of anomalies

By default, anomalies are computed by removing the daily climatology determined from the specified base period.
No trends are removed beforehand.

In [None]:
# Climatology used for computing anomalies
CLIMATOLOGY_BASE_PERIOD = [np.datetime64('1979-01-01'), np.datetime64('2018-12-31')]

In [None]:
# Whether to write anomalies to file
SAVE_ANOMALIES = True

In [None]:
def get_anomaly_datafile_name(input_datafile, base_period, input_detrended=False):
    """Create filename for anomalies output file.
    
    Parameters
    ----------
    input_datafile : str
        Name of the input datafile.
        
    base_period : list
        List containing the first and late date in the base period used
        to compute the anomalies.
        
    input_detrended : bool
        If True, add suffix to filename indicating anomalies are
        based on detrended data.
    
    Returns
    -------
    filename : str
        Name of the output file for the anomalies.
    """
    
    basename, ext = os.path.splitext(os.path.basename(input_datafile))
    
    base_period_str = '{}_{}'.format(pd.to_datetime(base_period[0]).strftime('%Y%m%d'),
                                     pd.to_datetime(base_period[1]).strftime('%Y%m%d'))
    
    if input_detrended:
        return '.'.join([basename, base_period_str, 'detrended', 'anom']) + ext
    
    return '.'.join([basename, base_period_str, 'anom']) + ext

In [None]:
# Load 500 hPa geopotential height data
hgt_ds = xr.open_dataset(INPUT_HGT_DATAFILE)

hgt_da = hgt_ds[VAR_NAME].astype(np.float64)

# Get data corresponding to climatology base period
base_period_da = hgt_da.where((hgt_da[TIME_NAME] >= CLIMATOLOGY_BASE_PERIOD[0]) &
                              (hgt_da[TIME_NAME] <= CLIMATOLOGY_BASE_PERIOD[1]), drop=True)

# Calculate daily mean climatology
clim_mean_da = base_period_da.groupby(
    base_period_da[TIME_NAME].dt.dayofyear).mean(TIME_NAME)

# Calculate anomalies
hgt_anom_da = hgt_da.groupby(hgt_da[TIME_NAME].dt.dayofyear) - clim_mean_da

anom_output_filename = get_anomaly_datafile_name(INPUT_HGT_DATAFILE, CLIMATOLOGY_BASE_PERIOD)
anom_output_file = os.path.join(ANOM_RESULTS_DIR, anom_output_filename)

if SAVE_ANOMALIES:

    # Write to file
    hgt_anom_ds = hgt_anom_da.to_dataset(name=(VAR_NAME + '_anom'))

    hgt_anom_ds.attrs['input_file'] = INPUT_HGT_DATAFILE
    hgt_anom_ds.attrs['base_period_start'] = pd.to_datetime(CLIMATOLOGY_BASE_PERIOD[0]).strftime('%Y%m%d')
    hgt_anom_ds.attrs['base_period_end'] = pd.to_datetime(CLIMATOLOGY_BASE_PERIOD[1]).strftime('%Y%m%d')

    hgt_anom_ds.to_netcdf(anom_output_file)

hgt_ds.close()

## EOF analysis

In [None]:
def get_latitude_weights(data, weights_str, lat_name=LAT_NAME):
    """Calculate latitude weights.
    
    Parameters
    ----------
    data : xarray.DataArray or xarray.Dataset
        Object containing the data to perform EOF analysis on.
        
    weights_str : 'none' | 'cos' | 'scos'
        String indicating type of weighting to apply:

            - 'none': no latitude weighting.
            - 'cos': weight by cosine of latitude
            - 'scos': weight by square root of cosine of latitude
            
    lat_name : str
        Name of latitude coordinate.
        
    Returns
    -------
    lat_weights : xarray.DataArray
        Array containing the latitude weights.
    """

    # No latitude weighting
    if weights_str == 'none':
        return xr.ones_like(data[lat_name])

    # Weighting by the cosine of the latitude
    if weights_str == 'cos':
        return np.cos(np.deg2rad(data[lat_name]))
    
    # Weighting by the square root of the cosine of the latitude
    if weights_str == 'scos':
        return np.cos(np.deg2rad(data[lat_name])).clip(0., 1.)**0.5
    
    raise ValueError("Invalid weights string '%s'" % weights_str)

In [None]:
def fix_svd_phases(u, vh):
    """Impose fixed phase convention on singular vectors.
    
    Given a set of left- and right-singular vectors as the columns of u
    and rows of vh, respectively, imposes the phase convention that for
    each left-singular vector, the element with largest absolute value
    is real and positive.

    Parameters
    ----------
    u : array, shape (M, K)
        Unitary array containing the left-singular vectors as columns

    vh : array, shape (K, N)
        Unitary array containing the right-singular vectors as rows.

    Returns
    -------
    u_fixed : array, shape (M, K)
        Unitary array containing the left-singular vectors as columns,
        conforming to the chosen phase convention.

    vh_fixed : array, shape (K, N)
        Unitary array containing the right-singular vectors as rows,
        conforming to the chosen phase convention.
    """

    n_cols = u.shape[1]
    max_elem_rows = np.argmax(np.abs(u), axis=0)

    if np.any(np.iscomplexobj(u)):
        phases = np.exp(-1j * np.angle(u[max_elem_rows, range(n_cols)]))
    else:
        phases = np.sign(u[max_elem_rows, range(n_cols)])

    u *= phases
    vh *= phases[:, np.newaxis]

    return u, vh

In [None]:
# Default maximum number of EOFs to retain
DEFAULT_MAX_EOFS = 200

def run_pca(da, sample_dim=TIME_NAME, weights=None, normalization='unit', tolerance=1e-6,
            center=True, bias=False, n_components=DEFAULT_MAX_EOFS, eofs_base_period=None,
            algorithm='arpack', tol=0, n_iter=10, random_state=None):
    """Run EOF analysis on data.
    
    Parameters
    ----------
    da : xarray.DataArray
        DataArray containing the field to perform EOF analysis on.
        
    sample_dim : str
        Dimension labelling separate samples.
        
    weights : xarray.DataArray, optional
        If given, weights to apply to data before performing EOF analysis.
        
    normalization : 'unit' | 'sqrt_variance' | 'inverse_sqrt_variance'
        Normalization convention to use for EOFs.
        
    tolerance : float, default: 1e-6
        Tolerance used for checking data is centered.
        
    center : bool, default: True
        If True, center the data by removing the feature means.
        
    bias : bool, default: False
        If True, calculate estimated variance with biased number of
        degrees of freedom.
        
    n_components : int, optional
        Number of EOFs to retain.
        
    eofs_base_period : list, optional
        List of length 2 containing the first and last date to use in
        calculating the EOFs.
        
    random_state : integer, RandomState or None
        If an integer, random_state is the seed used by the
        random number generator. If a RandomState instance,
        random_state is the random number generator. If None,
        the random number generator is the RandomState instance
        used by `np.random`.
    
    Returns
    -------
    eofs_ds : xarray.Dataset
        Dataset containing the results of the EOF analysis, with
        the data variables:
        
            - 'eofs': calculated EOFs
            - 'principal_components': time-series of principal components
            - 'explained_variance_ratio': fraction of variance explained by
               each mode
            - 'singular_values': singular values for each mode
            - 'means': mean of each feature over the base period
    """

    # Determine shape of individual samples for later reshaping
    feature_dims = [d for d in da.dims if d != sample_dim]
    original_shape = [da.sizes[d] for d in da.dims if d != sample_dim]

    # Weight data if any weights are given
    if weights is not None:
        input_da = weights * da
        input_da = input_da.transpose(*da.dims)
    else:
        input_da = da

    # Ensure sample dimension is first dimension
    if input_da.get_axis_num(sample_dim) != 0:
        input_da = input_da.transpose(*([sample_dim] + feature_dims))

    # If no base period is given, use all data for computing the EOFs
    if eofs_base_period is None:
        eofs_base_period = [input_da[sample_dim].min().item(),
                            input_da[sample_dim].max().item()]

    # Determine the number of samples and number of features
    n_samples = input_da.sizes[sample_dim]
    n_features = np.product(original_shape)

    # Mask out missing values
    values = input_da.values
    flat_values = values.reshape((n_samples, n_features))

    missing_features = np.any(np.isnan(flat_values), axis=0)
    valid_values = flat_values[:, np.logical_not(missing_features)]

    # Select data in the requested period to use for computing the EOFs
    base_period_mask = np.logical_and(
        input_da[sample_dim].data >= eofs_base_period[0],
        input_da[sample_dim].data <= eofs_base_period[1])

    base_period_values = valid_values[base_period_mask]

    # If centering the data, compute the feature means and subtract them
    base_period_means = np.mean(base_period_values, axis=0)
    if center:
        base_period_values = base_period_values - base_period_means[np.newaxis, :]

    # Double check that inputs are centered
    input_means = np.mean(base_period_values, axis=0)

    if np.any(np.abs(input_means) > tolerance):
        warnings.warn(
            'Input data does not have zero column means '
            '(got max(abs(means))=%.3e)' % np.max(np.abs(input_means)),
            UserWarning)

    # Calculate truncated SVD
    svd = TruncatedSVD(n_components=n_components, algorithm=algorithm, n_iter=n_iter,
                       tol=tol, random_state=random_state)

    u = svd.fit_transform(base_period_values)
    u = u / svd.singular_values_
    s = svd.singular_values_
    vh = svd.components_

    # Project to get PCs over full sampling period
    u = np.dot(valid_values, vh.T) / svd.singular_values_

    # Fix phase convention for the EOFs and PCs
    u, vh = fix_svd_phases(u, vh)

    # Calculate variance explained by each mode
    ddof = 0 if bias else 1
    ddof_factor = base_period_values.shape[0] - ddof

    eigenvalues = s ** 2 / ddof_factor
    variances = np.var(base_period_values, ddof=ddof, axis=0)
    total_variance = np.sum(variances)
    explained_variance_ratio = eigenvalues / total_variance

    # Apply normalization convention for EOFs
    if normalization == 'unit':
        pcs = np.dot(u, np.diag(s))
        eofs = vh
    elif normalization == 'sqrt_variance':
        pcs = u * np.sqrt(ddof_factor)
        eofs = np.dot(np.diag(np.sqrt(eigenvalues)), vh)
    elif normalization == 'inverse_sqrt_variance':
        pcs = np.dot(u, np.diag(eigenvalues)) * np.sqrt(ddof_factor)
        eofs = np.dot(np.diag(1 / np.sqrt(eigenvalues)), vh)
    else:
        raise ValueError("Invalid normalization '%s'" % normalization)

    # Construct arrays containing computed EOFs and PCs
    full_eofs = np.full((n_components, n_features), np.NaN, dtype=flat_values.dtype)
    full_eofs[:, np.logical_not(missing_features)] = eofs

    full_eofs = full_eofs.reshape([n_components,] + original_shape)
    eof_coords = {d: da.coords[d] for d in feature_dims}
    eof_coords['component'] = np.arange(n_components)
    eof_dims = ['component'] + feature_dims

    feature_means = np.full((n_features,), np.NaN, dtype=flat_values.dtype)
    feature_means[np.logical_not(missing_features)] = base_period_means

    feature_means = feature_means.reshape(original_shape)

    pcs_da = xr.DataArray(
        pcs,
        coords={sample_dim: da[sample_dim],
                'component': np.arange(n_components)},
        dims=[sample_dim, 'component'], name='principal_components')
    eofs_da = xr.DataArray(
        full_eofs, coords=eof_coords, dims=eof_dims, name='components')
    explained_variance_ratio_da = xr.DataArray(
        explained_variance_ratio,
        coords={'component': np.arange(n_components)},
        dims=['component'], name='explained_variance_ratio')
    singular_values_da = xr.DataArray(
        s, coords={'component': np.arange(n_components)},
        dims=['component'], name='singular_values')
    feature_means_da = xr.DataArray(
        feature_means, coords={d: da.coords[d] for d in feature_dims},
        dims=feature_dims, name='mean')

    data_vars = {'eofs': eofs_da,
                 'principal_components': pcs_da,
                 'explained_variance_ratio': explained_variance_ratio_da,
                 'singular_values': singular_values_da,
                 'means': feature_means_da}

    out_coords = dict(c for c in da.coords.items())
    out_coords['component'] = np.arange(n_components)

    return xr.Dataset(data_vars, coords=out_coords)

In [None]:
def get_eofs_output_filename(input_filename, hemisphere, region, eofs_base_period,
                             season, lat_weights, max_eofs, normalization):
    """Get filename for saving EOFs.
    
    Parameters
    ----------
    input_datafile : str
        Name of the input datafile.

    hemisphere : 'WG' | 'NH' | 'SH'
        Hemisphere to perform EOF analysis on.
        
    region : str
        Name of region to perform EOF analysis on.
        
    eofs_base_period : list
        List of length 2 containing the first and last year used
        for computing the EOFs.
        
    season : 'ALL' | 'DJF' | 'MAM' | 'JJA' | 'SON'
        Season to perform EOF analysis in.
        
    lat_weights : str
        String indicating the type of latitude weights used.
        
    max_eofs : int
        Number of EOFs retained.
        
    normalization : str
        Normalization convention used for the EOFs.

    Returns
    -------
    filename : str
        Name of the output file for the EOFs.
    """
    
    basename, ext = os.path.splitext(os.path.basename(input_filename))
    
    eofs_suffix = '{}.{}.{}_{}.{}.max_eofs_{:d}.{}.{}.eofs'.format(
        hemisphere, region, pd.to_datetime(eofs_base_period[0]).strftime('%Y%m%d'),
        pd.to_datetime(eofs_base_period[1]).strftime('%Y%m%d'), season,
        max_eofs, lat_weights, normalization)
    
    return '.'.join([basename, eofs_suffix]) + ext

In [None]:
# Hemispheres to compute EOFs for
HEMISPHERES = ['WG', 'NH', 'SH']

# Seasons to compute EOFs for
SEASONS = ['ALL', 'DJF', 'MAM', 'JJA', 'SON']

# Time period to use for calculating the EOFs
EOFS_BASE_PERIOD = CLIMATOLOGY_BASE_PERIOD

# Latitude weights to apply
LAT_WEIGHTS = 'scos'

# Normalization convention used for EOFs
EOFS_NORMALIZATION = 'unit'

# Whether to write the EOFs to file
SAVE_EOFS = True

# Random seed for truncated SVD
RANDOM_SEED = 0

for hemisphere in HEMISPHERES:
    
    print('* Hemisphere: ', hemisphere)

    regions = get_regions(hemisphere)
    
    for region in regions:
        
        print('\t- Region:', region)

        for season in SEASONS:

            print('\t\t+ Season: ', season, end='')

            start_time = time.perf_counter()

            anom_da = get_region_data(hgt_anom_da, hemisphere=hemisphere,
                                      region=region, season=season)
                        
            lat_weights = get_latitude_weights(anom_da, weights_str=LAT_WEIGHTS)

            eofs_ds = run_pca(anom_da, weights=lat_weights,
                              n_components=DEFAULT_MAX_EOFS,
                              normalization=EOFS_NORMALIZATION,
                              eofs_base_period=EOFS_BASE_PERIOD,
                              random_state=RANDOM_SEED)
            
            eofs_ds.attrs['input_file'] = anom_output_file
            eofs_ds.attrs['eofs_base_period_start'] = pd.to_datetime(EOFS_BASE_PERIOD[0]).strftime('%Y%m%d')
            eofs_ds.attrs['eofs_base_period_end'] = pd.to_datetime(EOFS_BASE_PERIOD[1]).strftime('%Y%m%d')
            eofs_ds.attrs['eofs_normalization'] = EOFS_NORMALIZATION
            eofs_ds.attrs['lat_weights'] = LAT_WEIGHTS
            
            eofs_output_filename = get_eofs_output_filename(
                anom_output_file, hemisphere, region, EOFS_BASE_PERIOD,
                season, LAT_WEIGHTS, DEFAULT_MAX_EOFS, EOFS_NORMALIZATION)
            
            eofs_output_file = os.path.join(EOFS_NC_DIR, eofs_output_filename)
            
            if SAVE_EOFS:
                eofs_ds.to_netcdf(eofs_output_file)
                
            end_time = time.perf_counter()
            
            print(' ({:.2f} seconds)'.format(end_time - start_time))

## Plots

In [None]:
# Whether to write plots to file
SAVE_PLOTS = True

In [None]:
def get_eof_plot_output_filename(eofs_datafile, n_eofs_to_plot, ext='pdf'):
    """Get filename to save plots of EOFs with.
    
    Parameters
    ----------
    eofs_datafile : str
        Name of the datafile containing the EOFs to plot.
        
    n_eofs_to_plot : int
        Number of EOFs to plot.
        
    ext : str, default: 'pdf'
        File extension.

    Returns
    -------
    filename : str
        Name of the output file for the plot.
    """
    
    basename, _ = os.path.splitext(os.path.basename(eofs_datafile))

    suffix = 'k{}.eofs.{}'.format(n_eofs_to_plot, ext)
    
    return '.'.join([basename, suffix])

### Whole globe

In [None]:
def plot_global_eofs(eofs_ds, season, n_eofs_to_plot=20, cmap=plt.cm.RdBu_r,
                     lat_name=LAT_NAME, lon_name=LON_NAME):
    """Plot global EOFs.
    
    Parameters
    ----------
    eofs_ds : xarray.Dataset
        Dataset containing the results of the EOF analysis.
        
    season : 'ALL' | 'DJF' | 'MAM' | 'JJA' | 'SON'
        Season to plot EOFs for.
        
    n_eofs_to_plot : int, default: 20
        Number of EOF patterns to plot.
        
    cmap : object
        Colormap to use.
        
    lat_name : str
        Name of latitude coordinate.
        
    lon_name : str
        Name of longitude coordinate.

    Returns
    -------
    fig : figure
        Plot of EOFs.
    """

    hemisphere = 'WG'
    region = 'all'

    wrap_lon = True

    projection = ccrs.PlateCarree(central_longitude=0)

    components = eofs_ds['component'].isel(component=slice(0, n_eofs_to_plot)).values

    component_vmins = np.empty(n_eofs_to_plot)
    vmin = None
    for i, component in enumerate(components):
        component_vmin = eofs_ds['eofs'].sel(component=component).min().item()
        if vmin is None or component_vmin < vmin:
            vmin = component_vmin

    component_vmins[:] = vmin

    component_vmaxs = np.empty(n_eofs_to_plot)
    vmax = None
    for i, component in enumerate(components):
        component_vmax = eofs_ds['eofs'].sel(component=component).max().item()
        if vmax is None or component_vmax > vmax:
            vmax = component_vmax

    component_vmaxs[:] = vmax

    if n_eofs_to_plot % 4 == 0:
        ncols = 4
    elif n_eofs_to_plot % 2 == 0:
        ncols = 2
    else:
        ncols = 3

    nrows = int(np.ceil(n_eofs_to_plot / ncols))
    height_ratios = np.ones((nrows + 1))
    height_ratios[-1] = 0.1

    fig = plt.figure(constrained_layout=False, figsize=(6 * ncols, 3 * nrows))

    gs = gridspec.GridSpec(ncols=ncols, nrows=nrows + 1, figure=fig,
                           wspace=0.05, hspace=0.2,
                           height_ratios=height_ratios)

    lat = eofs_ds[lat_name]
    lon = eofs_ds[lon_name]

    row_index = 0
    col_index = 0
    for i, component in enumerate(components):
        eof_data = eofs_ds['eofs'].sel(component=component).squeeze().values
        expl_var = eofs_ds['explained_variance_ratio'].sel(
            component=component).item()

        if wrap_lon:
            eof_data, eof_lon = add_cyclic_point(eof_data, coord=lon)

        lon_grid, lat_grid = np.meshgrid(eof_lon, lat)

        ax = fig.add_subplot(gs[row_index, col_index], projection=projection)

        ax.coastlines()
        ax.set_global()

        ax_vmin = component_vmins[i]
        ax_vmax = component_vmaxs[i]

        cs = ax.pcolor(lon_grid, lat_grid, eof_data, vmin=ax_vmin, vmax=ax_vmax,
                       cmap=cmap, transform=ccrs.PlateCarree())

        ax.set_title('EOF {} ({:.2f}%)'.format(
            component + 1, expl_var * 100), fontsize=14)
        ax.set_aspect('auto')
        fig.canvas.draw()

        col_index += 1
        if col_index == ncols:
            col_index = 0
            row_index += 1

    cb_ax = fig.add_subplot(gs[-1, :])
    cb = fig.colorbar(cs, cax=cb_ax, pad=0.05, orientation='horizontal')

    fig.suptitle('Hemisphere: {}, Region: {}, Season: {}'.format(hemisphere, region, season),
                 fontsize=16, x=0.5, y=0.92)

    return fig

In [None]:
hemisphere = 'WG'
region = 'all'
seasons = ['ALL', 'DJF', 'MAM', 'JJA', 'SON']

n_eofs_to_plot = 20

for season in seasons:
    
    eofs_datafile = os.path.join(
        EOFS_NC_DIR,
        get_eofs_output_filename(
            anom_output_file, hemisphere, region, EOFS_BASE_PERIOD,
            season, LAT_WEIGHTS, DEFAULT_MAX_EOFS, EOFS_NORMALIZATION))

    eofs_ds = xr.open_dataset(eofs_datafile)

    fig = plot_global_eofs(eofs_ds, season, n_eofs_to_plot=n_eofs_to_plot)
    
    eofs_ds.close()
    
    if SAVE_PLOTS:

        output_filename = get_eof_plot_output_filename(eofs_datafile, n_eofs_to_plot)
        output_file = os.path.join(EOFS_PLOTS_DIR, output_filename)

        plt.savefig(output_file, bbox_inches='tight')
        
    plt.show()
    
    plt.close()

### Northern Hemisphere

In [None]:
def plot_nh_eofs(eofs_ds, season, n_eofs_to_plot=20, cmap=plt.cm.RdBu_r,
                 lat_name=LAT_NAME, lon_name=LON_NAME):
    """Plot NH EOFs.
    
    Parameters
    ----------
    eofs_ds : xarray.Dataset
        Dataset containing the results of the EOF analysis.
        
    season : 'ALL' | 'DJF' | 'MAM' | 'JJA' | 'SON'
        Season to plot EOFs for.
        
    n_eofs_to_plot : int, default: 20
        Number of EOF patterns to plot.
        
    cmap : object
        Colormap to use.
        
    lat_name : str
        Name of latitude coordinate.
        
    lon_name : str
        Name of longitude coordinate.

    Returns
    -------
    fig : figure
        Plot of EOFs.
    """
    
    hemisphere = 'NH'
    region = 'all'

    wrap_lon = True

    projection = ccrs.Orthographic(central_longitude=0, central_latitude=90)

    components = eofs_ds['component'].isel(component=slice(0, n_eofs_to_plot)).values

    component_vmins = np.empty(n_eofs_to_plot)
    vmin = None
    for i, component in enumerate(components):
        component_vmin = eofs_ds['eofs'].sel(component=component).min().item()
        if vmin is None or component_vmin < vmin:
            vmin = component_vmin

    component_vmins[:] = vmin

    component_vmaxs = np.empty(n_eofs_to_plot)
    vmax = None
    for i, component in enumerate(components):
        component_vmax = eofs_ds['eofs'].sel(component=component).max().item()
        if vmax is None or component_vmax > vmax:
            vmax = component_vmax

    component_vmaxs[:] = vmax

    if n_eofs_to_plot % 4 == 0:
        ncols = 4
    elif n_eofs_to_plot % 2 == 0:
        ncols = 2
    else:
        ncols = 3

    nrows = int(np.ceil(n_eofs_to_plot / ncols))
    height_ratios = np.ones((nrows + 1))
    height_ratios[-1] = 0.1

    fig = plt.figure(constrained_layout=False, figsize=(4 * ncols, 4 * nrows))

    gs = gridspec.GridSpec(ncols=ncols, nrows=nrows + 1, figure=fig,
                           wspace=0.01, hspace=0.2,
                           height_ratios=height_ratios)

    lat = eofs_ds[lat_name]
    lon = eofs_ds[lon_name]

    row_index = 0
    col_index = 0
    for i, component in enumerate(components):
        eof_data = eofs_ds['eofs'].sel(component=component).squeeze().values
        expl_var = eofs_ds['explained_variance_ratio'].sel(
            component=component).item()

        if wrap_lon:
            eof_data, eof_lon = add_cyclic_point(eof_data, coord=lon)

        lon_grid, lat_grid = np.meshgrid(eof_lon, lat)

        ax = fig.add_subplot(gs[row_index, col_index], projection=projection)

        ax.coastlines()
        ax.set_global()

        ax_vmin = component_vmins[i]
        ax_vmax = component_vmaxs[i]

        cs = ax.pcolor(lon_grid, lat_grid, eof_data, vmin=ax_vmin, vmax=ax_vmax,
                       cmap=cmap, transform=ccrs.PlateCarree())

        ax.set_title('EOF {} ({:.2f}%)'.format(
            component + 1, expl_var * 100), fontsize=14)
        ax.set_aspect('equal')
        fig.canvas.draw()

        col_index += 1
        if col_index == ncols:
            col_index = 0
            row_index += 1

    cb_ax = fig.add_subplot(gs[-1, :])
    cb = fig.colorbar(cs, cax=cb_ax, pad=0.05, orientation='horizontal')

    fig.suptitle('Hemisphere: {}, Region: {}, Season: {}'.format(hemisphere, region, season),
                 fontsize=16, x=0.5, y=0.92)

    return fig

In [None]:
hemisphere = 'NH'
region = 'all'
seasons = ['ALL', 'DJF', 'MAM', 'JJA', 'SON']

n_eofs_to_plot = 20

for season in seasons:
    
    eofs_datafile = os.path.join(
        EOFS_NC_DIR,
        get_eofs_output_filename(
            anom_output_file, hemisphere, region, EOFS_BASE_PERIOD,
            season, LAT_WEIGHTS, DEFAULT_MAX_EOFS, EOFS_NORMALIZATION))

    eofs_ds = xr.open_dataset(eofs_datafile)

    fig = plot_nh_eofs(eofs_ds, season, n_eofs_to_plot=n_eofs_to_plot)
    
    eofs_ds.close()
    
    if SAVE_PLOTS:

        output_filename = get_eof_plot_output_filename(eofs_datafile, n_eofs_to_plot)
        output_file = os.path.join(EOFS_PLOTS_DIR, output_filename)

        plt.savefig(output_file, bbox_inches='tight')
        
    plt.show()
    
    plt.close()

In [None]:
def plot_nh_atlantic_eofs(eofs_ds, season, n_eofs_to_plot=20, cmap=plt.cm.RdBu_r,
                          lat_name=LAT_NAME, lon_name=LON_NAME):
    """Plot NH Atlantic region EOFs.
    
    Parameters
    ----------
    eofs_ds : xarray.Dataset
        Dataset containing the results of the EOF analysis.
        
    season : 'ALL' | 'DJF' | 'MAM' | 'JJA' | 'SON'
        Season to plot EOFs for.
        
    n_eofs_to_plot : int, default: 20
        Number of EOF patterns to plot.
        
    cmap : object
        Colormap to use.
        
    lat_name : str
        Name of latitude coordinate.
        
    lon_name : str
        Name of longitude coordinate.

    Returns
    -------
    fig : figure
        Plot of EOFs.
    """

    hemisphere = 'NH'
    region = 'atlantic'

    lon_bounds = get_region_lon_bounds(hemisphere, region)
    lat_bounds = [0.0, 90.0]

    wrap_lon = False

    projection = ccrs.EquidistantConic(central_longitude=np.mean(lon_bounds), central_latitude=90)

    components = eofs_ds['component'].isel(component=slice(0, n_eofs_to_plot)).values

    component_vmins = np.empty(n_eofs_to_plot)
    vmin = None
    for i, component in enumerate(components):
        component_vmin = eofs_ds['eofs'].sel(component=component).min().item()
        if vmin is None or component_vmin < vmin:
            vmin = component_vmin

    component_vmins[:] = vmin

    component_vmaxs = np.empty(n_eofs_to_plot)
    vmax = None
    for i, component in enumerate(components):
        component_vmax = eofs_ds['eofs'].sel(component=component).max().item()
        if vmax is None or component_vmax > vmax:
            vmax = component_vmax

    component_vmaxs[:] = vmax

    if n_eofs_to_plot % 4 == 0:
        ncols = 4
    elif n_eofs_to_plot % 2 == 0:
        ncols = 2
    else:
        ncols = 3

    nrows = int(np.ceil(n_eofs_to_plot / ncols))
    height_ratios = np.ones((nrows + 1))
    height_ratios[-1] = 0.1

    fig = plt.figure(constrained_layout=False, figsize=(4 * ncols, 4 * nrows))

    gs = gridspec.GridSpec(ncols=ncols, nrows=nrows + 1, figure=fig,
                           wspace=0.1, hspace=0.2,
                           height_ratios=height_ratios)

    lat = eofs_ds[lat_name].values
    lon = eofs_ds[lon_name].values

    row_index = 0
    col_index = 0
    for i, component in enumerate(components):
        eof_data = eofs_ds['eofs'].sel(component=component).squeeze().values
        expl_var = eofs_ds['explained_variance_ratio'].sel(
            component=component).item()

        if wrap_lon:
            eof_data, eof_lon = add_cyclic_point(eof_data, coord=lon)
        else:
            eof_lon = lon

        lon_grid, lat_grid = np.meshgrid(eof_lon, lat)

        ax = fig.add_subplot(gs[row_index, col_index], projection=projection)

        ax.coastlines()

        left_bound_verts = np.vstack([np.repeat(lon_bounds[0], 100), np.flipud(np.linspace(0, 90.0, 100))]).T
        right_bound_verts = np.vstack([np.repeat(lon_bounds[1], 100), np.linspace(0, 90.0, 100)]).T
        
        lower_bound_verts = np.vstack([np.linspace(lon_bounds[0], lon_bounds[1], 100),
                                       np.repeat(0.0, 100)]).T
        upper_bound_verts = np.vstack([np.flipud(np.linspace(lon_bounds[0], lon_bounds[1], 100)),
                                       np.repeat(90.0, 100)]).T

        boundary = mpath.Path(np.vstack([left_bound_verts, lower_bound_verts, right_bound_verts, upper_bound_verts]))

        ax.set_boundary(boundary, transform=ccrs.PlateCarree())
        ax.set_extent(np.concatenate([lon_bounds, lat_bounds]), crs=ccrs.PlateCarree())

        ax_vmin = component_vmins[i]
        ax_vmax = component_vmaxs[i]

        cs = ax.pcolor(lon_grid, lat_grid, eof_data, vmin=ax_vmin, vmax=ax_vmax,
                       cmap=cmap, transform=ccrs.PlateCarree())

        ax.set_title('EOF {} ({:.2f}%)'.format(
            component + 1, expl_var * 100), fontsize=14)

        fig.canvas.draw()

        col_index += 1
        if col_index == ncols:
            col_index = 0
            row_index += 1

    cb_ax = fig.add_subplot(gs[-1, :])
    cb = fig.colorbar(cs, cax=cb_ax, pad=0.05, orientation='horizontal')

    fig.suptitle('Hemisphere: {}, Region: {}, Season: {}'.format(hemisphere, region, season),
                 fontsize=16, x=0.5, y=0.92)

    return fig

In [None]:
hemisphere = 'NH'
region = 'atlantic'
seasons = ['ALL', 'DJF', 'MAM', 'JJA', 'SON']

n_eofs_to_plot = 20

for season in seasons:
    
    eofs_datafile = os.path.join(
        EOFS_NC_DIR,
        get_eofs_output_filename(
            anom_output_file, hemisphere, region, EOFS_BASE_PERIOD,
            season, LAT_WEIGHTS, DEFAULT_MAX_EOFS, EOFS_NORMALIZATION))

    eofs_ds = xr.open_dataset(eofs_datafile)

    fig = plot_nh_atlantic_eofs(eofs_ds, season, n_eofs_to_plot=n_eofs_to_plot)
    
    eofs_ds.close()
    
    if SAVE_PLOTS:

        output_filename = get_eof_plot_output_filename(eofs_datafile, n_eofs_to_plot)
        output_file = os.path.join(EOFS_PLOTS_DIR, output_filename)

        plt.savefig(output_file, bbox_inches='tight')
        
    plt.show()
    
    plt.close()

In [None]:
def plot_nh_atlantic2_eofs(eofs_ds, season, n_eofs_to_plot=20, cmap=plt.cm.RdBu_r,
                           lat_name=LAT_NAME, lon_name=LON_NAME):
    """Plot NH Atlantic-2 region EOFs.
    
    Parameters
    ----------
    eofs_ds : xarray.Dataset
        Dataset containing the results of the EOF analysis.
        
    season : 'ALL' | 'DJF' | 'MAM' | 'JJA' | 'SON'
        Season to plot EOFs for.
        
    n_eofs_to_plot : int, default: 20
        Number of EOF patterns to plot.
        
    cmap : object
        Colormap to use.
        
    lat_name : str
        Name of latitude coordinate.
        
    lon_name : str
        Name of longitude coordinate.

    Returns
    -------
    fig : figure
        Plot of EOFs.
    """

    hemisphere = 'NH'
    region = 'atlantic2'

    lon_bounds = get_region_lon_bounds(hemisphere, region)
    lat_bounds = [0.0, 90.0]

    wrap_lon = False

    projection = ccrs.EquidistantConic(central_longitude=np.mean(lon_bounds), central_latitude=90)

    components = eofs_ds['component'].isel(component=slice(0, n_eofs_to_plot)).values

    component_vmins = np.empty(n_eofs_to_plot)
    vmin = None
    for i, component in enumerate(components):
        component_vmin = eofs_ds['eofs'].sel(component=component).min().item()
        if vmin is None or component_vmin < vmin:
            vmin = component_vmin

    component_vmins[:] = vmin

    component_vmaxs = np.empty(n_eofs_to_plot)
    vmax = None
    for i, component in enumerate(components):
        component_vmax = eofs_ds['eofs'].sel(component=component).max().item()
        if vmax is None or component_vmax > vmax:
            vmax = component_vmax

    component_vmaxs[:] = vmax

    if n_eofs_to_plot % 4 == 0:
        ncols = 4
    elif n_eofs_to_plot % 2 == 0:
        ncols = 2
    else:
        ncols = 3

    nrows = int(np.ceil(n_eofs_to_plot / ncols))
    height_ratios = np.ones((nrows + 1))
    height_ratios[-1] = 0.1

    fig = plt.figure(constrained_layout=False, figsize=(4 * ncols, 4 * nrows))

    gs = gridspec.GridSpec(ncols=ncols, nrows=nrows + 1, figure=fig,
                           wspace=0.1, hspace=0.2,
                           height_ratios=height_ratios)

    lat = eofs_ds[lat_name].values
    lon = eofs_ds[lon_name].values

    row_index = 0
    col_index = 0
    for i, component in enumerate(components):
        eof_data = eofs_ds['eofs'].sel(component=component).squeeze().values
        expl_var = eofs_ds['explained_variance_ratio'].sel(
            component=component).item()

        if wrap_lon:
            eof_data, eof_lon = add_cyclic_point(eof_data, coord=lon)
        else:
            eof_lon = lon

        lon_grid, lat_grid = np.meshgrid(eof_lon, lat)

        ax = fig.add_subplot(gs[row_index, col_index], projection=projection)

        ax.coastlines()

        left_bound_verts = np.vstack([np.repeat(lon_bounds[0], 100), np.flipud(np.linspace(0, 90.0, 100))]).T
        right_bound_verts = np.vstack([np.repeat(lon_bounds[1], 100), np.linspace(0, 90.0, 100)]).T
        
        lower_bound_verts = np.vstack([np.linspace(lon_bounds[0], lon_bounds[1], 100),
                                       np.repeat(0.0, 100)]).T
        upper_bound_verts = np.vstack([np.flipud(np.linspace(lon_bounds[0], lon_bounds[1], 100)),
                                       np.repeat(90.0, 100)]).T

        boundary = mpath.Path(np.vstack([left_bound_verts, lower_bound_verts, right_bound_verts, upper_bound_verts]))

        ax.set_boundary(boundary, transform=ccrs.PlateCarree())
        ax.set_extent(np.concatenate([lon_bounds, lat_bounds]), crs=ccrs.PlateCarree())

        ax_vmin = component_vmins[i]
        ax_vmax = component_vmaxs[i]

        cs = ax.pcolor(lon_grid, lat_grid, eof_data, vmin=ax_vmin, vmax=ax_vmax,
                       cmap=cmap, transform=ccrs.PlateCarree())

        ax.set_title('EOF {} ({:.2f}%)'.format(
            component + 1, expl_var * 100), fontsize=14)

        fig.canvas.draw()

        col_index += 1
        if col_index == ncols:
            col_index = 0
            row_index += 1

    cb_ax = fig.add_subplot(gs[-1, :])
    cb = fig.colorbar(cs, cax=cb_ax, pad=0.05, orientation='horizontal')

    fig.suptitle('Hemisphere: {}, Region: {}, Season: {}'.format(hemisphere, region, season),
                 fontsize=16, x=0.5, y=0.92)

    return fig

In [None]:
hemisphere = 'NH'
region = 'atlantic2'
seasons = ['ALL', 'DJF', 'MAM', 'JJA', 'SON']

n_eofs_to_plot = 20

for season in seasons:
    
    eofs_datafile = os.path.join(
        EOFS_NC_DIR,
        get_eofs_output_filename(
            anom_output_file, hemisphere, region, EOFS_BASE_PERIOD,
            season, LAT_WEIGHTS, DEFAULT_MAX_EOFS, EOFS_NORMALIZATION))

    eofs_ds = xr.open_dataset(eofs_datafile)

    fig = plot_nh_atlantic2_eofs(eofs_ds, season, n_eofs_to_plot=n_eofs_to_plot)
    
    eofs_ds.close()
    
    if SAVE_PLOTS:

        output_filename = get_eof_plot_output_filename(eofs_datafile, n_eofs_to_plot)
        output_file = os.path.join(EOFS_PLOTS_DIR, output_filename)

        plt.savefig(output_file, bbox_inches='tight')
        
    plt.show()
    
    plt.close()

In [None]:
def plot_nh_atlantic3_eofs(eofs_ds, season, n_eofs_to_plot=20, cmap=plt.cm.RdBu_r,
                           lat_name=LAT_NAME, lon_name=LON_NAME):
    """Plot NH Atlantic-3 region EOFs.
    
    Parameters
    ----------
    eofs_ds : xarray.Dataset
        Dataset containing the results of the EOF analysis.
        
    season : 'ALL' | 'DJF' | 'MAM' | 'JJA' | 'SON'
        Season to plot EOFs for.
        
    n_eofs_to_plot : int, default: 20
        Number of EOF patterns to plot.
        
    cmap : object
        Colormap to use.
        
    lat_name : str
        Name of latitude coordinate.
        
    lon_name : str
        Name of longitude coordinate.

    Returns
    -------
    fig : figure
        Plot of EOFs.
    """

    hemisphere = 'NH'
    region = 'atlantic3'

    lon_bounds = get_region_lon_bounds(hemisphere, region)
    lat_bounds = [0.0, 90.0]

    wrap_lon = False

    central_longitude = 0.5 * (lon_bounds[1] % 180 + lon_bounds[0] % 180) % 360
    projection = ccrs.EquidistantConic(central_longitude=central_longitude, central_latitude=90)

    components = eofs_ds['component'].isel(component=slice(0, n_eofs_to_plot)).values

    component_vmins = np.empty(n_eofs_to_plot)
    vmin = None
    for i, component in enumerate(components):
        component_vmin = eofs_ds['eofs'].sel(component=component).min().item()
        if vmin is None or component_vmin < vmin:
            vmin = component_vmin

    component_vmins[:] = vmin

    component_vmaxs = np.empty(n_eofs_to_plot)
    vmax = None
    for i, component in enumerate(components):
        component_vmax = eofs_ds['eofs'].sel(component=component).max().item()
        if vmax is None or component_vmax > vmax:
            vmax = component_vmax

    component_vmaxs[:] = vmax

    if n_eofs_to_plot % 4 == 0:
        ncols = 4
    elif n_eofs_to_plot % 2 == 0:
        ncols = 2
    else:
        ncols = 3

    nrows = int(np.ceil(n_eofs_to_plot / ncols))
    height_ratios = np.ones((nrows + 1))
    height_ratios[-1] = 0.1

    fig = plt.figure(constrained_layout=False, figsize=(4 * ncols, 4 * nrows))

    gs = gridspec.GridSpec(ncols=ncols, nrows=nrows + 1, figure=fig,
                           wspace=0.1, hspace=0.2,
                           height_ratios=height_ratios)

    lat = eofs_ds[lat_name].values
    lon = eofs_ds[lon_name].values

    row_index = 0
    col_index = 0
    for i, component in enumerate(components):
        eof_data = eofs_ds['eofs'].sel(component=component).squeeze().values
        expl_var = eofs_ds['explained_variance_ratio'].sel(
            component=component).item()

        if wrap_lon:
            eof_data, eof_lon = add_cyclic_point(eof_data, coord=lon)
        else:
            eof_lon = lon

        lon_grid, lat_grid = np.meshgrid(eof_lon, lat)

        ax = fig.add_subplot(gs[row_index, col_index], projection=projection)

        ax.coastlines()

        left_bound_verts = np.vstack([np.repeat(lon_bounds[0], 100), np.flipud(np.linspace(0, 90.0, 100))]).T
        right_bound_verts = np.vstack([np.repeat(lon_bounds[1], 100), np.linspace(0, 90.0, 100)]).T
        
        lower_bound_verts = np.vstack([np.linspace(lon_bounds[0], lon_bounds[1], 100),
                                       np.repeat(0.0, 100)]).T
        upper_bound_verts = np.vstack([np.flipud(np.linspace(lon_bounds[0], lon_bounds[1], 100)),
                                       np.repeat(90.0, 100)]).T

        boundary = mpath.Path(np.vstack([left_bound_verts, lower_bound_verts, right_bound_verts, upper_bound_verts]))

        ax.set_boundary(boundary, transform=ccrs.PlateCarree())
        ax.set_extent(np.concatenate([lon_bounds, lat_bounds]), crs=ccrs.PlateCarree())

        ax_vmin = component_vmins[i]
        ax_vmax = component_vmaxs[i]

        cs = ax.pcolor(lon_grid, lat_grid, eof_data, vmin=ax_vmin, vmax=ax_vmax,
                       cmap=cmap, transform=ccrs.PlateCarree())

        ax.set_title('EOF {} ({:.2f}%)'.format(
            component + 1, expl_var * 100), fontsize=14)

        fig.canvas.draw()

        col_index += 1
        if col_index == ncols:
            col_index = 0
            row_index += 1

    cb_ax = fig.add_subplot(gs[-1, :])
    cb = fig.colorbar(cs, cax=cb_ax, pad=0.05, orientation='horizontal')

    fig.suptitle('Hemisphere: {}, Region: {}, Season: {}'.format(hemisphere, region, season),
                 fontsize=16, x=0.5, y=0.92)

    return fig

In [None]:
hemisphere = 'NH'
region = 'atlantic3'
seasons = ['ALL', 'DJF', 'MAM', 'JJA', 'SON']

n_eofs_to_plot = 20

for season in seasons:
    
    eofs_datafile = os.path.join(
        EOFS_NC_DIR,
        get_eofs_output_filename(
            anom_output_file, hemisphere, region, EOFS_BASE_PERIOD,
            season, LAT_WEIGHTS, DEFAULT_MAX_EOFS, EOFS_NORMALIZATION))

    eofs_ds = xr.open_dataset(eofs_datafile)

    fig = plot_nh_atlantic3_eofs(eofs_ds, season, n_eofs_to_plot=n_eofs_to_plot)
    
    eofs_ds.close()
    
    if SAVE_PLOTS:

        output_filename = get_eof_plot_output_filename(eofs_datafile, n_eofs_to_plot)
        output_file = os.path.join(EOFS_PLOTS_DIR, output_filename)

        plt.savefig(output_file, bbox_inches='tight')
        
    plt.show()
    
    plt.close()

In [None]:
def plot_nh_eurasia_eofs(eofs_ds, season, n_eofs_to_plot=20, cmap=plt.cm.RdBu_r,
                         lat_name=LAT_NAME, lon_name=LON_NAME):
    """Plot NH Eurasia region EOFs.
    
    Parameters
    ----------
    eofs_ds : xarray.Dataset
        Dataset containing the results of the EOF analysis.
        
    season : 'ALL' | 'DJF' | 'MAM' | 'JJA' | 'SON'
        Season to plot EOFs for.
        
    n_eofs_to_plot : int, default: 20
        Number of EOF patterns to plot.
        
    cmap : object
        Colormap to use.
        
    lat_name : str
        Name of latitude coordinate.
        
    lon_name : str
        Name of longitude coordinate.

    Returns
    -------
    fig : figure
        Plot of EOFs.
    """

    hemisphere = 'NH'
    region = 'eurasia'

    lon_bounds = get_region_lon_bounds(hemisphere, region)
    lat_bounds = [0.0, 90.0]

    wrap_lon = False

    projection = ccrs.EquidistantConic(central_longitude=np.mean(lon_bounds), central_latitude=90)

    components = eofs_ds['component'].isel(component=slice(0, n_eofs_to_plot)).values

    component_vmins = np.empty(n_eofs_to_plot)
    vmin = None
    for i, component in enumerate(components):
        component_vmin = eofs_ds['eofs'].sel(component=component).min().item()
        if vmin is None or component_vmin < vmin:
            vmin = component_vmin

    component_vmins[:] = vmin

    component_vmaxs = np.empty(n_eofs_to_plot)
    vmax = None
    for i, component in enumerate(components):
        component_vmax = eofs_ds['eofs'].sel(component=component).max().item()
        if vmax is None or component_vmax > vmax:
            vmax = component_vmax

    component_vmaxs[:] = vmax

    if n_eofs_to_plot % 4 == 0:
        ncols = 4
    elif n_eofs_to_plot % 2 == 0:
        ncols = 2
    else:
        ncols = 3

    nrows = int(np.ceil(n_eofs_to_plot / ncols))
    height_ratios = np.ones((nrows + 1))
    height_ratios[-1] = 0.1

    fig = plt.figure(constrained_layout=False, figsize=(4 * ncols, 4 * nrows))

    gs = gridspec.GridSpec(ncols=ncols, nrows=nrows + 1, figure=fig,
                           wspace=0.1, hspace=0.2,
                           height_ratios=height_ratios)

    lat = eofs_ds[lat_name].values
    lon = eofs_ds[lon_name].values

    row_index = 0
    col_index = 0
    for i, component in enumerate(components):
        eof_data = eofs_ds['eofs'].sel(component=component).squeeze().values
        expl_var = eofs_ds['explained_variance_ratio'].sel(
            component=component).item()

        if wrap_lon:
            eof_data, eof_lon = add_cyclic_point(eof_data, coord=lon)
        else:
            eof_lon = lon

        lon_grid, lat_grid = np.meshgrid(eof_lon, lat)

        ax = fig.add_subplot(gs[row_index, col_index], projection=projection)

        ax.coastlines()

        left_bound_verts = np.vstack([np.repeat(lon_bounds[0], 100), np.flipud(np.linspace(0, 90.0, 100))]).T
        right_bound_verts = np.vstack([np.repeat(lon_bounds[1], 100), np.linspace(0, 90.0, 100)]).T
        
        lower_bound_verts = np.vstack([np.linspace(lon_bounds[0], lon_bounds[1], 100),
                                       np.repeat(0.0, 100)]).T
        upper_bound_verts = np.vstack([np.flipud(np.linspace(lon_bounds[0], lon_bounds[1], 100)),
                                       np.repeat(90.0, 100)]).T

        boundary = mpath.Path(np.vstack([left_bound_verts, lower_bound_verts, right_bound_verts, upper_bound_verts]))

        ax.set_boundary(boundary, transform=ccrs.PlateCarree())
        ax.set_extent(np.concatenate([lon_bounds, lat_bounds]), crs=ccrs.PlateCarree())

        ax_vmin = component_vmins[i]
        ax_vmax = component_vmaxs[i]

        cs = ax.pcolor(lon_grid, lat_grid, eof_data, vmin=ax_vmin, vmax=ax_vmax,
                       cmap=cmap, transform=ccrs.PlateCarree())

        ax.set_title('EOF {} ({:.2f}%)'.format(
            component + 1, expl_var * 100), fontsize=14)

        fig.canvas.draw()

        col_index += 1
        if col_index == ncols:
            col_index = 0
            row_index += 1

    cb_ax = fig.add_subplot(gs[-1, :])
    cb = fig.colorbar(cs, cax=cb_ax, pad=0.05, orientation='horizontal')

    fig.suptitle('Hemisphere: {}, Region: {}, Season: {}'.format(hemisphere, region, season),
                 fontsize=16, x=0.5, y=0.92)

    return fig

In [None]:
hemisphere = 'NH'
region = 'eurasia'
seasons = ['ALL', 'DJF', 'MAM', 'JJA', 'SON']

n_eofs_to_plot = 20

for season in seasons:
    
    eofs_datafile = os.path.join(
        EOFS_NC_DIR,
        get_eofs_output_filename(
            anom_output_file, hemisphere, region, EOFS_BASE_PERIOD,
            season, LAT_WEIGHTS, DEFAULT_MAX_EOFS, EOFS_NORMALIZATION))

    eofs_ds = xr.open_dataset(eofs_datafile)

    fig = plot_nh_eurasia_eofs(eofs_ds, season, n_eofs_to_plot=n_eofs_to_plot)
    
    eofs_ds.close()
    
    if SAVE_PLOTS:

        output_filename = get_eof_plot_output_filename(eofs_datafile, n_eofs_to_plot)
        output_file = os.path.join(EOFS_PLOTS_DIR, output_filename)

        plt.savefig(output_file, bbox_inches='tight')
        
    plt.show()
    
    plt.close()

In [None]:
def plot_nh_pacific_eofs(eofs_ds, season, n_eofs_to_plot=20, cmap=plt.cm.RdBu_r,
                         lat_name=LAT_NAME, lon_name=LON_NAME):
    """Plot NH Pacific region EOFs.
    
    Parameters
    ----------
    eofs_ds : xarray.Dataset
        Dataset containing the results of the EOF analysis.
        
    season : 'ALL' | 'DJF' | 'MAM' | 'JJA' | 'SON'
        Season to plot EOFs for.
        
    n_eofs_to_plot : int, default: 20
        Number of EOF patterns to plot.
        
    cmap : object
        Colormap to use.
        
    lat_name : str
        Name of latitude coordinate.
        
    lon_name : str
        Name of longitude coordinate.

    Returns
    -------
    fig : figure
        Plot of EOFs.
    """

    hemisphere = 'NH'
    region = 'pacific'

    lon_bounds = get_region_lon_bounds(hemisphere, region)
    lat_bounds = [0.0, 90.0]

    wrap_lon = False
    
    projection = ccrs.EquidistantConic(central_longitude=np.mean(lon_bounds), central_latitude=90)

    components = eofs_ds['component'].isel(component=slice(0, n_eofs_to_plot)).values

    component_vmins = np.empty(n_eofs_to_plot)
    vmin = None
    for i, component in enumerate(components):
        component_vmin = eofs_ds['eofs'].sel(component=component).min().item()
        if vmin is None or component_vmin < vmin:
            vmin = component_vmin

    component_vmins[:] = vmin

    component_vmaxs = np.empty(n_eofs_to_plot)
    vmax = None
    for i, component in enumerate(components):
        component_vmax = eofs_ds['eofs'].sel(component=component).max().item()
        if vmax is None or component_vmax > vmax:
            vmax = component_vmax

    component_vmaxs[:] = vmax

    if n_eofs_to_plot % 4 == 0:
        ncols = 4
    elif n_eofs_to_plot % 2 == 0:
        ncols = 2
    else:
        ncols = 3

    nrows = int(np.ceil(n_eofs_to_plot / ncols))
    height_ratios = np.ones((nrows + 1))
    height_ratios[-1] = 0.1

    fig = plt.figure(constrained_layout=False, figsize=(4 * ncols, 4 * nrows))

    gs = gridspec.GridSpec(ncols=ncols, nrows=nrows + 1, figure=fig,
                           wspace=0.1, hspace=0.2,
                           height_ratios=height_ratios)

    lat = eofs_ds[lat_name].values
    lon = eofs_ds[lon_name].values

    row_index = 0
    col_index = 0
    for i, component in enumerate(components):
        eof_data = eofs_ds['eofs'].sel(component=component).squeeze().values
        expl_var = eofs_ds['explained_variance_ratio'].sel(
            component=component).item()

        if wrap_lon:
            eof_data, eof_lon = add_cyclic_point(eof_data, coord=lon)
        else:
            eof_lon = lon

        lon_grid, lat_grid = np.meshgrid(eof_lon, lat)

        ax = fig.add_subplot(gs[row_index, col_index], projection=projection)

        ax.coastlines()

        left_bound_verts = np.vstack([np.repeat(lon_bounds[0], 100), np.flipud(np.linspace(0, 90.0, 100))]).T
        right_bound_verts = np.vstack([np.repeat(lon_bounds[1], 100), np.linspace(0, 90.0, 100)]).T
        
        lower_bound_verts = np.vstack([np.linspace(lon_bounds[0], lon_bounds[1], 100),
                                       np.repeat(0.0, 100)]).T
        upper_bound_verts = np.vstack([np.flipud(np.linspace(lon_bounds[0], lon_bounds[1], 100)),
                                       np.repeat(90.0, 100)]).T

        boundary = mpath.Path(np.vstack([left_bound_verts, lower_bound_verts, right_bound_verts, upper_bound_verts]))

        ax.set_boundary(boundary, transform=ccrs.PlateCarree())
        ax.set_extent(np.concatenate([lon_bounds, lat_bounds]), crs=ccrs.PlateCarree())

        ax_vmin = component_vmins[i]
        ax_vmax = component_vmaxs[i]

        cs = ax.pcolor(lon_grid, lat_grid, eof_data, vmin=ax_vmin, vmax=ax_vmax,
                       cmap=cmap, transform=ccrs.PlateCarree())

        ax.set_title('EOF {} ({:.2f}%)'.format(
            component + 1, expl_var * 100), fontsize=14)

        fig.canvas.draw()

        col_index += 1
        if col_index == ncols:
            col_index = 0
            row_index += 1

    cb_ax = fig.add_subplot(gs[-1, :])
    cb = fig.colorbar(cs, cax=cb_ax, pad=0.05, orientation='horizontal')

    fig.suptitle('Hemisphere: {}, Region: {}, Season: {}'.format(hemisphere, region, season),
                 fontsize=16, x=0.5, y=0.92)

    return fig

In [None]:
hemisphere = 'NH'
region = 'pacific'
seasons = ['ALL', 'DJF', 'MAM', 'JJA', 'SON']

n_eofs_to_plot = 20

for season in seasons:
    
    eofs_datafile = os.path.join(
        EOFS_NC_DIR,
        get_eofs_output_filename(
            anom_output_file, hemisphere, region, EOFS_BASE_PERIOD,
            season, LAT_WEIGHTS, DEFAULT_MAX_EOFS, EOFS_NORMALIZATION))

    eofs_ds = xr.open_dataset(eofs_datafile)

    fig = plot_nh_pacific_eofs(eofs_ds, season, n_eofs_to_plot=n_eofs_to_plot)
    
    eofs_ds.close()
    
    if SAVE_PLOTS:

        output_filename = get_eof_plot_output_filename(eofs_datafile, n_eofs_to_plot)
        output_file = os.path.join(EOFS_PLOTS_DIR, output_filename)

        plt.savefig(output_file, bbox_inches='tight')
        
    plt.show()
    
    plt.close()

In [None]:
def plot_nh_atlantic_eurasia_eofs(eofs_ds, season, n_eofs_to_plot=20, cmap=plt.cm.RdBu_r,
                                  lat_name=LAT_NAME, lon_name=LON_NAME):
    """Plot NH Atlantic-Eurasia region EOFs.
    
    Parameters
    ----------
    eofs_ds : xarray.Dataset
        Dataset containing the results of the EOF analysis.
        
    season : 'ALL' | 'DJF' | 'MAM' | 'JJA' | 'SON'
        Season to plot EOFs for.
        
    n_eofs_to_plot : int, default: 20
        Number of EOF patterns to plot.
        
    cmap : object
        Colormap to use.
        
    lat_name : str
        Name of latitude coordinate.
        
    lon_name : str
        Name of longitude coordinate.

    Returns
    -------
    fig : figure
        Plot of EOFs.
    """

    hemisphere = 'NH'
    region = 'atlantic_eurasia'

    lon_bounds = get_region_lon_bounds(hemisphere, region)
    lon_bounds[lon_bounds > 180] -= 360
    lat_bounds = [0.0, 90.0]

    wrap_lon = False

    projection = ccrs.EquidistantConic(central_longitude=0, central_latitude=90)

    components = eofs_ds['component'].isel(component=slice(0, n_eofs_to_plot)).values

    component_vmins = np.empty(n_eofs_to_plot)
    vmin = None
    for i, component in enumerate(components):
        component_vmin = eofs_ds['eofs'].sel(component=component).min().item()
        if vmin is None or component_vmin < vmin:
            vmin = component_vmin

    component_vmins[:] = vmin

    component_vmaxs = np.empty(n_eofs_to_plot)
    vmax = None
    for i, component in enumerate(components):
        component_vmax = eofs_ds['eofs'].sel(component=component).max().item()
        if vmax is None or component_vmax > vmax:
            vmax = component_vmax

    component_vmaxs[:] = vmax

    if n_eofs_to_plot % 4 == 0:
        ncols = 4
    elif n_eofs_to_plot % 2 == 0:
        ncols = 2
    else:
        ncols = 3

    nrows = int(np.ceil(n_eofs_to_plot / ncols))
    height_ratios = np.ones((nrows + 1))
    height_ratios[-1] = 0.1

    fig = plt.figure(constrained_layout=False, figsize=(4 * ncols, 3 * nrows))

    gs = gridspec.GridSpec(ncols=ncols, nrows=nrows + 1, figure=fig,
                           wspace=0.1, hspace=0.01,
                           height_ratios=height_ratios)

    lat = eofs_ds[lat_name].values
    lon = eofs_ds[lon_name].values
    lon[lon > 180] -= 360

    row_index = 0
    col_index = 0
    for i, component in enumerate(components):
        eof_data = eofs_ds['eofs'].sel(component=component).squeeze().values
        expl_var = eofs_ds['explained_variance_ratio'].sel(
            component=component).item()

        if wrap_lon:
            eof_data, eof_lon = add_cyclic_point(eof_data, coord=lon)
        else:
            eof_lon = lon

        lon_grid, lat_grid = np.meshgrid(eof_lon, lat)

        ax = fig.add_subplot(gs[row_index, col_index], projection=projection)

        ax.coastlines()

        left_bound_verts = np.vstack([np.repeat(lon_bounds[0], 100), np.flipud(np.linspace(0, 90.0, 100))]).T
        right_bound_verts = np.vstack([np.repeat(lon_bounds[1], 100), np.linspace(0, 90.0, 100)]).T

        lon_grid = np.sort(lon_grid, axis=1)
        for lat_idx in range(lat_grid.shape[0]):
            eof_data[lat_idx, :] = eof_data[lat_idx, np.argsort(lon)]

        lower_bound_verts = np.vstack([np.linspace(lon_bounds[0], lon_bounds[1], 100), np.repeat(0.0, 100)]).T
        upper_bound_verts = np.vstack([np.flipud(np.linspace(lon_bounds[0], lon_bounds[1], 100)),
                                       np.repeat(90.0, 100)]).T

        boundary = mpath.Path(np.vstack([left_bound_verts, lower_bound_verts, right_bound_verts, upper_bound_verts]))

        ax.set_boundary(boundary, transform=ccrs.PlateCarree())
        ax.set_extent(np.concatenate([lon_bounds, lat_bounds]), crs=ccrs.PlateCarree())

        ax_vmin = component_vmins[i]
        ax_vmax = component_vmaxs[i]

        cs = ax.pcolor(lon_grid, lat_grid, eof_data, vmin=ax_vmin, vmax=ax_vmax,
                       cmap=cmap, transform=ccrs.PlateCarree())

        ax.set_title('EOF {} ({:.2f}%)'.format(
            component + 1, expl_var * 100), fontsize=14)
        ax.set_aspect('equal')
        fig.canvas.draw()

        col_index += 1
        if col_index == ncols:
            col_index = 0
            row_index += 1

    cb_ax = fig.add_subplot(gs[-1, :])
    cb = fig.colorbar(cs, cax=cb_ax, pad=0.05, orientation='horizontal')

    fig.suptitle('Hemisphere: {}, Region: {}, Season: {}'.format(hemisphere, region, season),
                 fontsize=16, x=0.5, y=0.92)

    return fig

In [None]:
hemisphere = 'NH'
region = 'atlantic_eurasia'
seasons = ['ALL', 'DJF', 'MAM', 'JJA', 'SON']

n_eofs_to_plot = 20

for season in seasons:
    
    eofs_datafile = os.path.join(
        EOFS_NC_DIR,
        get_eofs_output_filename(
            anom_output_file, hemisphere, region, EOFS_BASE_PERIOD,
            season, LAT_WEIGHTS, DEFAULT_MAX_EOFS, EOFS_NORMALIZATION))

    eofs_ds = xr.open_dataset(eofs_datafile)

    fig = plot_nh_atlantic_eurasia_eofs(eofs_ds, season, n_eofs_to_plot=n_eofs_to_plot)
    
    eofs_ds.close()
    
    if SAVE_PLOTS:

        output_filename = get_eof_plot_output_filename(eofs_datafile, n_eofs_to_plot)
        output_file = os.path.join(EOFS_PLOTS_DIR, output_filename)

        plt.savefig(output_file, bbox_inches='tight')
        
    plt.show()
    
    plt.close()

### Southern Hemisphere

In [None]:
def plot_sh_eofs(eofs_ds, season, n_eofs_to_plot=20, cmap=plt.cm.RdBu_r,
                 lat_name=LAT_NAME, lon_name=LON_NAME):
    """Plot SH EOFs.
    
    Parameters
    ----------
    eofs_ds : xarray.Dataset
        Dataset containing the results of the EOF analysis.
        
    season : 'ALL' | 'DJF' | 'MAM' | 'JJA' | 'SON'
        Season to plot EOFs for.
        
    n_eofs_to_plot : int, default: 20
        Number of EOF patterns to plot.
        
    cmap : object
        Colormap to use.
        
    lat_name : str
        Name of latitude coordinate.
        
    lon_name : str
        Name of longitude coordinate.

    Returns
    -------
    fig : figure
        Plot of EOFs.
    """

    hemisphere = 'SH'
    region = 'all'

    wrap_lon = True
    
    projection = ccrs.Orthographic(central_longitude=0, central_latitude=-90)

    components = eofs_ds['component'].isel(component=slice(0, n_eofs_to_plot)).values

    component_vmins = np.empty(n_eofs_to_plot)
    vmin = None
    for i, component in enumerate(components):
        component_vmin = eofs_ds['eofs'].sel(component=component).min().item()
        if vmin is None or component_vmin < vmin:
            vmin = component_vmin

    component_vmins[:] = vmin

    component_vmaxs = np.empty(n_eofs_to_plot)
    vmax = None
    for i, component in enumerate(components):
        component_vmax = eofs_ds['eofs'].sel(component=component).max().item()
        if vmax is None or component_vmax > vmax:
            vmax = component_vmax

    component_vmaxs[:] = vmax

    if n_eofs_to_plot % 4 == 0:
        ncols = 4
    elif n_eofs_to_plot % 2 == 0:
        ncols = 2
    else:
        ncols = 3

    nrows = int(np.ceil(n_eofs_to_plot / ncols))
    height_ratios = np.ones((nrows + 1))
    height_ratios[-1] = 0.1

    fig = plt.figure(constrained_layout=False, figsize=(4 * ncols, 4 * nrows))

    gs = gridspec.GridSpec(ncols=ncols, nrows=nrows + 1, figure=fig,
                           wspace=0.01, hspace=0.2,
                           height_ratios=height_ratios)

    lat = eofs_ds[lat_name]
    lon = eofs_ds[lon_name]

    row_index = 0
    col_index = 0
    for i, component in enumerate(components):
        eof_data = eofs_ds['eofs'].sel(component=component).squeeze().values
        expl_var = eofs_ds['explained_variance_ratio'].sel(
            component=component).item()

        if wrap_lon:
            eof_data, eof_lon = add_cyclic_point(eof_data, coord=lon)

        lon_grid, lat_grid = np.meshgrid(eof_lon, lat)

        ax = fig.add_subplot(gs[row_index, col_index], projection=projection)

        ax.coastlines()
        ax.set_global()

        ax_vmin = component_vmins[i]
        ax_vmax = component_vmaxs[i]

        cs = ax.pcolor(lon_grid, lat_grid, eof_data, vmin=ax_vmin, vmax=ax_vmax,
                       cmap=cmap, transform=ccrs.PlateCarree())

        ax.set_title('EOF {} ({:.2f}%)'.format(
            component + 1, expl_var * 100), fontsize=14)
        ax.set_aspect('equal')
        fig.canvas.draw()

        col_index += 1
        if col_index == ncols:
            col_index = 0
            row_index += 1

    cb_ax = fig.add_subplot(gs[-1, :])
    cb = fig.colorbar(cs, cax=cb_ax, pad=0.05, orientation='horizontal')

    fig.suptitle('Hemisphere: {}, Region: {}, Season: {}'.format(hemisphere, region, season),
                 fontsize=16, x=0.5, y=0.92)

    return fig

In [None]:
hemisphere = 'SH'
region = 'all'
seasons = ['ALL', 'DJF', 'MAM', 'JJA', 'SON']

n_eofs_to_plot = 20

for season in seasons:
    
    eofs_datafile = os.path.join(
        EOFS_NC_DIR,
        get_eofs_output_filename(
            anom_output_file, hemisphere, region, EOFS_BASE_PERIOD,
            season, LAT_WEIGHTS, DEFAULT_MAX_EOFS, EOFS_NORMALIZATION))

    eofs_ds = xr.open_dataset(eofs_datafile)

    fig = plot_sh_eofs(eofs_ds, season, n_eofs_to_plot=n_eofs_to_plot)
    
    eofs_ds.close()
    
    if SAVE_PLOTS:

        output_filename = get_eof_plot_output_filename(eofs_datafile, n_eofs_to_plot)
        output_file = os.path.join(EOFS_PLOTS_DIR, output_filename)

        plt.savefig(output_file, bbox_inches='tight')
        
    plt.show()
    
    plt.close()

In [None]:
def plot_sh_australian_eofs(eofs_ds, season, n_eofs_to_plot=20, cmap=plt.cm.RdBu_r,
                            lat_name=LAT_NAME, lon_name=LON_NAME):
    """Plot SH Australian region EOFs.
    
    Parameters
    ----------
    eofs_ds : xarray.Dataset
        Dataset containing the results of the EOF analysis.
        
    season : 'ALL' | 'DJF' | 'MAM' | 'JJA' | 'SON'
        Season to plot EOFs for.
        
    n_eofs_to_plot : int, default: 20
        Number of EOF patterns to plot.
        
    cmap : object
        Colormap to use.
        
    lat_name : str
        Name of latitude coordinate.
        
    lon_name : str
        Name of longitude coordinate.

    Returns
    -------
    fig : figure
        Plot of EOFs.
    """

    hemisphere = 'SH'
    region = 'australian'

    lon_bounds = get_region_lon_bounds(hemisphere, region)
    lat_bounds = [-90.0, 0.0]

    wrap_lon = False
    
    projection = ccrs.EquidistantConic(central_longitude=np.mean(lon_bounds), central_latitude=-90,
                                       standard_parallels=(-20, -50))

    components = eofs_ds['component'].isel(component=slice(0, n_eofs_to_plot)).values

    component_vmins = np.empty(n_eofs_to_plot)
    vmin = None
    for i, component in enumerate(components):
        component_vmin = eofs_ds['eofs'].sel(component=component).min().item()
        if vmin is None or component_vmin < vmin:
            vmin = component_vmin

    component_vmins[:] = vmin

    component_vmaxs = np.empty(n_eofs_to_plot)
    vmax = None
    for i, component in enumerate(components):
        component_vmax = eofs_ds['eofs'].sel(component=component).max().item()
        if vmax is None or component_vmax > vmax:
            vmax = component_vmax

    component_vmaxs[:] = vmax

    if n_eofs_to_plot % 4 == 0:
        ncols = 4
    elif n_eofs_to_plot % 2 == 0:
        ncols = 2
    else:
        ncols = 3

    nrows = int(np.ceil(n_eofs_to_plot / ncols))
    height_ratios = np.ones((nrows + 1))
    height_ratios[-1] = 0.1

    fig = plt.figure(constrained_layout=False, figsize=(4 * ncols, 4 * nrows))

    gs = gridspec.GridSpec(ncols=ncols, nrows=nrows + 1, figure=fig,
                           wspace=0.1, hspace=0.2,
                           height_ratios=height_ratios)

    lat = eofs_ds[lat_name].values
    lon = eofs_ds[lon_name].values

    row_index = 0
    col_index = 0
    for i, component in enumerate(components):
        eof_data = eofs_ds['eofs'].sel(component=component).squeeze().values
        expl_var = eofs_ds['explained_variance_ratio'].sel(
            component=component).item()

        if wrap_lon:
            eof_data, eof_lon = add_cyclic_point(eof_data, coord=lon)
        else:
            eof_lon = lon

        lon_grid, lat_grid = np.meshgrid(eof_lon, lat)

        ax = fig.add_subplot(gs[row_index, col_index], projection=projection)

        ax.coastlines()

        left_bound_verts = np.vstack([np.repeat(lon_bounds[0], 100), np.flipud(np.linspace(-90.0, 0.0, 100))]).T
        right_bound_verts = np.vstack([np.repeat(lon_bounds[1], 100), np.linspace(-90.0, 0.0, 100)]).T
        
        lower_bound_verts = np.vstack([np.linspace(lon_bounds[0], lon_bounds[1], 100),
                                       np.repeat(-90.0, 100)]).T
        upper_bound_verts = np.vstack([np.flipud(np.linspace(lon_bounds[0], lon_bounds[1], 100)),
                                       np.repeat(0.0, 100)]).T

        boundary = mpath.Path(np.vstack([left_bound_verts, lower_bound_verts, right_bound_verts, upper_bound_verts]))

        ax.set_boundary(boundary, transform=ccrs.PlateCarree())
        ax.set_extent(np.concatenate([lon_bounds, lat_bounds]), crs=ccrs.PlateCarree())

        ax_vmin = component_vmins[i]
        ax_vmax = component_vmaxs[i]

        cs = ax.pcolor(lon_grid, lat_grid, eof_data, vmin=ax_vmin, vmax=ax_vmax,
                       cmap=cmap, transform=ccrs.PlateCarree())

        ax.set_title('EOF {} ({:.2f}%)'.format(
            component + 1, expl_var * 100), fontsize=14)

        fig.canvas.draw()

        col_index += 1
        if col_index == ncols:
            col_index = 0
            row_index += 1

    cb_ax = fig.add_subplot(gs[-1, :])
    cb = fig.colorbar(cs, cax=cb_ax, pad=0.05, orientation='horizontal')

    fig.suptitle('Hemisphere: {}, Region: {}, Season: {}'.format(hemisphere, region, season),
                 fontsize=16, x=0.5, y=0.92)

    return fig

In [None]:
hemisphere = 'SH'
region = 'australian'
seasons = ['ALL', 'DJF', 'MAM', 'JJA', 'SON']

n_eofs_to_plot = 20

for season in seasons:
    
    eofs_datafile = os.path.join(
        EOFS_NC_DIR,
        get_eofs_output_filename(
            anom_output_file, hemisphere, region, EOFS_BASE_PERIOD,
            season, LAT_WEIGHTS, DEFAULT_MAX_EOFS, EOFS_NORMALIZATION))

    eofs_ds = xr.open_dataset(eofs_datafile)

    fig = plot_sh_australian_eofs(eofs_ds, season, n_eofs_to_plot=n_eofs_to_plot)
    
    eofs_ds.close()
    
    if SAVE_PLOTS:

        output_filename = get_eof_plot_output_filename(eofs_datafile, n_eofs_to_plot)
        output_file = os.path.join(EOFS_PLOTS_DIR, output_filename)

        plt.savefig(output_file, bbox_inches='tight')
        
    plt.show()
    
    plt.close()

In [None]:
def plot_sh_full_pacific_eofs(eofs_ds, season, n_eofs_to_plot=20, cmap=plt.cm.RdBu_r,
                              lat_name=LAT_NAME, lon_name=LON_NAME):
    """Plot SH full Pacific region EOFs.
    
    Parameters
    ----------
    eofs_ds : xarray.Dataset
        Dataset containing the results of the EOF analysis.
        
    season : 'ALL' | 'DJF' | 'MAM' | 'JJA' | 'SON'
        Season to plot EOFs for.
        
    n_eofs_to_plot : int, default: 20
        Number of EOF patterns to plot.
        
    cmap : object
        Colormap to use.
        
    lat_name : str
        Name of latitude coordinate.
        
    lon_name : str
        Name of longitude coordinate.

    Returns
    -------
    fig : figure
        Plot of EOFs.
    """

    hemisphere = 'SH'
    region = 'full_pacific'

    lon_bounds = get_region_lon_bounds(hemisphere, region)
    lat_bounds = [-90.0, 0.0]

    wrap_lon = False
    
    projection = ccrs.EquidistantConic(central_longitude=np.mean(lon_bounds), central_latitude=-90,
                                       standard_parallels=(-20, -50))

    components = eofs_ds['component'].isel(component=slice(0, n_eofs_to_plot)).values

    component_vmins = np.empty(n_eofs_to_plot)
    vmin = None
    for i, component in enumerate(components):
        component_vmin = eofs_ds['eofs'].sel(component=component).min().item()
        if vmin is None or component_vmin < vmin:
            vmin = component_vmin

    component_vmins[:] = vmin

    component_vmaxs = np.empty(n_eofs_to_plot)
    vmax = None
    for i, component in enumerate(components):
        component_vmax = eofs_ds['eofs'].sel(component=component).max().item()
        if vmax is None or component_vmax > vmax:
            vmax = component_vmax

    component_vmaxs[:] = vmax

    if n_eofs_to_plot % 4 == 0:
        ncols = 4
    elif n_eofs_to_plot % 2 == 0:
        ncols = 2
    else:
        ncols = 3

    nrows = int(np.ceil(n_eofs_to_plot / ncols))
    height_ratios = np.ones((nrows + 1))
    height_ratios[-1] = 0.1

    fig = plt.figure(constrained_layout=False, figsize=(4 * ncols, 4 * nrows))

    gs = gridspec.GridSpec(ncols=ncols, nrows=nrows + 1, figure=fig,
                           wspace=0.1, hspace=0.2,
                           height_ratios=height_ratios)

    lat = eofs_ds[lat_name].values
    lon = eofs_ds[lon_name].values

    row_index = 0
    col_index = 0
    for i, component in enumerate(components):
        eof_data = eofs_ds['eofs'].sel(component=component).squeeze().values
        expl_var = eofs_ds['explained_variance_ratio'].sel(
            component=component).item()

        if wrap_lon:
            eof_data, eof_lon = add_cyclic_point(eof_data, coord=lon)
        else:
            eof_lon = lon

        lon_grid, lat_grid = np.meshgrid(eof_lon, lat)

        ax = fig.add_subplot(gs[row_index, col_index], projection=projection)

        ax.coastlines()

        left_bound_verts = np.vstack([np.repeat(lon_bounds[0], 100), np.flipud(np.linspace(-90.0, 0.0, 100))]).T
        right_bound_verts = np.vstack([np.repeat(lon_bounds[1], 100), np.linspace(-90.0, 0.0, 100)]).T
        
        lower_bound_verts = np.vstack([np.linspace(lon_bounds[0], lon_bounds[1], 100),
                                       np.repeat(-90.0, 100)]).T
        upper_bound_verts = np.vstack([np.flipud(np.linspace(lon_bounds[0], lon_bounds[1], 100)),
                                       np.repeat(0.0, 100)]).T

        boundary = mpath.Path(np.vstack([left_bound_verts, lower_bound_verts, right_bound_verts, upper_bound_verts]))

        ax.set_boundary(boundary, transform=ccrs.PlateCarree())
        ax.set_extent(np.concatenate([lon_bounds, lat_bounds]), crs=ccrs.PlateCarree())

        ax_vmin = component_vmins[i]
        ax_vmax = component_vmaxs[i]

        cs = ax.pcolor(lon_grid, lat_grid, eof_data, vmin=ax_vmin, vmax=ax_vmax,
                       cmap=cmap, transform=ccrs.PlateCarree())

        ax.set_title('EOF {} ({:.2f}%)'.format(
            component + 1, expl_var * 100), fontsize=14)

        fig.canvas.draw()

        col_index += 1
        if col_index == ncols:
            col_index = 0
            row_index += 1

    cb_ax = fig.add_subplot(gs[-1, :])
    cb = fig.colorbar(cs, cax=cb_ax, pad=0.05, orientation='horizontal')

    fig.suptitle('Hemisphere: {}, Region: {}, Season: {}'.format(hemisphere, region, season),
                 fontsize=16, x=0.5, y=0.92)

    return fig

In [None]:
hemisphere = 'SH'
region = 'full_pacific'
seasons = ['ALL', 'DJF', 'MAM', 'JJA', 'SON']

n_eofs_to_plot = 20

for season in seasons:
    
    eofs_datafile = os.path.join(
        EOFS_NC_DIR,
        get_eofs_output_filename(
            anom_output_file, hemisphere, region, EOFS_BASE_PERIOD,
            season, LAT_WEIGHTS, DEFAULT_MAX_EOFS, EOFS_NORMALIZATION))

    eofs_ds = xr.open_dataset(eofs_datafile)

    fig = plot_sh_full_pacific_eofs(eofs_ds, season, n_eofs_to_plot=n_eofs_to_plot)
    
    eofs_ds.close()
    
    if SAVE_PLOTS:

        output_filename = get_eof_plot_output_filename(eofs_datafile, n_eofs_to_plot)
        output_file = os.path.join(EOFS_PLOTS_DIR, output_filename)

        plt.savefig(output_file, bbox_inches='tight')
        
    plt.show()
    
    plt.close()

In [None]:
def plot_sh_indian_eofs(eofs_ds, season, n_eofs_to_plot=20, cmap=plt.cm.RdBu_r,
                        lat_name=LAT_NAME, lon_name=LON_NAME):
    """Plot SH Indian region EOFs.
    
    Parameters
    ----------
    eofs_ds : xarray.Dataset
        Dataset containing the results of the EOF analysis.
        
    season : 'ALL' | 'DJF' | 'MAM' | 'JJA' | 'SON'
        Season to plot EOFs for.
        
    n_eofs_to_plot : int, default: 20
        Number of EOF patterns to plot.
        
    cmap : object
        Colormap to use.
        
    lat_name : str
        Name of latitude coordinate.
        
    lon_name : str
        Name of longitude coordinate.

    Returns
    -------
    fig : figure
        Plot of EOFs.
    """

    hemisphere = 'SH'
    region = 'indian'

    lon_bounds = get_region_lon_bounds(hemisphere, region)
    lat_bounds = [-90.0, 0.0]

    wrap_lon = False
    
    projection = ccrs.EquidistantConic(central_longitude=np.mean(lon_bounds), central_latitude=-90,
                                       standard_parallels=(-20, -50))

    components = eofs_ds['component'].isel(component=slice(0, n_eofs_to_plot)).values

    component_vmins = np.empty(n_eofs_to_plot)
    vmin = None
    for i, component in enumerate(components):
        component_vmin = eofs_ds['eofs'].sel(component=component).min().item()
        if vmin is None or component_vmin < vmin:
            vmin = component_vmin

    component_vmins[:] = vmin

    component_vmaxs = np.empty(n_eofs_to_plot)
    vmax = None
    for i, component in enumerate(components):
        component_vmax = eofs_ds['eofs'].sel(component=component).max().item()
        if vmax is None or component_vmax > vmax:
            vmax = component_vmax

    component_vmaxs[:] = vmax

    if n_eofs_to_plot % 4 == 0:
        ncols = 4
    elif n_eofs_to_plot % 2 == 0:
        ncols = 2
    else:
        ncols = 3

    nrows = int(np.ceil(n_eofs_to_plot / ncols))
    height_ratios = np.ones((nrows + 1))
    height_ratios[-1] = 0.1

    fig = plt.figure(constrained_layout=False, figsize=(4 * ncols, 4 * nrows))

    gs = gridspec.GridSpec(ncols=ncols, nrows=nrows + 1, figure=fig,
                           wspace=0.1, hspace=0.2,
                           height_ratios=height_ratios)

    lat = eofs_ds[lat_name].values
    lon = eofs_ds[lon_name].values

    row_index = 0
    col_index = 0
    for i, component in enumerate(components):
        eof_data = eofs_ds['eofs'].sel(component=component).squeeze().values
        expl_var = eofs_ds['explained_variance_ratio'].sel(
            component=component).item()

        if wrap_lon:
            eof_data, eof_lon = add_cyclic_point(eof_data, coord=lon)
        else:
            eof_lon = lon

        lon_grid, lat_grid = np.meshgrid(eof_lon, lat)

        ax = fig.add_subplot(gs[row_index, col_index], projection=projection)

        ax.coastlines()

        left_bound_verts = np.vstack([np.repeat(lon_bounds[0], 100), np.flipud(np.linspace(-90.0, 0.0, 100))]).T
        right_bound_verts = np.vstack([np.repeat(lon_bounds[1], 100), np.linspace(-90.0, 0.0, 100)]).T
        
        lower_bound_verts = np.vstack([np.linspace(lon_bounds[0], lon_bounds[1], 100),
                                       np.repeat(-90.0, 100)]).T
        upper_bound_verts = np.vstack([np.flipud(np.linspace(lon_bounds[0], lon_bounds[1], 100)),
                                       np.repeat(0.0, 100)]).T

        boundary = mpath.Path(np.vstack([left_bound_verts, lower_bound_verts, right_bound_verts, upper_bound_verts]))

        ax.set_boundary(boundary, transform=ccrs.PlateCarree())
        ax.set_extent(np.concatenate([lon_bounds, lat_bounds]), crs=ccrs.PlateCarree())

        ax_vmin = component_vmins[i]
        ax_vmax = component_vmaxs[i]

        cs = ax.pcolor(lon_grid, lat_grid, eof_data, vmin=ax_vmin, vmax=ax_vmax,
                       cmap=cmap, transform=ccrs.PlateCarree())

        ax.set_title('EOF {} ({:.2f}%)'.format(
            component + 1, expl_var * 100), fontsize=14)

        fig.canvas.draw()

        col_index += 1
        if col_index == ncols:
            col_index = 0
            row_index += 1

    cb_ax = fig.add_subplot(gs[-1, :])
    cb = fig.colorbar(cs, cax=cb_ax, pad=0.05, orientation='horizontal')

    fig.suptitle('Hemisphere: {}, Region: {}, Season: {}'.format(hemisphere, region, season),
                 fontsize=16, x=0.5, y=0.92)

    return fig

In [None]:
hemisphere = 'SH'
region = 'indian'
seasons = ['ALL', 'DJF', 'MAM', 'JJA', 'SON']

n_eofs_to_plot = 20

for season in seasons:
    
    eofs_datafile = os.path.join(
        EOFS_NC_DIR,
        get_eofs_output_filename(
            anom_output_file, hemisphere, region, EOFS_BASE_PERIOD,
            season, LAT_WEIGHTS, DEFAULT_MAX_EOFS, EOFS_NORMALIZATION))

    eofs_ds = xr.open_dataset(eofs_datafile)

    fig = plot_sh_indian_eofs(eofs_ds, season, n_eofs_to_plot=n_eofs_to_plot)
    
    eofs_ds.close()
    
    if SAVE_PLOTS:

        output_filename = get_eof_plot_output_filename(eofs_datafile, n_eofs_to_plot)
        output_file = os.path.join(EOFS_PLOTS_DIR, output_filename)

        plt.savefig(output_file, bbox_inches='tight')
        
    plt.show()
    
    plt.close()

In [None]:
def plot_sh_pacific_eofs(eofs_ds, season, n_eofs_to_plot=20, cmap=plt.cm.RdBu_r,
                         lat_name=LAT_NAME, lon_name=LON_NAME):
    """Plot SH Pacific region EOFs.
    
    Parameters
    ----------
    eofs_ds : xarray.Dataset
        Dataset containing the results of the EOF analysis.
        
    season : 'ALL' | 'DJF' | 'MAM' | 'JJA' | 'SON'
        Season to plot EOFs for.
        
    n_eofs_to_plot : int, default: 20
        Number of EOF patterns to plot.
        
    cmap : object
        Colormap to use.
        
    lat_name : str
        Name of latitude coordinate.
        
    lon_name : str
        Name of longitude coordinate.

    Returns
    -------
    fig : figure
        Plot of EOFs.
    """

    hemisphere = 'SH'
    region = 'pacific'

    lon_bounds = get_region_lon_bounds(hemisphere, region)
    lat_bounds = [-90.0, 0.0]

    wrap_lon = False
    
    projection = ccrs.EquidistantConic(central_longitude=np.mean(lon_bounds), central_latitude=-90,
                                       standard_parallels=(-20, -50))

    components = eofs_ds['component'].isel(component=slice(0, n_eofs_to_plot)).values

    component_vmins = np.empty(n_eofs_to_plot)
    vmin = None
    for i, component in enumerate(components):
        component_vmin = eofs_ds['eofs'].sel(component=component).min().item()
        if vmin is None or component_vmin < vmin:
            vmin = component_vmin

    component_vmins[:] = vmin

    component_vmaxs = np.empty(n_eofs_to_plot)
    vmax = None
    for i, component in enumerate(components):
        component_vmax = eofs_ds['eofs'].sel(component=component).max().item()
        if vmax is None or component_vmax > vmax:
            vmax = component_vmax

    component_vmaxs[:] = vmax

    if n_eofs_to_plot % 4 == 0:
        ncols = 4
    elif n_eofs_to_plot % 2 == 0:
        ncols = 2
    else:
        ncols = 3

    nrows = int(np.ceil(n_eofs_to_plot / ncols))
    height_ratios = np.ones((nrows + 1))
    height_ratios[-1] = 0.1

    fig = plt.figure(constrained_layout=False, figsize=(4 * ncols, 4 * nrows))

    gs = gridspec.GridSpec(ncols=ncols, nrows=nrows + 1, figure=fig,
                           wspace=0.1, hspace=0.2,
                           height_ratios=height_ratios)

    lat = eofs_ds[lat_name].values
    lon = eofs_ds[lon_name].values

    row_index = 0
    col_index = 0
    for i, component in enumerate(components):
        eof_data = eofs_ds['eofs'].sel(component=component).squeeze().values
        expl_var = eofs_ds['explained_variance_ratio'].sel(
            component=component).item()

        if wrap_lon:
            eof_data, eof_lon = add_cyclic_point(eof_data, coord=lon)
        else:
            eof_lon = lon

        lon_grid, lat_grid = np.meshgrid(eof_lon, lat)

        ax = fig.add_subplot(gs[row_index, col_index], projection=projection)

        ax.coastlines()

        left_bound_verts = np.vstack([np.repeat(lon_bounds[0], 100), np.flipud(np.linspace(-90.0, 0.0, 100))]).T
        right_bound_verts = np.vstack([np.repeat(lon_bounds[1], 100), np.linspace(-90.0, 0.0, 100)]).T
        
        lower_bound_verts = np.vstack([np.linspace(lon_bounds[0], lon_bounds[1], 100),
                                       np.repeat(-90.0, 100)]).T
        upper_bound_verts = np.vstack([np.flipud(np.linspace(lon_bounds[0], lon_bounds[1], 100)),
                                       np.repeat(0.0, 100)]).T

        boundary = mpath.Path(np.vstack([left_bound_verts, lower_bound_verts, right_bound_verts, upper_bound_verts]))

        ax.set_boundary(boundary, transform=ccrs.PlateCarree())
        ax.set_extent(np.concatenate([lon_bounds, lat_bounds]), crs=ccrs.PlateCarree())

        ax_vmin = component_vmins[i]
        ax_vmax = component_vmaxs[i]

        cs = ax.pcolor(lon_grid, lat_grid, eof_data, vmin=ax_vmin, vmax=ax_vmax,
                       cmap=cmap, transform=ccrs.PlateCarree())

        ax.set_title('EOF {} ({:.2f}%)'.format(
            component + 1, expl_var * 100), fontsize=14)

        fig.canvas.draw()

        col_index += 1
        if col_index == ncols:
            col_index = 0
            row_index += 1

    cb_ax = fig.add_subplot(gs[-1, :])
    cb = fig.colorbar(cs, cax=cb_ax, pad=0.05, orientation='horizontal')

    fig.suptitle('Hemisphere: {}, Region: {}, Season: {}'.format(hemisphere, region, season),
                 fontsize=16, x=0.5, y=0.92)

    return fig

In [None]:
hemisphere = 'SH'
region = 'pacific'
seasons = ['ALL', 'DJF', 'MAM', 'JJA', 'SON']

n_eofs_to_plot = 20

for season in seasons:
    
    eofs_datafile = os.path.join(
        EOFS_NC_DIR,
        get_eofs_output_filename(
            anom_output_file, hemisphere, region, EOFS_BASE_PERIOD,
            season, LAT_WEIGHTS, DEFAULT_MAX_EOFS, EOFS_NORMALIZATION))

    eofs_ds = xr.open_dataset(eofs_datafile)

    fig = plot_sh_pacific_eofs(eofs_ds, season, n_eofs_to_plot=n_eofs_to_plot)
    
    eofs_ds.close()
    
    if SAVE_PLOTS:

        output_filename = get_eof_plot_output_filename(eofs_datafile, n_eofs_to_plot)
        output_file = os.path.join(EOFS_PLOTS_DIR, output_filename)

        plt.savefig(output_file, bbox_inches='tight')
        
    plt.show()
    
    plt.close()

In [None]:
def plot_sh_south_america_eofs(eofs_ds, season, n_eofs_to_plot=20, cmap=plt.cm.RdBu_r,
                               lat_name=LAT_NAME, lon_name=LON_NAME):
    """Plot SH South America region EOFs.
    
    Parameters
    ----------
    eofs_ds : xarray.Dataset
        Dataset containing the results of the EOF analysis.
        
    season : 'ALL' | 'DJF' | 'MAM' | 'JJA' | 'SON'
        Season to plot EOFs for.
        
    n_eofs_to_plot : int, default: 20
        Number of EOF patterns to plot.
        
    cmap : object
        Colormap to use.
        
    lat_name : str
        Name of latitude coordinate.
        
    lon_name : str
        Name of longitude coordinate.

    Returns
    -------
    fig : figure
        Plot of EOFs.
    """

    hemisphere = 'SH'
    region = 'south_america'

    lon_bounds = get_region_lon_bounds(hemisphere, region)
    lat_bounds = [-90.0, 0.0]

    wrap_lon = False
    
    projection = ccrs.EquidistantConic(central_longitude=np.mean(lon_bounds), central_latitude=-90,
                                       standard_parallels=(-20, -50))

    components = eofs_ds['component'].isel(component=slice(0, n_eofs_to_plot)).values

    component_vmins = np.empty(n_eofs_to_plot)
    vmin = None
    for i, component in enumerate(components):
        component_vmin = eofs_ds['eofs'].sel(component=component).min().item()
        if vmin is None or component_vmin < vmin:
            vmin = component_vmin

    component_vmins[:] = vmin

    component_vmaxs = np.empty(n_eofs_to_plot)
    vmax = None
    for i, component in enumerate(components):
        component_vmax = eofs_ds['eofs'].sel(component=component).max().item()
        if vmax is None or component_vmax > vmax:
            vmax = component_vmax

    component_vmaxs[:] = vmax

    if n_eofs_to_plot % 4 == 0:
        ncols = 4
    elif n_eofs_to_plot % 2 == 0:
        ncols = 2
    else:
        ncols = 3

    nrows = int(np.ceil(n_eofs_to_plot / ncols))
    height_ratios = np.ones((nrows + 1))
    height_ratios[-1] = 0.1

    fig = plt.figure(constrained_layout=False, figsize=(4 * ncols, 4 * nrows))

    gs = gridspec.GridSpec(ncols=ncols, nrows=nrows + 1, figure=fig,
                           wspace=0.1, hspace=0.2,
                           height_ratios=height_ratios)

    lat = eofs_ds[lat_name].values
    lon = eofs_ds[lon_name].values

    row_index = 0
    col_index = 0
    for i, component in enumerate(components):
        eof_data = eofs_ds['eofs'].sel(component=component).squeeze().values
        expl_var = eofs_ds['explained_variance_ratio'].sel(
            component=component).item()

        if wrap_lon:
            eof_data, eof_lon = add_cyclic_point(eof_data, coord=lon)
        else:
            eof_lon = lon

        lon_grid, lat_grid = np.meshgrid(eof_lon, lat)

        ax = fig.add_subplot(gs[row_index, col_index], projection=projection)

        ax.coastlines()

        left_bound_verts = np.vstack([np.repeat(lon_bounds[0], 100), np.flipud(np.linspace(-90.0, 0.0, 100))]).T
        right_bound_verts = np.vstack([np.repeat(lon_bounds[1], 100), np.linspace(-90.0, 0.0, 100)]).T
        
        lower_bound_verts = np.vstack([np.linspace(lon_bounds[0], lon_bounds[1], 100),
                                       np.repeat(-90.0, 100)]).T
        upper_bound_verts = np.vstack([np.flipud(np.linspace(lon_bounds[0], lon_bounds[1], 100)),
                                       np.repeat(0.0, 100)]).T

        boundary = mpath.Path(np.vstack([left_bound_verts, lower_bound_verts, right_bound_verts, upper_bound_verts]))

        ax.set_boundary(boundary, transform=ccrs.PlateCarree())
        ax.set_extent(np.concatenate([lon_bounds, lat_bounds]), crs=ccrs.PlateCarree())

        ax_vmin = component_vmins[i]
        ax_vmax = component_vmaxs[i]

        cs = ax.pcolor(lon_grid, lat_grid, eof_data, vmin=ax_vmin, vmax=ax_vmax,
                       cmap=cmap, transform=ccrs.PlateCarree())

        ax.set_title('EOF {} ({:.2f}%)'.format(
            component + 1, expl_var * 100), fontsize=14)

        fig.canvas.draw()

        col_index += 1
        if col_index == ncols:
            col_index = 0
            row_index += 1

    cb_ax = fig.add_subplot(gs[-1, :])
    cb = fig.colorbar(cs, cax=cb_ax, pad=0.05, orientation='horizontal')

    fig.suptitle('Hemisphere: {}, Region: {}, Season: {}'.format(hemisphere, region, season),
                 fontsize=16, x=0.5, y=0.92)

    return fig

In [None]:
hemisphere = 'SH'
region = 'south_america'
seasons = ['ALL', 'DJF', 'MAM', 'JJA', 'SON']

n_eofs_to_plot = 20

for season in seasons:
    
    eofs_datafile = os.path.join(
        EOFS_NC_DIR,
        get_eofs_output_filename(
            anom_output_file, hemisphere, region, EOFS_BASE_PERIOD,
            season, LAT_WEIGHTS, DEFAULT_MAX_EOFS, EOFS_NORMALIZATION))

    eofs_ds = xr.open_dataset(eofs_datafile)

    fig = plot_sh_south_america_eofs(eofs_ds, season, n_eofs_to_plot=n_eofs_to_plot)
    
    eofs_ds.close()
    
    if SAVE_PLOTS:

        output_filename = get_eof_plot_output_filename(eofs_datafile, n_eofs_to_plot)
        output_file = os.path.join(EOFS_PLOTS_DIR, output_filename)

        plt.savefig(output_file, bbox_inches='tight')
        
    plt.show()
    
    plt.close()