In [6]:
import netCDF4 as nc
import numpy
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("ignore", category=FutureWarning)
warnings.filterwarnings("ignore", category=UserWarning)
import os
import cartopy.feature as cfeature
import glob
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.colors as colors
from matplotlib import cm
#import seaborn as sns
import cartopy.crs as ccrs
import cartopy.feature as cf
import mplotutils as mpu # helper functions for cartopy and matplotlib
#import regionmask
from func_plots import *
from func_stats import *

In [7]:
# Regional plot with coastline and country borders
def format_axes(axes):
    for i, ax in enumerate(axes):
        #ax.set_title(plot_names[i])
        ax.coastlines(resolution='50m', linewidth=0.45)
        # ax.add_feature(cf.BORDERS, linewidth=0.3)
        ax.set_extent([lonmin, lonmax, latmin, latmax], crs=data_proj)
        ax.gridlines(draw_labels=False, linewidth=0.3, color="gray", xlocs=range(-180, 180, 10), ylocs=range(-90, 90, 10))

def add_headers(fig, *, row_headers=None, col_headers=None, row_pad=1, col_pad=4, rotate_row_headers=True, **text_kwargs):
    # Based on https://stackoverflow.com/a/25814386
    axes = fig.get_axes()

    for ax in axes:
        sbs = ax.get_subplotspec()

        # Putting headers on cols
        if (col_headers is not None) and sbs.is_first_row():
            ax.annotate(
                col_headers[sbs.colspan.start],
                xy=(0.5, 1),
                xytext=(0, col_pad),
                xycoords="axes fraction",
                textcoords="offset points",
                ha="center",
                va="baseline",
                **text_kwargs,
            )

        # Putting headers on rows
        if (row_headers is not None) and sbs.is_first_col():
            ax.annotate(
                row_headers[sbs.rowspan.start],
                xy=(0, 0.5),
                xytext=(-ax.yaxis.labelpad - row_pad, 0),
                xycoords=ax.transAxes, # ax.yaxis.label replaced by ax.transAxes
                textcoords="offset points",
                ha="center",
                va="center", # vertical even after rotation
                rotation=rotate_row_headers * 90,
                **text_kwargs,
            )

lonmin, lonmax, latmin, latmax = [-11, 37, 35, 70.5]                   # window for plotting
set_lonmin, set_lonmax, set_latmin, set_latmax = [-35, 65, 30, 72.6]   # subset the data to get sensible vmin and vmax for the colorbar
data_proj = ccrs.PlateCarree()
map_proj = ccrs.LambertConformal(central_longitude=15) # for regional maps

In [8]:
filepath = '/landclim2/yiyaoy/COSMO-CLM2-simulations/timmean/clm5.0_eur0.5_control_h0_2020-2024.nc_timmean'
with nc.Dataset(filepath) as ds:
    lat = ds.variables['lat'][:]       # already a NumPy array
    lon = ds.variables['lon'][:]       # likewise

# shift longitudes from [0,360] to [-180,180]
lon = np.where(lon > 180, lon - 360, lon)

In [9]:
def _load_case(var, lat, lon, scenario, season_suffix, subdir):
    """
    Helper to load one file, mask >1e9, and wrap in an xarray.DataArray.
    """
    fn = (
        f"/landclim2/yiyaoy/COSMO-CLM2-simulations/"
        f"{subdir}/post-processing/{'timmean' if season_suffix == '' else 'seasonal/timmean'}"
        f"/mergetime/"
        f"clm5.0_eur0.5_{scenario}_h0_2025-2059"
        f"{season_suffix}.nc_timmean_timmean"
    )
    ds = nc.Dataset(fn)
    arr = np.squeeze(np.array(ds.variables[var][:]))
    arr[arr > 1e9] = np.nan
    return xr.DataArray(arr, coords={'y': lat, 'x': lon}, dims=['y', 'x'])

def get_mean_data_aff_afb(var, lat, lon):
    """
    Returns an xarray.Dataset with one DataArray per scenario-season.
    Keys will be e.g. 'control', 'control_JJA', 'all_forest_MAM', etc.
    """
    cases = [
        # (scenario-name, season-suffix, top‑level folder)
        ('control',      '',     'control/clm5'),
        ('control',      '_JJA', 'control/clm5'),
        ('control',      '_MAM', 'control/clm5'),
        ('all_forest',   '',     'all_forest/clm5'),
        ('all_forest',   '_JJA', 'all_forest/clm5'),
        ('all_forest',   '_MAM', 'all_forest/clm5'),
        ('all_grass',   '',      'all_grass/clm5'),
        ('all_grass',   '_JJA',  'all_grass/clm5'),
        ('all_grass',   '_MAM',  'all_grass/clm5'),
        ('all_forest_broadleaf',    '',     'all_forest_broadleaf/clm5'),
        ('all_forest_broadleaf',    '_JJA', 'all_forest_broadleaf/clm5'),
        ('all_forest_broadleaf',    '_MAM', 'all_forest_broadleaf/clm5'),
    ]

    data_vars = {}
    for scen, season, subdir in cases:
        key = f"{scen}{season}"
        data_vars[key] = _load_case(var, lat, lon, scen, season, subdir)
    
    return xr.Dataset(data_vars)

In [11]:
SWup_M = get_mean_data_aff_afb('SWup', lat, lon)
SWdown_M = get_mean_data_aff_afb('SWdown', lat, lon)
SWnet_M = SWdown_M - SWup_M

In [None]:
# 1. Levels, tick‑labels and cmaps
levels_dict = {
    'seq': np.array([-5, 0, 5, 10, 15, 20, 25, 30]) + 273.15,
    "div": np.array([-16, -8, -4, -2, -1, 1, 2, 4, 8, 16])
}
ticks_dict = {
    "seq": ["\N{MINUS SIGN}5", "0", "5", "10", "15", "20", "25", "30"],
    "div": ['\N{MINUS SIGN}16', '\N{MINUS SIGN}8', '\N{MINUS SIGN}4', '\N{MINUS SIGN}2', '\N{MINUS SIGN}1', '+1', '+2', '+4', '+8', '+16']
}
cmaps = {'seq': 'hot_r', 'div': 'RdBu_r'}

unit = '$\\mathregular{W/m^2}$'
vmin, vmax, extend = 0, 42, 'both'

# 2. Define each panel: (row, col), (var1, var0, levels‑type), title, letter
panels = [
    
    
    ("up", "all_forest",             "control",         0, 0, "div", "Ctl: $\mathregular{SW_{up,yearM}}$",             "a"),
    ("up", "all_forest_broadleaf","control",   0, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{SW_{up,yearM}}$",   "b"),
    ("up", "all_forest_broadleaf",            "all_forest",   0, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{SW_{up,yearM}}$",   "c"),
    ("up", "all_forest_broadleaf","all_grass",  0, 3, "div", "Brd–Def: $\\Delta$$\mathregular{SW_{up,yearM}}$",   "d"),

    ("down", "all_forest",              "control",         1, 0, "div", "Ctl: $\mathregular{SW_{down,yearM}}$",             "e"),
    ("down", "all_forest_broadleaf","control",   1, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{SW_{down,yearM}}$",   "f"),
    ("down", "all_forest_broadleaf",            "all_forest",   1, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{SW_{down,yearM}}$",   "g"),
    ("down", "all_forest_broadleaf","all_grass",  1, 3, "div", "Brd–Def: $\\Delta$$\mathregular{SW_{down,yearM}}$",   "h"),

    ("net", "all_forest",              "control",         2, 0, "div", "Ctl: $\mathregular{SW_{net,yearM}}$",             "i"),
    ("net", "all_forest_broadleaf","control",   2, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{SW_{net,yearM}}$",   "j"),
    ("net", "all_forest_broadleaf",            "all_forest",   2, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{SW_{net,yearM}}$",   "k"),
    ("net", "all_forest_broadleaf","all_grass",  2, 3, "div", "Brd–Def: $\\Delta$$\mathregular{SW_{net,yearM}}$",   "l"),
]

# 2) pack your three xarray‐datasets
datasets = {
    "up": SWup_M,
    "down": SWdown_M,
    "net": SWnet_M,
}

# 3. Make the figure & axes grid
# 4) build the figure
fig, axes = plt.subplots(
    3, 4,
    figsize=(12, 8),
    subplot_kw=dict(projection=map_proj, facecolor="lightgrey"),
    dpi=300
)
format_axes(axes.flatten())
fig.subplots_adjust(hspace=0.2, wspace=0.1, left=0.03, right=0.935, bottom=0.03, top=0.9)

# 4. Loop and plot
# 5) loop through panels
for dset_key, hi, lo, r, c, lvl_type, title_stem, letter in panels:
    ax     = axes[r, c]
    ds     = datasets[dset_key]
    arr    = ds[hi] if lo is None else (ds[hi] - ds[lo])
    levels = levels_dict[lvl_type]
    ticks  = ticks_dict[lvl_type]
    cmap   = cmaps[lvl_type]

    h = arr.plot(
        ax=ax,
        transform=data_proj,
        vmin=vmin, vmax=vmax,
        levels=levels,
        extend=extend,
        add_colorbar=False,
        cmap=cmap
    )
    ax.set_title(f"{title_stem} ({unit})")
    ax.add_feature(cfeature.OCEAN, color="whitesmoke")

    # always draw colorbar on every panel here (remove the if-col==3 test if you want all)
    cbar = mpu.colorbar(
        h, ax,
        orientation="vertical",
        size=0.04,
        extend=extend,
        pad=0.05,
        ticks=levels
    )
    cbar.set_ticklabels(ticks)

    ax.text(
        0.02, 0.93, letter,
        transform=ax.transAxes,
        fontsize=10, fontweight="bold"
    )

In [None]:
# 1. Levels, tick‑labels and cmaps
levels_dict = {
    'seq': np.array([-5, 0, 5, 10, 15, 20, 25, 30]) + 273.15,
    "div": np.array([-16, -8, -4, -2, -1, 1, 2, 4, 8, 16])
}
ticks_dict = {
    "seq": ["\N{MINUS SIGN}5", "0", "5", "10", "15", "20", "25", "30"],
    "div": ['\N{MINUS SIGN}16', '\N{MINUS SIGN}8', '\N{MINUS SIGN}4', '\N{MINUS SIGN}2', '\N{MINUS SIGN}1', '+1', '+2', '+4', '+8', '+16']
}
cmaps = {'seq': 'hot_r', 'div': 'RdBu_r'}

unit = '$\\mathregular{W/m^2}$'
vmin, vmax, extend = 0, 42, 'both'

# 2. Define each panel: (row, col), (var1, var0, levels‑type), title, letter
panels = [
    
    
    ("up", "all_forest_JJA",             "control_JJA",         0, 0, "div",   "Aff-Ctl: $\\Delta$$\mathregular{SW_{up,jjaM}}$",             "a"),
    ("up", "all_forest_broadleaf_JJA",   "control_JJA",   0, 1, "div",            "AfB–Ctl: $\\Delta$$\mathregular{SW_{up,jjaM}}$",   "b"),
    ("up", "all_forest_broadleaf_JJA",            "all_forest",   0, 2, "div", "AfB–Aff: $\\Delta$$\mathregular{SW_{up,jjaM}}$",   "c"),
    ("up", "all_forest_broadleaf_JJA","all_grass_JJA",  0, 3, "div",           "AfB–Def: $\\Delta$$\mathregular{SW_{up,jjaM}}$",   "d"),

    ("down", "all_forest_JJA",              "control_JJA",         1, 0, "div",      "Aff-Ctl: $\\Delta$$\mathregular{SW_{down,jjaM}}$",             "e"),
    ("down", "all_forest_broadleaf_JJA","control_JJA",   1, 1, "div",                "AfB–Ctl: $\\Delta$$\mathregular{SW_{down,jjaM}}$",   "f"),
    ("down", "all_forest_broadleaf_JJA",            "all_forest_JJA",   1, 2, "div", "AfB–Aff: $\\Delta$$\mathregular{SW_{down,jjaM}}$",   "g"),
    ("down", "all_forest_broadleaf_JJA","all_grass_JJA",  1, 3, "div",               "AfB–Def: $\\Delta$$\mathregular{SW_{down,jjaM}}$",   "h"),

    ("net", "all_forest_JJA",              "control_JJA",         2, 0, "div", "Aff-Ctl: $\\Delta$$\mathregular{SW_{net,jjaM}}$",             "i"),
    ("net", "all_forest_broadleaf_JJA","control_JJA",   2, 1, "div", "AfB–Ctl: $\\Delta$$\mathregular{SW_{net,jjaM}}$",   "j"),
    ("net", "all_forest_broadleaf_JJA",            "all_forest_JJA",   2, 2, "div", "AfB–Aff: $\\Delta$$\mathregular{SW_{net,jjaM}}$",   "k"),
    ("net", "all_forest_broadleaf_JJA","all_grass_JJA",  2, 3, "div", "AfB–Def: $\\Delta$$\mathregular{SW_{net,jjaM}}$",   "l"),
]

# 2) pack your three xarray‐datasets
datasets = {
    "up": SWup_M,
    "down": SWdown_M,
    "net": SWnet_M,
}

# 3. Make the figure & axes grid
# 4) build the figure
fig, axes = plt.subplots(
    3, 4,
    figsize=(12, 8),
    subplot_kw=dict(projection=map_proj, facecolor="lightgrey"),
    dpi=300
)
format_axes(axes.flatten())
fig.subplots_adjust(hspace=0.2, wspace=0.1, left=0.03, right=0.935, bottom=0.03, top=0.9)

# 4. Loop and plot
# 5) loop through panels
for dset_key, hi, lo, r, c, lvl_type, title_stem, letter in panels:
    ax     = axes[r, c]
    ds     = datasets[dset_key]
    arr    = ds[hi] if lo is None else (ds[hi] - ds[lo])
    levels = levels_dict[lvl_type]
    ticks  = ticks_dict[lvl_type]
    cmap   = cmaps[lvl_type]

    h = arr.plot(
        ax=ax,
        transform=data_proj,
        vmin=vmin, vmax=vmax,
        levels=levels,
        extend=extend,
        add_colorbar=False,
        cmap=cmap
    )
    ax.set_title(f"{title_stem} ({unit})")
    ax.add_feature(cfeature.OCEAN, color="whitesmoke")

    # always draw colorbar on every panel here (remove the if-col==3 test if you want all)
    cbar = mpu.colorbar(
        h, ax,
        orientation="vertical",
        size=0.04,
        extend=extend,
        pad=0.05,
        ticks=levels
    )
    cbar.set_ticklabels(ticks)

    ax.text(
        0.02, 0.93, letter,
        transform=ax.transAxes,
        fontsize=10, fontweight="bold"
    )

In [None]:
# 1. Levels, tick‑labels and cmaps
levels_dict = {
    'seq': np.array([-5, 0, 5, 10, 15, 20, 25, 30]) + 273.15,
    "div": np.array([-16, -8, -4, -2, -1, 1, 2, 4, 8, 16])
}
ticks_dict = {
    "seq": ["\N{MINUS SIGN}5", "0", "5", "10", "15", "20", "25", "30"],
    "div": ['\N{MINUS SIGN}16', '\N{MINUS SIGN}8', '\N{MINUS SIGN}4', '\N{MINUS SIGN}2', '\N{MINUS SIGN}1', '+1', '+2', '+4', '+8', '+16']
}
cmaps = {'seq': 'hot_r', 'div': 'RdBu_r'}

unit = '$\\mathregular{W/m^2}$'
vmin, vmax, extend = 0, 42, 'both'

# 2. Define each panel: (row, col), (var1, var0, levels‑type), title, letter
panels = [
    
    
    ("up", "all_forest_MAM",             "control_MAM",         0, 0, "div", "Ctl: $\mathregular{SW_{up,mamM}}$",             "a"),
    ("up", "all_forest_broadleaf_MAM","control_MAM",   0, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{SW_{up,mamM}}$",   "b"),
    ("up", "all_forest_broadleaf_MAM",            "all_forest_MAM",   0, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{SW_{up,mamM}}$",   "c"),
    ("up", "all_forest_broadleaf_MAM","all_grass_MAM",  0, 3, "div", "Brd–Def: $\\Delta$$\mathregular{SW_{up,mamM}}$",   "d"),

    ("down", "all_forest_MAM",              "control_MAM",         1, 0, "div", "Ctl: $\mathregular{SW_{down,mamM}}$",             "e"),
    ("down", "all_forest_broadleaf_MAM","control_MAM",   1, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{SW_{down,mamM}}$",   "f"),
    ("down", "all_forest_broadleaf_MAM",            "all_forest_MAM",   1, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{SW_{down,mamM}}$",   "g"),
    ("down", "all_forest_broadleaf_MAM","all_grass_MAM",  1, 3, "div", "Brd–Def: $\\Delta$$\mathregular{SW_{down,mamM}}$",   "h"),

    ("net", "all_forest_MAM",              "control_MAM",         2, 0, "div", "Ctl: $\mathregular{SW_{net,mamM}}$",             "i"),
    ("net", "all_forest_broadleaf_MAM","control_MAM",   2, 1, "div", "Brd–Ctl: $\\Delta$$\mathregular{SW_{net,mamM}}$",   "j"),
    ("net", "all_forest_broadleaf_MAM",            "all_forest_MAM",   2, 2, "div", "Def–Ctl: $\\Delta$$\mathregular{SW_{net,mamM}}$",   "k"),
    ("net", "all_forest_broadleaf_MAM","all_grass_MAM",  2, 3, "div", "Brd–Def: $\\Delta$$\mathregular{SW_{net,mamM}}$",   "l"),
]

# 2) pack your three xarray‐datasets
datasets = {
    "up": SWup_M,
    "down": SWdown_M,
    "net": SWnet_M,
}

# 3. Make the figure & axes grid
# 4) build the figure
fig, axes = plt.subplots(
    3, 4,
    figsize=(12, 8),
    subplot_kw=dict(projection=map_proj, facecolor="lightgrey"),
    dpi=300
)
format_axes(axes.flatten())
fig.subplots_adjust(hspace=0.2, wspace=0.1, left=0.03, right=0.935, bottom=0.03, top=0.9)

# 4. Loop and plot
# 5) loop through panels
for dset_key, hi, lo, r, c, lvl_type, title_stem, letter in panels:
    ax     = axes[r, c]
    ds     = datasets[dset_key]
    arr    = ds[hi] if lo is None else (ds[hi] - ds[lo])
    levels = levels_dict[lvl_type]
    ticks  = ticks_dict[lvl_type]
    cmap   = cmaps[lvl_type]

    h = arr.plot(
        ax=ax,
        transform=data_proj,
        vmin=vmin, vmax=vmax,
        levels=levels,
        extend=extend,
        add_colorbar=False,
        cmap=cmap
    )
    ax.set_title(f"{title_stem} ({unit})")
    ax.add_feature(cfeature.OCEAN, color="whitesmoke")

    # always draw colorbar on every panel here (remove the if-col==3 test if you want all)
    cbar = mpu.colorbar(
        h, ax,
        orientation="vertical",
        size=0.04,
        extend=extend,
        pad=0.05,
        ticks=levels
    )
    cbar.set_ticklabels(ticks)

    ax.text(
        0.02, 0.93, letter,
        transform=ax.transAxes,
        fontsize=10, fontweight="bold"
    )