In [1]:
import netCDF4 as nc
import numpy
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
import os
import cartopy.feature as cfeature
import glob
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib import cm
#import seaborn as sns
import cartopy.crs as ccrs
import cartopy.feature as cf
import mplotutils as mpu # helper functions for cartopy and matplotlib
#import regionmask
from func_plots import *
from func_stats import *

In [2]:
# Regional plot with coastline and country borders
def format_axes(axes):
    for i, ax in enumerate(axes):
        #ax.set_title(plot_names[i])
        ax.coastlines(resolution='50m', linewidth=0.45)
        # ax.add_feature(cf.BORDERS, linewidth=0.3)
        ax.set_extent([lonmin, lonmax, latmin, latmax], crs=data_proj)
        ax.gridlines(draw_labels=False, linewidth=0.3, color="gray", xlocs=range(-180, 180, 10), ylocs=range(-90, 90, 10))

def add_headers(fig, *, row_headers=None, col_headers=None, row_pad=1, col_pad=4, rotate_row_headers=True, **text_kwargs):
    # Based on https://stackoverflow.com/a/25814386
    axes = fig.get_axes()

    for ax in axes:
        sbs = ax.get_subplotspec()

        # Putting headers on cols
        if (col_headers is not None) and sbs.is_first_row():
            ax.annotate(
                col_headers[sbs.colspan.start],
                xy=(0.5, 1),
                xytext=(0, col_pad),
                xycoords="axes fraction",
                textcoords="offset points",
                ha="center",
                va="baseline",
                **text_kwargs,
            )

        # Putting headers on rows
        if (row_headers is not None) and sbs.is_first_col():
            ax.annotate(
                row_headers[sbs.rowspan.start],
                xy=(0, 0.5),
                xytext=(-ax.yaxis.labelpad - row_pad, 0),
                xycoords=ax.transAxes, # ax.yaxis.label replaced by ax.transAxes
                textcoords="offset points",
                ha="center",
                va="center", # vertical even after rotation
                rotation=rotate_row_headers * 90,
                **text_kwargs,
            )

lonmin, lonmax, latmin, latmax = [-11, 37, 35, 70.5]                   # window for plotting
set_lonmin, set_lonmax, set_latmin, set_latmax = [-35, 65, 30, 72.6]   # subset the data to get sensible vmin and vmax for the colorbar
data_proj = ccrs.PlateCarree()
map_proj = ccrs.LambertConformal(central_longitude=15) # for regional maps

In [67]:
filepath = '/landclim2/yiyaoy/COSMO-CLM2-simulations/timmean/clm5.0_eur0.5_control_h0_2020-2024.nc_timmean'
with nc.Dataset(filepath) as ds:
    # the [:] slice gives you a NumPy array right away
    lat = ds.variables['lat'][:]
    lon = ds.variables['lon'][:]

# vectorized shift from [0,360] → [-180,180]
lon = (lon + 180) % 360 - 180

In [6]:
def get_data_mean_for_plot(var, lat, lon):
    """
    Load annual, JJA and MAM means for control, only_forest_broadleaf and all_grass
    into one xarray.Dataset keyed by:
      control, control_JJA, control_MAM,
      only_forest_broadleaf, only_forest_broadleaf_JJA, only_forest_broadleaf_MAM,
      all_grass, all_grass_JJA, all_grass_MAM
    """
    scenarios = {
        "control":                  "control/clm5",
        "only_forest_broadleaf":    "only_forest_broadleaf/clm5",
        "all_grass":                "all_grass/clm5",
    }
    seasons = {
        "":     "timmean",
        "_JJA": "seasonal/timmean",
        "_MAM": "seasonal/timmean",
    }

    template = (
        "/landclim2/yiyaoy/COSMO-CLM2-simulations/{folder}/post-processing/{proc}/"
        "mergetime/clm5.0_eur0.5_{scen}_h0_2025-2059{suffix}.nc_timmean_timmean"
    )

    data_vars = {}
    for scen, folder in scenarios.items():
        for suffix, proc in seasons.items():
            key  = f"{scen}{suffix}"
            path = template.format(folder=folder, proc=proc, scen=scen, suffix=suffix)
            
            # load, squeeze, mask >1e9, wrap in DataArray
            arr = nc.Dataset(path).variables[var][:]
            arr = np.squeeze(np.array(arr))
            arr[arr > 1e9] = np.nan
            da = xr.DataArray(arr, coords={"y": lat, "x": lon}, dims=["y", "x"])
            
            data_vars[key] = da

    return xr.Dataset(data_vars)

In [9]:
tair_M = get_data_mean_for_plot('TSA', lat, lon)
tsfc_M = get_data_mean_for_plot('TSKIN', lat, lon)
tatm_M = get_data_mean_for_plot('TBOT', lat, lon)

In [None]:
# 1) contour settings
levels_dict = {
    "seq": np.array([-5, -2, 1, 4, 7, 10, 13, 16]) + 273.15,
    "div": np.array([-2, -1, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 1, 2])
}
ticks_dict = {
    "seq": ["\N{MINUS SIGN}5", "\N{MINUS SIGN}2", "1", "4", "7", "10", "13", "16"],
    "div": ['\N{MINUS SIGN}2.0', '\N{MINUS SIGN}1.0', '\N{MINUS SIGN}0.5', '\N{MINUS SIGN}0.3', '\N{MINUS SIGN}0.1', '+0.1', '+0.3', '+0.5', '+1.0', '+2.0']
}
cmaps = {"seq": "hot_r", "div": "RdBu_r"}
unit  = "$\\mathregular{^\\circ C}$"
vmin, vmax, extend = 0, 42, "both"

# 2) pack your three xarray‐datasets
datasets = {
    "sfc": tsfc_M,
    "air": tair_M,
    "atm": tatm_M,
}

# 3) define each subplot:
#    (which dataset key, high_key, low_key, row, col, levels‐type, title‐stem, panel‐letter)
panels = [
    ("sfc", "control",             None,         0, 0, "seq", "Ctl: $\mathregular{T_{sfc,yearM}}$",             "a"),
    ("sfc", "only_forest_broadleaf","control",   0, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{sfc,yearM}}$",   "b"),
    ("sfc", "all_grass",            "control",   0, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{sfc,yearM}}$",   "c"),
    ("sfc", "only_forest_broadleaf","all_grass",  0, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{sfc,yearM}}$",   "d"),

    ("air", "control",              None,         1, 0, "seq", "Ctl: $\mathregular{T_{air,yearM}}$",             "e"),
    ("air", "only_forest_broadleaf","control",   1, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{air,yearM}}$",   "f"),
    ("air", "all_grass",            "control",   1, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{air,yearM}}$",   "g"),
    ("air", "only_forest_broadleaf","all_grass",  1, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{air,yearM}}$",   "h"),

    ("atm", "control",              None,         2, 0, "seq", "Ctl: $\mathregular{T_{atm,yearM}}$",             "i"),
    ("atm", "only_forest_broadleaf","control",   2, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{atm,yearM}}$",   "j"),
    ("atm", "all_grass",            "control",   2, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{atm,yearM}}$",   "k"),
    ("atm", "only_forest_broadleaf","all_grass",  2, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{atm,yearM}}$",   "l"),
]

# 4) build the figure
fig, axes = plt.subplots(
    3, 4,
    figsize=(12, 8),
    subplot_kw=dict(projection=map_proj, facecolor="lightgrey"),
    dpi=300
)
format_axes(axes.flatten())
fig.subplots_adjust(hspace=0.2, wspace=0.1, left=0.03, right=0.935, bottom=0.03, top=0.9)

# 5) loop through panels
for dset_key, hi, lo, r, c, lvl_type, title_stem, letter in panels:
    ax     = axes[r, c]
    ds     = datasets[dset_key]
    arr    = ds[hi] if lo is None else (ds[hi] - ds[lo])
    levels = levels_dict[lvl_type]
    ticks  = ticks_dict[lvl_type]
    cmap   = cmaps[lvl_type]

    h = arr.plot(
        ax=ax,
        transform=data_proj,
        vmin=vmin, vmax=vmax,
        levels=levels,
        extend=extend,
        add_colorbar=False,
        cmap=cmap
    )
    ax.set_title(f"{title_stem} ({unit})")
    ax.add_feature(cfeature.OCEAN, color="whitesmoke")

    # always draw colorbar on every panel here (remove the if-col==3 test if you want all)
    cbar = mpu.colorbar(
        h, ax,
        orientation="vertical",
        size=0.04,
        extend=extend,
        pad=0.05,
        ticks=levels
    )
    cbar.set_ticklabels(ticks)

    ax.text(
        0.02, 0.93, letter,
        transform=ax.transAxes,
        fontsize=10, fontweight="bold"
    )

In [54]:
SWup_M = get_data_mean_for_plot('SWup', lat, lon)
SWdown_M = get_data_mean_for_plot('SWdown', lat, lon)
SWnet_M = SWdown_M - SWup_M

In [None]:
# 1) contour settings
levels_dict = {
    "seq": np.array([5, 8, 11, 14, 17, 20, 23, 26]) + 273.15,
    "div": np.array([-2, -1, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 1, 2])
}
ticks_dict = {
    "seq": ["5", "8", "11", "14", "17", "20", "23", "26"],
    "div": ['\N{MINUS SIGN}2.0', '\N{MINUS SIGN}1.0', '\N{MINUS SIGN}0.5', '\N{MINUS SIGN}0.3', '\N{MINUS SIGN}0.1', '+0.1', '+0.3', '+0.5', '+1.0', '+2.0']
}
cmaps = {"seq": "hot_r", "div": "RdBu_r"}
unit  = "$\\mathregular{^\\circ C}$"
vmin, vmax, extend = 0, 42, "both"

# 2) pack your three xarray‐datasets
datasets = {
    "sfc": tsfc_M,
    "air": tair_M,
    "atm": tatm_M,
}

# 3) define each subplot:
#    (which dataset key, high_key, low_key, row, col, levels‐type, title‐stem, panel‐letter)
panels = [
    ("sfc", "control_JJA",             None,         0, 0, "seq", "Ctl: $\mathregular{T_{sfc,jjaM}}$",             "a"),
    ("sfc", "only_forest_broadleaf_JJA","control_JJA",   0, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{sfc,jjaM}}$",   "b"),
    ("sfc", "all_grass_JJA",            "control_JJA",   0, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{sfc,jjaM}}$",   "c"),
    ("sfc", "only_forest_broadleaf_JJA","all_grass_JJA",  0, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{sfc,jjaM}}$",   "d"),

    ("air", "control_JJA",              None,         1, 0, "seq", "Ctl: $\mathregular{T_{air,jjaM}}$",             "e"),
    ("air", "only_forest_broadleaf_JJA","control_JJA",   1, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{air,jjaM}}$",   "f"),
    ("air", "all_grass_JJA",            "control_JJA",   1, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{air,jjaM}}$",   "g"),
    ("air", "only_forest_broadleaf_JJA","all_grass_JJA",  1, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{air,jjaM}}$",   "h"),

    ("atm", "control_JJA",              None,         2, 0, "seq", "Ctl: $\mathregular{T_{atm,jjaM}}$",             "i"),
    ("atm", "only_forest_broadleaf_JJA","control_JJA",   2, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{atm,jjaM}}$",   "j"),
    ("atm", "all_grass_JJA",            "control_JJA",   2, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{atm,jjaM}}$",   "k"),
    ("atm", "only_forest_broadleaf_JJA","all_grass_JJA",  2, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{atm,jjaM}}$",   "l"),
]

# 4) build the figure
fig, axes = plt.subplots(
    3, 4,
    figsize=(12, 8),
    subplot_kw=dict(projection=map_proj, facecolor="lightgrey"),
    dpi=300
)
format_axes(axes.flatten())
fig.subplots_adjust(hspace=0.2, wspace=0.1, left=0.03, right=0.935, bottom=0.03, top=0.9)

# 5) loop through panels
for dset_key, hi, lo, r, c, lvl_type, title_stem, letter in panels:
    ax     = axes[r, c]
    ds     = datasets[dset_key]
    arr    = ds[hi] if lo is None else (ds[hi] - ds[lo])
    levels = levels_dict[lvl_type]
    ticks  = ticks_dict[lvl_type]
    cmap   = cmaps[lvl_type]

    h = arr.plot(
        ax=ax,
        transform=data_proj,
        vmin=vmin, vmax=vmax,
        levels=levels,
        extend=extend,
        add_colorbar=False,
        cmap=cmap
    )
    ax.set_title(f"{title_stem} ({unit})")
    ax.add_feature(cfeature.OCEAN, color="whitesmoke")

    # always draw colorbar on every panel here (remove the if-col==3 test if you want all)
    cbar = mpu.colorbar(
        h, ax,
        orientation="vertical",
        size=0.04,
        extend=extend,
        pad=0.05,
        ticks=levels
    )
    cbar.set_ticklabels(ticks)

    ax.text(
        0.02, 0.93, letter,
        transform=ax.transAxes,
        fontsize=10, fontweight="bold"
    )

In [None]:
LHF_M = get_data_mean_for_plot('EFLX_LH_TOT', lat, lon)
SHF_M = get_data_mean_for_plot('FSH', lat, lon)
LWdown_M = get_data_mean_for_plot('LWdown', lat, lon)

In [None]:
# 1) contour settings
levels_dict = {
    "seq": np.array([-8, -5, -2, 1, 4, 7, 10, 13]) + 273.15,
    "div": np.array([-2, -1, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 1, 2])
}
ticks_dict = {
    "seq": ["\N{MINUS SIGN}8", "\N{MINUS SIGN}5", "\N{MINUS SIGN}2", "1", "4", "7", "10", "13"],
    "div": ['\N{MINUS SIGN}2.0', '\N{MINUS SIGN}1.0', '\N{MINUS SIGN}0.5', '\N{MINUS SIGN}0.3', '\N{MINUS SIGN}0.1', '+0.1', '+0.3', '+0.5', '+1.0', '+2.0']
}
cmaps = {"seq": "hot_r", "div": "RdBu_r"}
unit  = "$\\mathregular{^\\circ C}$"
vmin, vmax, extend = 0, 42, "both"

# 2) pack your three xarray‐datasets
datasets = {
    "sfc": tsfc_M,
    "air": tair_M,
    "atm": tatm_M,
}

# 3) define each subplot:
#    (which dataset key, high_key, low_key, row, col, levels‐type, title‐stem, panel‐letter)
panels = [
    ("sfc", "control_MAM",             None,         0, 0, "seq", "Ctl: $\mathregular{T_{sfc,mamM}}$",             "a"),
    ("sfc", "only_forest_broadleaf_MAM","control_MAM",   0, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{sfc,mamM}}$",   "b"),
    ("sfc", "all_grass_MAM",            "control_MAM",   0, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{sfc,mamM}}$",   "c"),
    ("sfc", "only_forest_broadleaf_MAM","all_grass_MAM",  0, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{sfc,mamM}}$",   "d"),

    ("air", "control_MAM",              None,         1, 0, "seq", "Ctl: $\mathregular{T_{air,mamM}}$",             "e"),
    ("air", "only_forest_broadleaf_MAM","control_MAM",   1, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{air,mamM}}$",   "f"),
    ("air", "all_grass_MAM",            "control_MAM",   1, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{air,mamM}}$",   "g"),
    ("air", "only_forest_broadleaf_MAM","all_grass_MAM",  1, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{air,mamM}}$",   "h"),

    ("atm", "control_MAM",              None,         2, 0, "seq", "Ctl: $\mathregular{T_{atm,mamM}}$",             "i"),
    ("atm", "only_forest_broadleaf_MAM","control_MAM",   2, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{atm,mamM}}$",   "j"),
    ("atm", "all_grass_MAM",            "control_MAM",   2, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{atm,mamM}}$",   "k"),
    ("atm", "only_forest_broadleaf_MAM","all_grass_MAM",  2, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{atm,mamM}}$",   "l"),
]

# 4) build the figure
fig, axes = plt.subplots(
    3, 4,
    figsize=(12, 8),
    subplot_kw=dict(projection=map_proj, facecolor="lightgrey"),
    dpi=300
)
format_axes(axes.flatten())
fig.subplots_adjust(hspace=0.2, wspace=0.1, left=0.03, right=0.935, bottom=0.03, top=0.9)

# 5) loop through panels
for dset_key, hi, lo, r, c, lvl_type, title_stem, letter in panels:
    ax     = axes[r, c]
    ds     = datasets[dset_key]
    arr    = ds[hi] if lo is None else (ds[hi] - ds[lo])
    levels = levels_dict[lvl_type]
    ticks  = ticks_dict[lvl_type]
    cmap   = cmaps[lvl_type]

    h = arr.plot(
        ax=ax,
        transform=data_proj,
        vmin=vmin, vmax=vmax,
        levels=levels,
        extend=extend,
        add_colorbar=False,
        cmap=cmap
    )
    ax.set_title(f"{title_stem} ({unit})")
    ax.add_feature(cfeature.OCEAN, color="whitesmoke")

    # always draw colorbar on every panel here (remove the if-col==3 test if you want all)
    cbar = mpu.colorbar(
        h, ax,
        orientation="vertical",
        size=0.04,
        extend=extend,
        pad=0.05,
        ticks=levels
    )
    cbar.set_ticklabels(ticks)

    ax.text(
        0.02, 0.93, letter,
        transform=ax.transAxes,
        fontsize=10, fontweight="bold"
    )

In [28]:
def get_data_max_for_plot(var, lat, lon):
    """
    Load annual, JJA and MAM means for control, only_forest_broadleaf and all_grass
    into one xarray.Dataset keyed by:
      control, control_JJA, control_MAM,
      only_forest_broadleaf, only_forest_broadleaf_JJA, only_forest_broadleaf_MAM,
      all_grass, all_grass_JJA, all_grass_MAM
    """
    scenarios = {
        "control":                  "control/clm5",
        "only_forest_broadleaf":    "only_forest_broadleaf/clm5",
        "all_grass":                "all_grass/clm5",
    }
    seasons = {
        "":     "timmean",
        "_JJA": "seasonal/timmean",
        "_MAM": "seasonal/timmean",
    }

    template = (
        "/landclim2/yiyaoy/COSMO-CLM2-simulations/{folder}/post-processing/{proc}/"
        "mergetime/clm5.0_eur0.5_{scen}_h2_2025-2059{suffix}.nc_timmean_timmean"
    )

    data_vars = {}
    for scen, folder in scenarios.items():
        for suffix, proc in seasons.items():
            key  = f"{scen}{suffix}"
            path = template.format(folder=folder, proc=proc, scen=scen, suffix=suffix)
            
            # load, squeeze, mask >1e9, wrap in DataArray
            arr = nc.Dataset(path).variables[var][:]
            arr = np.squeeze(np.array(arr))
            arr[arr > 1e9] = np.nan
            da = xr.DataArray(arr, coords={"y": lat, "x": lon}, dims=["y", "x"])
            
            data_vars[key] = da

    return xr.Dataset(data_vars)

tair_X = get_data_max_for_plot('TSA', lat, lon)
tsfc_X = get_data_max_for_plot('TSKIN', lat, lon)
tatm_X = get_data_max_for_plot('TBOT', lat, lon)

In [None]:
# 1) contour settings
levels_dict = {
    "seq": np.array([3, 6, 9, 12, 15, 18, 21, 24]) + 273.15,
    "div": np.array([-2, -1, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 1, 2])
}
ticks_dict = {
    "seq": ["3", "6", "9", "12", "15", "18", "21", "24"],
    "div": ['\N{MINUS SIGN}2.0', '\N{MINUS SIGN}1.0', '\N{MINUS SIGN}0.5', '\N{MINUS SIGN}0.3', '\N{MINUS SIGN}0.1', '+0.1', '+0.3', '+0.5', '+1.0', '+2.0']
}
cmaps = {"seq": "hot_r", "div": "RdBu_r"}
unit  = "$\\mathregular{^\\circ C}$"
vmin, vmax, extend = 0, 42, "both"

# 2) pack your three xarray‐datasets
datasets = {
    "sfc": tsfc_X,
    "air": tair_X,
    "atm": tatm_X,
}

# 3) define each subplot:
#    (which dataset key, high_key, low_key, row, col, levels‐type, title‐stem, panel‐letter)
panels = [
    ("sfc", "control",             None,         0, 0, "seq", "Ctl: $\mathregular{T_{sfc,yearX}}$",             "a"),
    ("sfc", "only_forest_broadleaf","control",   0, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{sfc,yearX}}$",   "b"),
    ("sfc", "all_grass",            "control",   0, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{sfc,yearX}}$",   "c"),
    ("sfc", "only_forest_broadleaf","all_grass",  0, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{sfc,yearX}}$",   "d"),

    ("air", "control",              None,         1, 0, "seq", "Ctl: $\mathregular{T_{air,yearX}}$",             "e"),
    ("air", "only_forest_broadleaf","control",   1, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{air,yearX}}$",   "f"),
    ("air", "all_grass",            "control",   1, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{air,yearX}}$",   "g"),
    ("air", "only_forest_broadleaf","all_grass",  1, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{air,yearX}}$",   "h"),

    ("atm", "control",              None,         2, 0, "seq", "Ctl: $\mathregular{T_{atm,yearX}}$",             "i"),
    ("atm", "only_forest_broadleaf","control",   2, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{atm,yearX}}$",   "j"),
    ("atm", "all_grass",            "control",   2, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{atm,yearX}}$",   "k"),
    ("atm", "only_forest_broadleaf","all_grass",  2, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{atm,yearX}}$",   "l"),
]

# 4) build the figure
fig, axes = plt.subplots(
    3, 4,
    figsize=(12, 8),
    subplot_kw=dict(projection=map_proj, facecolor="lightgrey"),
    dpi=300
)
format_axes(axes.flatten())
fig.subplots_adjust(hspace=0.2, wspace=0.1, left=0.03, right=0.935, bottom=0.03, top=0.9)

# 5) loop through panels
for dset_key, hi, lo, r, c, lvl_type, title_stem, letter in panels:
    ax     = axes[r, c]
    ds     = datasets[dset_key]
    arr    = ds[hi] if lo is None else (ds[hi] - ds[lo])
    levels = levels_dict[lvl_type]
    ticks  = ticks_dict[lvl_type]
    cmap   = cmaps[lvl_type]

    h = arr.plot(
        ax=ax,
        transform=data_proj,
        vmin=vmin, vmax=vmax,
        levels=levels,
        extend=extend,
        add_colorbar=False,
        cmap=cmap
    )
    ax.set_title(f"{title_stem} ({unit})")
    ax.add_feature(cfeature.OCEAN, color="whitesmoke")

    # always draw colorbar on every panel here (remove the if-col==3 test if you want all)
    cbar = mpu.colorbar(
        h, ax,
        orientation="vertical",
        size=0.04,
        extend=extend,
        pad=0.05,
        ticks=levels
    )
    cbar.set_ticklabels(ticks)

    ax.text(
        0.02, 0.93, letter,
        transform=ax.transAxes,
        fontsize=10, fontweight="bold"
    )

In [None]:
# 1) contour settings
levels_dict = {
    "seq": np.array([15, 18, 21, 24, 27, 30, 33, 36]) + 273.15,
    "div": np.array([-2, -1, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 1, 2])
}
ticks_dict = {
    "seq": ["15", "18", "21", "24", "27", "30", "33", "36"],
    "div": ['\N{MINUS SIGN}2.0', '\N{MINUS SIGN}1.0', '\N{MINUS SIGN}0.5', '\N{MINUS SIGN}0.3', '\N{MINUS SIGN}0.1', '+0.1', '+0.3', '+0.5', '+1.0', '+2.0']
}
cmaps = {"seq": "hot_r", "div": "RdBu_r"}
unit  = "$\\mathregular{^\\circ C}$"
vmin, vmax, extend = 0, 42, "both"

# 2) pack your three xarray‐datasets
datasets = {
    "sfc": tsfc_X,
    "air": tair_X,
    "atm": tatm_X,
}

# 3) define each subplot:
#    (which dataset key, high_key, low_key, row, col, levels‐type, title‐stem, panel‐letter)
panels = [
    ("sfc", "control_JJA",             None,         0, 0, "seq", "Ctl: $\mathregular{T_{sfc,jjaX}}$",             "a"),
    ("sfc", "only_forest_broadleaf_JJA","control_JJA",   0, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{sfc,jjaX}}$",   "b"),
    ("sfc", "all_grass_JJA",            "control_JJA",   0, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{sfc,jjaX}}$",   "c"),
    ("sfc", "only_forest_broadleaf_JJA","all_grass_JJA",  0, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{sfc,jjaX}}$",   "d"),

    ("air", "control_JJA",              None,         1, 0, "seq", "Ctl: $\mathregular{T_{air,jjaX}}$",             "e"),
    ("air", "only_forest_broadleaf_JJA","control_JJA",   1, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{air,jjaX}}$",   "f"),
    ("air", "all_grass_JJA",            "control_JJA",   1, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{air,jjaX}}$",   "g"),
    ("air", "only_forest_broadleaf_JJA","all_grass_JJA",  1, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{air,jjaX}}$",   "h"),

    ("atm", "control_JJA",              None,         2, 0, "seq", "Ctl: $\mathregular{T_{atm,jjaX}}$",             "i"),
    ("atm", "only_forest_broadleaf_JJA","control_JJA",   2, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{atm,jjaX}}$",   "j"),
    ("atm", "all_grass_JJA",            "control_JJA",   2, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{atm,jjaX}}$",   "k"),
    ("atm", "only_forest_broadleaf_JJA","all_grass_JJA",  2, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{atm,jjaX}}$",   "l"),
]

# 4) build the figure
fig, axes = plt.subplots(
    3, 4,
    figsize=(12, 8),
    subplot_kw=dict(projection=map_proj, facecolor="lightgrey"),
    dpi=300
)
format_axes(axes.flatten())
fig.subplots_adjust(hspace=0.2, wspace=0.1, left=0.03, right=0.935, bottom=0.03, top=0.9)

# 5) loop through panels
for dset_key, hi, lo, r, c, lvl_type, title_stem, letter in panels:
    ax     = axes[r, c]
    ds     = datasets[dset_key]
    arr    = ds[hi] if lo is None else (ds[hi] - ds[lo])
    levels = levels_dict[lvl_type]
    ticks  = ticks_dict[lvl_type]
    cmap   = cmaps[lvl_type]

    h = arr.plot(
        ax=ax,
        transform=data_proj,
        vmin=vmin, vmax=vmax,
        levels=levels,
        extend=extend,
        add_colorbar=False,
        cmap=cmap
    )
    ax.set_title(f"{title_stem} ({unit})")
    ax.add_feature(cfeature.OCEAN, color="whitesmoke")

    # always draw colorbar on every panel here (remove the if-col==3 test if you want all)
    cbar = mpu.colorbar(
        h, ax,
        orientation="vertical",
        size=0.04,
        extend=extend,
        pad=0.05,
        ticks=levels
    )
    cbar.set_ticklabels(ticks)

    ax.text(
        0.02, 0.93, letter,
        transform=ax.transAxes,
        fontsize=10, fontweight="bold"
    )

In [None]:
# 1) contour settings
levels_dict = {
    "seq": np.array([3, 6, 9, 12, 15, 18, 21, 24]) + 273.15,
    "div": np.array([-2, -1, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 1, 2])
}
ticks_dict = {
    "seq": ["3", "6", "9", "12", "15", "18", "21", "24"],
    "div": ['\N{MINUS SIGN}2.0', '\N{MINUS SIGN}1.0', '\N{MINUS SIGN}0.5', '\N{MINUS SIGN}0.3', '\N{MINUS SIGN}0.1', '+0.1', '+0.3', '+0.5', '+1.0', '+2.0']
}
cmaps = {"seq": "hot_r", "div": "RdBu_r"}
unit  = "$\\mathregular{^\\circ C}$"
vmin, vmax, extend = 0, 42, "both"

# 2) pack your three xarray‐datasets
datasets = {
    "sfc": tsfc_X,
    "air": tair_X,
    "atm": tatm_X,
}

# 3) define each subplot:
#    (which dataset key, high_key, low_key, row, col, levels‐type, title‐stem, panel‐letter)
panels = [
    ("sfc", "control_MAM",             None,         0, 0, "seq", "Ctl: $\mathregular{T_{sfc,mamX}}$",             "a"),
    ("sfc", "only_forest_broadleaf_MAM","control_MAM",   0, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{sfc,mamX}}$",   "b"),
    ("sfc", "all_grass_MAM",            "control_MAM",   0, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{sfc,mamX}}$",   "c"),
    ("sfc", "only_forest_broadleaf_MAM","all_grass_MAM",  0, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{sfc,mamX}}$",   "d"),

    ("air", "control_MAM",              None,         1, 0, "seq", "Ctl: $\mathregular{T_{air,mamX}}$",             "e"),
    ("air", "only_forest_broadleaf_MAM","control_MAM",   1, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{air,mamX}}$",   "f"),
    ("air", "all_grass_MAM",            "control_MAM",   1, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{air,mamX}}$",   "g"),
    ("air", "only_forest_broadleaf_MAM","all_grass_MAM",  1, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{air,mamX}}$",   "h"),

    ("atm", "control_MAM",              None,         2, 0, "seq", "Ctl: $\mathregular{T_{atm,mamX}}$",             "i"),
    ("atm", "only_forest_broadleaf_MAM","control_MAM",   2, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{atm,mamX}}$",   "j"),
    ("atm", "all_grass_MAM",            "control_MAM",   2, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{atm,mamX}}$",   "k"),
    ("atm", "only_forest_broadleaf_MAM","all_grass_MAM",  2, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{atm,mamX}}$",   "l"),
]

# 4) build the figure
fig, axes = plt.subplots(
    3, 4,
    figsize=(12, 8),
    subplot_kw=dict(projection=map_proj, facecolor="lightgrey"),
    dpi=300
)
format_axes(axes.flatten())
fig.subplots_adjust(hspace=0.2, wspace=0.1, left=0.03, right=0.935, bottom=0.03, top=0.9)

# 5) loop through panels
for dset_key, hi, lo, r, c, lvl_type, title_stem, letter in panels:
    ax     = axes[r, c]
    ds     = datasets[dset_key]
    arr    = ds[hi] if lo is None else (ds[hi] - ds[lo])
    levels = levels_dict[lvl_type]
    ticks  = ticks_dict[lvl_type]
    cmap   = cmaps[lvl_type]

    h = arr.plot(
        ax=ax,
        transform=data_proj,
        vmin=vmin, vmax=vmax,
        levels=levels,
        extend=extend,
        add_colorbar=False,
        cmap=cmap
    )
    ax.set_title(f"{title_stem} ({unit})")
    ax.add_feature(cfeature.OCEAN, color="whitesmoke")

    # always draw colorbar on every panel here (remove the if-col==3 test if you want all)
    cbar = mpu.colorbar(
        h, ax,
        orientation="vertical",
        size=0.04,
        extend=extend,
        pad=0.05,
        ticks=levels
    )
    cbar.set_ticklabels(ticks)

    ax.text(
        0.02, 0.93, letter,
        transform=ax.transAxes,
        fontsize=10, fontweight="bold"
    )

In [32]:
def get_data_min_for_plot(var, lat, lon):
    """
    Load annual, JJA and MAM means for control, only_forest_broadleaf and all_grass
    into one xarray.Dataset keyed by:
      control, control_JJA, control_MAM,
      only_forest_broadleaf, only_forest_broadleaf_JJA, only_forest_broadleaf_MAM,
      all_grass, all_grass_JJA, all_grass_MAM
    """
    scenarios = {
        "control":                  "control/clm5",
        "only_forest_broadleaf":    "only_forest_broadleaf/clm5",
        "all_grass":                "all_grass/clm5",
    }
    seasons = {
        "":     "timmean",
        "_JJA": "seasonal/timmean",
        "_MAM": "seasonal/timmean",
    }

    template = (
        "/landclim2/yiyaoy/COSMO-CLM2-simulations/{folder}/post-processing/{proc}/"
        "mergetime/clm5.0_eur0.5_{scen}_h4_2025-2059{suffix}.nc_timmean_timmean"
    )

    data_vars = {}
    for scen, folder in scenarios.items():
        for suffix, proc in seasons.items():
            key  = f"{scen}{suffix}"
            path = template.format(folder=folder, proc=proc, scen=scen, suffix=suffix)
            
            # load, squeeze, mask >1e9, wrap in DataArray
            arr = nc.Dataset(path).variables[var][:]
            arr = np.squeeze(np.array(arr))
            arr[arr > 1e9] = np.nan
            da = xr.DataArray(arr, coords={"y": lat, "x": lon}, dims=["y", "x"])
            
            data_vars[key] = da

    return xr.Dataset(data_vars)

tair_N = get_data_min_for_plot('TSA', lat, lon)
tsfc_N = get_data_min_for_plot('TSKIN', lat, lon)
tatm_N = get_data_min_for_plot('TBOT', lat, lon)

In [None]:
# 1) contour settings
levels_dict = {
    "seq": np.array([-8, -5, -2, 1, 4, 7, 10, 13]) + 273.15,
    "div": np.array([-2, -1, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 1, 2])
}
ticks_dict = {
    "seq": ["\N{MINUS SIGN}8", "\N{MINUS SIGN}5", "\N{MINUS SIGN}2", "1", "4", "7", "10", "13"],
    "div": ['\N{MINUS SIGN}2.0', '\N{MINUS SIGN}1.0', '\N{MINUS SIGN}0.5', '\N{MINUS SIGN}0.3', '\N{MINUS SIGN}0.1', '+0.1', '+0.3', '+0.5', '+1.0', '+2.0']
}
cmaps = {"seq": "hot_r", "div": "RdBu_r"}
unit  = "$\\mathregular{^\\circ C}$"
vmin, vmax, extend = 0, 42, "both"

# 2) pack your three xarray‐datasets
datasets = {
    "sfc": tsfc_N,
    "air": tair_N,
    "atm": tatm_N,
}

# 3) define each subplot:
#    (which dataset key, high_key, low_key, row, col, levels‐type, title‐stem, panel‐letter)
panels = [
    ("sfc", "control",             None,         0, 0, "seq", "Ctl: $\mathregular{T_{sfc,yearN}}$",             "a"),
    ("sfc", "only_forest_broadleaf","control",   0, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{sfc,yearN}}$",   "b"),
    ("sfc", "all_grass",            "control",   0, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{sfc,yearN}}$",   "c"),
    ("sfc", "only_forest_broadleaf","all_grass",  0, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{sfc,yearN}}$",   "d"),

    ("air", "control",              None,         1, 0, "seq", "Ctl: $\mathregular{T_{air,yearN}}$",             "e"),
    ("air", "only_forest_broadleaf","control",   1, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{air,yearN}}$",   "f"),
    ("air", "all_grass",            "control",   1, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{air,yearN}}$",   "g"),
    ("air", "only_forest_broadleaf","all_grass",  1, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{air,yearN}}$",   "h"),

    ("atm", "control",              None,         2, 0, "seq", "Ctl: $\mathregular{T_{atm,yearN}}$",             "i"),
    ("atm", "only_forest_broadleaf","control",   2, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{atm,yearN}}$",   "j"),
    ("atm", "all_grass",            "control",   2, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{atm,yearN}}$",   "k"),
    ("atm", "only_forest_broadleaf","all_grass",  2, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{atm,yearN}}$",   "l"),
]

# 4) build the figure
fig, axes = plt.subplots(
    3, 4,
    figsize=(12, 8),
    subplot_kw=dict(projection=map_proj, facecolor="lightgrey"),
    dpi=300
)
format_axes(axes.flatten())
fig.subplots_adjust(hspace=0.2, wspace=0.1, left=0.03, right=0.935, bottom=0.03, top=0.9)

# 5) loop through panels
for dset_key, hi, lo, r, c, lvl_type, title_stem, letter in panels:
    ax     = axes[r, c]
    ds     = datasets[dset_key]
    arr    = ds[hi] if lo is None else (ds[hi] - ds[lo])
    levels = levels_dict[lvl_type]
    ticks  = ticks_dict[lvl_type]
    cmap   = cmaps[lvl_type]

    h = arr.plot(
        ax=ax,
        transform=data_proj,
        vmin=vmin, vmax=vmax,
        levels=levels,
        extend=extend,
        add_colorbar=False,
        cmap=cmap
    )
    ax.set_title(f"{title_stem} ({unit})")
    ax.add_feature(cfeature.OCEAN, color="whitesmoke")

    # always draw colorbar on every panel here (remove the if-col==3 test if you want all)
    cbar = mpu.colorbar(
        h, ax,
        orientation="vertical",
        size=0.04,
        extend=extend,
        pad=0.05,
        ticks=levels
    )
    cbar.set_ticklabels(ticks)

    ax.text(
        0.02, 0.93, letter,
        transform=ax.transAxes,
        fontsize=10, fontweight="bold"
    )

In [None]:
# 1) contour settings
levels_dict = {
    "seq": np.array([1, 4, 7, 10, 13, 16, 19, 22]) + 273.15,
    "div": np.array([-2, -1, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 1, 2])
}
ticks_dict = {
    "seq": ["1", "4", "7", "10", "13", "16", "19", "22"],
    "div": ['\N{MINUS SIGN}2.0', '\N{MINUS SIGN}1.0', '\N{MINUS SIGN}0.5', '\N{MINUS SIGN}0.3', '\N{MINUS SIGN}0.1', '+0.1', '+0.3', '+0.5', '+1.0', '+2.0']
}
cmaps = {"seq": "hot_r", "div": "RdBu_r"}
unit  = "$\\mathregular{^\\circ C}$"
vmin, vmax, extend = 0, 42, "both"

# 2) pack your three xarray‐datasets
datasets = {
    "sfc": tsfc_N,
    "air": tair_N,
    "atm": tatm_N,
}

# 3) define each subplot:
#    (which dataset key, high_key, low_key, row, col, levels‐type, title‐stem, panel‐letter)
panels = [
    ("sfc", "control_JJA",             None,         0, 0, "seq", "Ctl: $\mathregular{T_{sfc,jjaN}}$",             "a"),
    ("sfc", "only_forest_broadleaf_JJA","control_JJA",   0, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{sfc,jjaN}}$",   "b"),
    ("sfc", "all_grass_JJA",            "control_JJA",   0, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{sfc,jjaN}}$",   "c"),
    ("sfc", "only_forest_broadleaf_JJA","all_grass_JJA",  0, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{sfc,jjaN}}$",   "d"),

    ("air", "control_JJA",              None,         1, 0, "seq", "Ctl: $\mathregular{T_{air,jjaN}}$",             "e"),
    ("air", "only_forest_broadleaf_JJA","control_JJA",   1, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{air,jjaN}}$",   "f"),
    ("air", "all_grass_JJA",            "control_JJA",   1, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{air,jjaN}}$",   "g"),
    ("air", "only_forest_broadleaf_JJA","all_grass_JJA",  1, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{air,jjaN}}$",   "h"),

    ("atm", "control_JJA",              None,         2, 0, "seq", "Ctl: $\mathregular{T_{atm,jjaN}}$",             "i"),
    ("atm", "only_forest_broadleaf_JJA","control_JJA",   2, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{atm,jjaN}}$",   "j"),
    ("atm", "all_grass_JJA",            "control_JJA",   2, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{atm,jjaN}}$",   "k"),
    ("atm", "only_forest_broadleaf_JJA","all_grass_JJA",  2, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{atm,jjaN}}$",   "l"),
]

# 4) build the figure
fig, axes = plt.subplots(
    3, 4,
    figsize=(12, 8),
    subplot_kw=dict(projection=map_proj, facecolor="lightgrey"),
    dpi=300
)
format_axes(axes.flatten())
fig.subplots_adjust(hspace=0.2, wspace=0.1, left=0.03, right=0.935, bottom=0.03, top=0.9)

# 5) loop through panels
for dset_key, hi, lo, r, c, lvl_type, title_stem, letter in panels:
    ax     = axes[r, c]
    ds     = datasets[dset_key]
    arr    = ds[hi] if lo is None else (ds[hi] - ds[lo])
    levels = levels_dict[lvl_type]
    ticks  = ticks_dict[lvl_type]
    cmap   = cmaps[lvl_type]

    h = arr.plot(
        ax=ax,
        transform=data_proj,
        vmin=vmin, vmax=vmax,
        levels=levels,
        extend=extend,
        add_colorbar=False,
        cmap=cmap
    )
    ax.set_title(f"{title_stem} ({unit})")
    ax.add_feature(cfeature.OCEAN, color="whitesmoke")

    # always draw colorbar on every panel here (remove the if-col==3 test if you want all)
    cbar = mpu.colorbar(
        h, ax,
        orientation="vertical",
        size=0.04,
        extend=extend,
        pad=0.05,
        ticks=levels
    )
    cbar.set_ticklabels(ticks)

    ax.text(
        0.02, 0.93, letter,
        transform=ax.transAxes,
        fontsize=10, fontweight="bold"
    )

In [None]:
# 1) contour settings
levels_dict = {
    "seq": np.array([-8, -5, -2, 1, 4, 7, 10, 13]) + 273.15,
    "div": np.array([-2, -1, -0.5, -0.3, -0.1, 0.1, 0.3, 0.5, 1, 2])
}
ticks_dict = {
    "seq": ["\N{MINUS SIGN}8", "\N{MINUS SIGN}5", "\N{MINUS SIGN}2", "1", "4", "7", "10", "13"],
    "div": ['\N{MINUS SIGN}2.0', '\N{MINUS SIGN}1.0', '\N{MINUS SIGN}0.5', '\N{MINUS SIGN}0.3', '\N{MINUS SIGN}0.1', '+0.1', '+0.3', '+0.5', '+1.0', '+2.0']
}
cmaps = {"seq": "hot_r", "div": "RdBu_r"}
unit  = "$\\mathregular{^\\circ C}$"
vmin, vmax, extend = 0, 42, "both"

# 2) pack your three xarray‐datasets
datasets = {
    "sfc": tsfc_N,
    "air": tair_N,
    "atm": tatm_N,
}

# 3) define each subplot:
#    (which dataset key, high_key, low_key, row, col, levels‐type, title‐stem, panel‐letter)
panels = [
    ("sfc", "control_MAM",             None,         0, 0, "seq", "Ctl: $\mathregular{T_{sfc,mamN}}$",             "a"),
    ("sfc", "only_forest_broadleaf_MAM","control_MAM",   0, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{sfc,mamN}}$",   "b"),
    ("sfc", "all_grass_MAM",            "control_MAM",   0, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{sfc,mamN}}$",   "c"),
    ("sfc", "only_forest_broadleaf_MAM","all_grass_MAM",  0, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{sfc,mamN}}$",   "d"),

    ("air", "control_MAM",              None,         1, 0, "seq", "Ctl: $\mathregular{T_{air,mamN}}$",             "e"),
    ("air", "only_forest_broadleaf_MAM","control_MAM",   1, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{air,mamN}}$",   "f"),
    ("air", "all_grass_MAM",            "control_MAM",   1, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{air,mamN}}$",   "g"),
    ("air", "only_forest_broadleaf_MAM","all_grass_MAM",  1, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{air,mamN}}$",   "h"),

    ("atm", "control_MAM",              None,         2, 0, "seq", "Ctl: $\mathregular{T_{atm,mamN}}$",             "i"),
    ("atm", "only_forest_broadleaf_MAM","control_MAM",   2, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{T_{atm,mamN}}$",   "j"),
    ("atm", "all_grass_MAM",            "control_MAM",   2, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{T_{atm,mamN}}$",   "k"),
    ("atm", "only_forest_broadleaf_MAM","all_grass_MAM",  2, 3, "div", "Brd–Def: $\\Delta$$\mathregular{T_{atm,mamN}}$",   "l"),
]

# 4) build the figure
fig, axes = plt.subplots(
    3, 4,
    figsize=(12, 8),
    subplot_kw=dict(projection=map_proj, facecolor="lightgrey"),
    dpi=300
)
format_axes(axes.flatten())
fig.subplots_adjust(hspace=0.2, wspace=0.1, left=0.03, right=0.935, bottom=0.03, top=0.9)

# 5) loop through panels
for dset_key, hi, lo, r, c, lvl_type, title_stem, letter in panels:
    ax     = axes[r, c]
    ds     = datasets[dset_key]
    arr    = ds[hi] if lo is None else (ds[hi] - ds[lo])
    levels = levels_dict[lvl_type]
    ticks  = ticks_dict[lvl_type]
    cmap   = cmaps[lvl_type]

    h = arr.plot(
        ax=ax,
        transform=data_proj,
        vmin=vmin, vmax=vmax,
        levels=levels,
        extend=extend,
        add_colorbar=False,
        cmap=cmap
    )
    ax.set_title(f"{title_stem} ({unit})")
    ax.add_feature(cfeature.OCEAN, color="whitesmoke")

    # always draw colorbar on every panel here (remove the if-col==3 test if you want all)
    cbar = mpu.colorbar(
        h, ax,
        orientation="vertical",
        size=0.04,
        extend=extend,
        pad=0.05,
        ticks=levels
    )
    cbar.set_ticklabels(ticks)

    ax.text(
        0.02, 0.93, letter,
        transform=ax.transAxes,
        fontsize=10, fontweight="bold"
    )