## Notebook Description and Execution Requirements

This notebook aims to analyze and visualize the **Heave** and **Spice** components of ocean salinity and temperature, derived from CMIP6 climate model data. The analysis focuses on zonal-mean anomalies (future vs. historical) within a specific depth layer.

### What the Notebook Does:

1. **CMIP6 Data**: CMIP6 Data Retrieval with gdown: Transfers preprocessed NetCDF (.nc) files—containing salinity and temperature data from multiple CMIP6 models for historical and projection scenarios (ssp370, ssp585)—into the notebook environment using gdown. These files were previously prepared and are essential for the analysis.
3. **Definition of Auxiliary Functions**: Includes functions for robust interpolation, calculation of the Heave and Spice components (using the `neutralocean` library), alignment of depth coordinates between datasets, and determination of “nice” limits for contour plots.
4. **Data Processing by Model and Scenario**:
    * For each model and scenario, loads the historical and projection datasets.
    * Computes the *spice* and *heave* fields for salinity and temperature.
    * Computes the zonal mean of these fields and selects data within the 0–2000 m depth range.
    * Saves these preprocessed zonal-mean data into individual NetCDF files (`preprocessed_data/zonal_mean_MODEL_SCENARIO.nc`).
    * Records the maximum absolute values of each component (Heave/Spice for salinity/temperature) in a JSON file (`raw_contour_limits.json`).
5. **Normalization of Contour Limits**: Adjusts the color and contour limits of the plots to ensure consistency and symmetry across all models and scenarios, enabling fair visual comparison. The finalized limits are saved in `final_plot_limits.json`.
6. **Ensemble Mean Calculation**: Computes the mean across all processed models for each scenario, creating an ensemble mean for the Heave and Spice components. These ensemble data are saved in `preprocessed_data/ensemble_zonal_mean_SCENARIO.nc`.
7. **Generation of Final Plots**: Produces detailed visualizations for each individual model and for the ensemble mean. The plots display Heave and Spice components of temperature and salinity. Additionally, the ensemble mean includes **stippling** to indicate areas where anomalies are statistically significant (p < 0.05) based on a t-test across ensemble members.

### Required Files and Folders for Running on Another Machine:

To reproduce this notebook on another machine, you will need the following:

* **NetCDF Data Files (`.nc`)**: The raw CMIP6 model data files downloaded via `gdown`. These are essential inputs.  
  The notebook downloads them directly from Google Drive URLs, so the machine must have internet access, or you may manually download them and place them in the same folder as the notebook (`/content/`).  
  The file names follow the pattern:  
  `CMIP.MODEL.SCENARIO.Omon.gn.nc` and `ScenarioMIP.MODEL.SCENARIO.Omon.gn.nc`.

* **Folder Structure (Generated by the Notebook)**:
    * `preprocessed_data/`: Automatically created by the notebook to store intermediate results (zonal means for each model and the ensemble).

* **JSON Files (Generated by the Notebook)**:
    * `raw_contour_limits.json`: Generated after computing the raw contour limits.
    * `final_plot_limits.json`: Generated after normalizing the contour limits.

* **Output Images (Generated by the Notebook)**:
    * `ssp370_temperature_plots_with_ensemble_stippling.jpeg`
    * `ssp370_salinity_plots_with_ensemble_stippling.jpeg`
    * `ssp585_temperature_plots_with_ensemble_stippling.jpeg`
    * `ssp585_salinity_plots_with_ensemble_stippling.jpeg`

  These JPEG files contain the final plots and will be saved in the notebook’s root directory.

**In summary**: The main external requirement is the CMIP6 `.nc` data files. The notebook is designed to download them, and all other folders and JSON files are generated automatically during execution. The machine must have internet access for data download and installation of required libraries.

In [None]:
#Install libraries
!pip install cftime
!pip install xarray==2024.6.0
!pip install gsw
!pip install netcdf4
!pip install neutralocean==2.1.3

Collecting cftime
  Downloading cftime-1.6.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (8.7 kB)
Downloading cftime-1.6.5-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (1.6 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.6/1.6 MB[0m [31m19.0 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: cftime
Successfully installed cftime-1.6.5
Collecting xarray==2024.6.0
  Downloading xarray-2024.6.0-py3-none-any.whl.metadata (11 kB)
Downloading xarray-2024.6.0-py3-none-any.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m16.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: xarray
  Attempting uninstall: xarray
    Found existing installation: xarray 2025.10.1
    Uninstalling xarray-2025.10.1:
      Successfully uninstalled xarray-2025.10.1
Successfully installed xarray-2024.6.0
Collecting gsw
  Downloading gsw-3.6.20-cp312-cp312-manylinux2014_x86_64.manyl

In [None]:
# @title Essential Imports
import xarray as xr
import numpy as np
import os
import json
import glob
import matplotlib.pyplot as plt
import math
from matplotlib import cm
from matplotlib.colors import ListedColormap
from scipy import stats
from neutralocean.traj import ntp_bottle_to_cast
import gsw
import matplotlib.ticker as mticker # Import MaxNLocator

In [None]:
#Downloads my analysis file directly from Drive
!gdown 1Q9rb8Hx9HOt1T962PtkrntK27t4N88FC #Downloads the file.
!gdown 106h7gu0WGxVFdFWnapQU4maskdPbKT1n

#Downloads my analysis file directly from Drive
!gdown 1Q9rb8Hx9HOt1T962PtkrntK27t4N88FC #Downloads the file.
!gdown 1oY0-iWcZ-9V0_uAMKKoHLW_P5PSIWVzj

#Downloads my analysis file directly from Drive
!gdown 13jhtN_mJ67yWA_hfIBa46sZWe1gdgYqj #Downloads the file.
!gdown 1jL-gPxHb29z0ZLPHGX37t0kPUNkISv5C

#Downloads my analysis file directly from Drive
!gdown 13jhtN_mJ67yWA_hfIBa46sZWe1gdgYqj #Downloads the file.
!gdown 19_1YRA-9xkdXA5WSCWyIohQaxWcOM3Oy

#Downloads my analysis file directly from Drive
!gdown 1bY2m7UlSL-VsbpbuyTM0KWTA9dxCYJ1D #Downloads the file.
!gdown 1RFoWsRUHLVj2jnVMVXL3G4xI2cz57RHh

#Downloads my analysis file directly from Drive
!gdown 1bY2m7UlSL-VsbpbuyTM0KWTA9dxCYJ1D #Downloads the file.
!gdown 1c6xKdMtM4HZZEV8OQzagYvFa7pMYx2aj

#Downloads my analysis file directly from Drive
!gdown 1GFnnZXe4wjkWSj78Iwf1Mbke-nb-nkKI #Downloads the file.
!gdown 15syf8Oa0tXoM3PfWgQmU0VSFBFlJd6Xd

#Downloads my analysis file directly from Drive
!gdown 1GFnnZXe4wjkWSj78Iwf1Mbke-nb-nkKI #Downloads the file.
!gdown 1UeCijOcs2mrSH-CsB7MYhCvEEbXgnlJj

#Downloads my analysis file directly from Drive
!gdown 1ncZLX0-218XRlWgAGyeIyj5Sut9QlwEL #Downloads the file.
!gdown 11lFfQZlOB5vpRZEzwRJxGF3swvT9fndo

#Downloads my analysis file directly from Drive
!gdown 1ncZLX0-218XRlWgAGyeIyj5Sut9QlwEL #Downloads the file.
!gdown 1xUNAkCFRUkuApdTQkXPAp42vLqaPWnHr

#Downloads my analysis file directly from Drive
!gdown 1bxaVTcR5veuUyZLhjTFC3fzH1a0Bsen7 #Downloads the file.
!gdown 1lXituflJ0ft93lhxVoJxGJHZQeWD6vso

#Downloads my analysis file directly from Drive
!gdown 1bxaVTcR5veuUyZLhjTFC3fzH1a0Bsen7 #Downloads the file.
!gdown 1AZ3FrHTzdhMwoa8gBUKDLtHStoyJQNHu

Downloading...
From: https://drive.google.com/uc?id=1Q9rb8Hx9HOt1T962PtkrntK27t4N88FC
To: /content/CMIP.CNRM-CERFACS.CNRM-ESM2-1.historical.Omon.gn.nc
100% 4.87M/4.87M [00:00<00:00, 148MB/s]
Downloading...
From: https://drive.google.com/uc?id=106h7gu0WGxVFdFWnapQU4maskdPbKT1n
To: /content/ScenarioMIP.CNRM-CERFACS.CNRM-ESM2-1.ssp585.Omon.gn (2).nc
100% 4.87M/4.87M [00:00<00:00, 45.9MB/s]
Downloading...
From: https://drive.google.com/uc?id=1Q9rb8Hx9HOt1T962PtkrntK27t4N88FC
To: /content/CMIP.CNRM-CERFACS.CNRM-ESM2-1.historical.Omon.gn.nc
100% 4.87M/4.87M [00:00<00:00, 119MB/s]
Downloading...
From: https://drive.google.com/uc?id=1oY0-iWcZ-9V0_uAMKKoHLW_P5PSIWVzj
To: /content/ScenarioMIP.CNRM-CERFACS.CNRM-ESM2-1.ssp370.Omon.gn.nc
100% 4.87M/4.87M [00:00<00:00, 181MB/s]
Downloading...
From: https://drive.google.com/uc?id=13jhtN_mJ67yWA_hfIBa46sZWe1gdgYqj
To: /content/CMIP.CAMS.CAMS-CSM1-0.historical.Omon.gn.nc
100% 3.25M/3.25M [00:00<00:00, 173MB/s]
Downloading...
From: https://drive.google.

In [None]:
# @title Helper Functions
def _interp(x, y, xnew):
    """Robust 1D linear interpolation (ignores NaN and sorts x)."""
    m = np.isfinite(x) & np.isfinite(y)
    if m.sum() < 2:
        return np.full_like(xnew, np.nan, dtype=float)
    xx, yy = x[m], y[m]
    if np.any(np.diff(xx) <= 0):
        o = np.argsort(xx)
        xx, yy = xx[o], yy[o]
    return np.interp(xnew, xx, yy, left=np.nan, right=np.nan)

def _heave_spice_sa(sa_h, th_h, p_h, sa_f, th_f, p_f):
    """
    1D column -> (spice, heave) for SALINITY on p_f levels (full grid).
    Inputs and outputs are 1D along 'lev'. Returns arrays of SAME size as p_f.
    """
    p_f_full = p_f.copy()

    ok_h = np.isfinite(sa_h) & np.isfinite(th_h) & np.isfinite(p_h)
    ok_f = np.isfinite(sa_f) & np.isfinite(th_f) & np.isfinite(p_f)
    if ok_f.sum() < 2 or ok_h.sum() < 2: # Corrected condition order
        n = p_f_full.size
        return np.full(n, np.nan), np.full(n, np.nan)


    sa_h_v, th_h_v, p_h_v = sa_h[ok_h], th_h[ok_h], p_h[ok_h]
    sa_f_v, th_f_v, p_f_v = sa_f[ok_f], th_f[ok_f], p_f[ok_f]


    # neutral bottles (historical -> future)
    s1_list, p1_list = [], []
    for s0, t0, p0 in zip(sa_h_v, th_h_v, p_h_v):
        s_ntp, t_ntp, p_ntp = ntp_bottle_to_cast(s0, t0, p0, sa_f_v, th_f_v, p_f_v)
        s1_list.append(s_ntp); p1_list.append(p_ntp)
    s1 = np.asarray(s1_list); p1 = np.asarray(p1_list)

    # spice at historical depths -> mapped to future levels (p_f_full)
    spice_hist = s1 - sa_h_v
    spice_pf   = _interp(p1,   spice_hist, p_f_full)

    # historical resampled on p_f_full and total
    sa_h_pf    = _interp(p_h_v, sa_h_v,    p_f_full)
    diff       = sa_f - sa_h_pf

    # heave = total - spice
    heave_pf   = diff - spice_pf
    return spice_pf, heave_pf


def _heave_spice_th(sa_h, th_h, p_h, sa_f, th_f, p_f):
    """
    1D column -> (spice, heave) for TEMPERATURE on p_f levels (full grid).
    """
    p_f_full = p_f.copy()

    ok_h = np.isfinite(sa_h) & np.isfinite(th_h) & np.isfinite(p_h)
    ok_f = np.isfinite(sa_f) & np.isfinite(th_f) & np.isfinite(p_f)
    if ok_f.sum() < 2 or ok_h.sum() < 2: # Corrected condition order
        n = p_f_full.size
        return np.full(n, np.nan), np.full(n, np.nan)


    sa_h_v, th_h_v, p_h_v = sa_h[ok_h], th_h[ok_h], p_h[ok_h]
    sa_f_v, th_f_v, p_f_v = sa_f[ok_f], th_f[ok_f], p_f[ok_f]

    t1_list, p1_list = [], []
    for s0, t0, p0 in zip(sa_h_v, th_h_v, p_h_v):
        s_ntp, t_ntp, p_ntp = ntp_bottle_to_cast(s0, t0, p0, sa_f_v, th_f_v, p_f_v)
        t1_list.append(t_ntp); p1_list.append(p_ntp)
    t1 = np.asarray(t1_list); p1 = np.asarray(p1_list)

    spice_hist = t1 - th_h_v
    spice_pf   = _interp(p1,   spice_hist, p_f_full)

    th_h_pf    = _interp(p_h_v, th_h_v,    p_f_full)
    diff       = th_f - th_h_pf
    heave_pf   = diff - spice_pf
    return spice_pf, heave_pf


def ensure_same_lev(ds_h: xr.Dataset, ds_ssp: xr.Dataset, *, strict=False):
    """
    Ensures ds_h and ds_ssp have the same 'lev' coordinate.
    - strict=True: raises an error if they differ.
    - strict=False: interpolates ds_h to ds_ssp's grid, if necessary.
    Returns (ds_h2, ds_ssp2) with identical and sorted 'lev'.
    """
    # sort for safety
    ds_h   = ds_h.sortby('lev')
    ds_ssp = ds_ssp.sortby('lev')

    same_size = ds_h.sizes.get('lev', None) == ds_ssp.sizes.get('lev', None)
    same_vals = same_size and np.allclose(ds_h['lev'].values, ds_ssp['lev'].values, equal_nan=False)

    if same_vals:
        return ds_h, ds_ssp  # already aligned

    if strict:
        raise ValueError(
            "ds_h and ds_ssp have different 'lev' coordinates. "
            f"ds_h.lev.size={ds_h.sizes.get('lev')}, ds_ssp.lev.size={ds_ssp.sizes.get('lev')}"
        )

    # align by interpolating the historical to the future grid
    ds_h_interp = ds_h.interp(lev=ds_ssp['lev'])
    return ds_h_interp, ds_ssp

def get_nice_limits(vmax_abs):
    """
    Calculates "nice" symmetric limits and contour levels
    using Matplotlib's MaxNLocator.
    """
    # Handles NaN or zero values
    if vmax_abs == 0 or np.isnan(vmax_abs):
        vmax_abs = 0.1 # Sets a small default to avoid errors

    # Uses MaxNLocator to find "nice" tick locations (around 11)
    # symmetric=True ensures ticks are symmetric around 0
    # prune='both' removes ticks outside the data range
    locator = mticker.MaxNLocator(nbins=10, symmetric=True, prune='both')

    # Gets ticks for the range [-vmax, +vmax]
    ticks = locator.tick_values(-vmax_abs, vmax_abs)

    # The new "nice" limits are the minimum and maximum of these ticks
    nice_vmin = ticks[0]
    nice_vmax = ticks[-1]

    # Contour line levels are the ticks themselves
    line_levels = ticks.tolist()

    # Fill levels are a finer linspace between these "nice" limits
    fill_levels = np.linspace(nice_vmin, nice_vmax, 51).tolist()

    return nice_vmin, nice_vmax, fill_levels, line_levels

In [None]:
# @title Setup: File paths, Models, and Scenarios
nc_files = glob.glob('*.nc')
scenarios = ['ssp370', 'ssp585']

# Redefine models by extracting unique model names from nc_files
models = list(set([f.split('.')[2] for f in nc_files if len(f.split('.')) >= 4]))
print(f"Identified models: {models}")

# Create a dictionary to organize the file paths by model and scenario
file_dict = {}
for f in nc_files:
    parts = f.split('.')
    if len(parts) >= 4:
        model = parts[2]
        scenario = parts[3]
        if model not in file_dict:
            file_dict[model] = {}
        file_dict[model][scenario] = f
    else:
        print(f"Skipping file with unexpected format: {f}")

Identified models: ['IPSL-CM6A-LR', 'CAMS-CSM1-0', 'GFDL-ESM4', 'MIROC6', 'CNRM-ESM2-1', 'CESM2']


In [None]:
# @title Calculate and Save Raw Contour Limits
# Create directory for preprocessed data
output_data_dir = "preprocessed_data"
os.makedirs(output_data_dir, exist_ok=True)

raw_contour_limits = {}

for model in models:
    if 'historical' not in file_dict.get(model, {}):
        print(f"Skipping model {model} as historical data is missing.")
        continue

    if model not in raw_contour_limits:
        raw_contour_limits[model] = {}

    for scenario in scenarios:
        if scenario not in file_dict.get(model, {}):
            print(f"Skipping scenario {scenario} for model {model} as data is missing.")
            continue

        print(f"Processing model: {model}, scenario: {scenario}")

        # Open historical dataset
        ds_h = xr.open_dataset(file_dict[model]['historical'])

        # Open scenario dataset
        ds_ssp = xr.open_dataset(file_dict[model][scenario])

        # Ensure same 'lev' coordinate
        try:
            ds_h, ds_ssp = ensure_same_lev(ds_h, ds_ssp, strict=True)
        except ValueError as e:
            print(f"Error ensuring same 'lev' for model {model}, scenario {scenario}: {e}")
            ds_h.close()
            ds_ssp.close()
            continue

        # --- Heave/Spice Calculation ---
        # Calculate heave and spice for salinity
        spice_sa, heave_sa = xr.apply_ufunc(
            _heave_spice_sa,
            ds_h['sa'], ds_h['thetao'], ds_h['press'],
            ds_ssp['sa'], ds_ssp['thetao'], ds_ssp['press'],
            input_core_dims=[['lev'], ['lev'], ['lev'], ['lev'], ['lev'], ['lev']],
            output_core_dims=[['lev'], ['lev']],
            vectorize=True, dask='parallelized', output_dtypes=[float, float]
        )

        # Calculate heave and spice for temperature
        spice_th, heave_th = xr.apply_ufunc(
            _heave_spice_th,
            ds_h['sa'], ds_h['thetao'], ds_h['press'],
            ds_ssp['sa'], ds_ssp['thetao'], ds_ssp['press'],
            input_core_dims=[['lev'], ['lev'], ['lev'], ['lev'], ['lev'], ['lev']],
            output_core_dims=[['lev'], ['lev']],
            vectorize=True, dask='parallelized', output_dtypes=[float, float]
        )

        # Create a new dataset to store the results
        ds_out = xr.Dataset(
            {
                'spice_salinity'   : spice_sa.transpose('lev','lat','lon'),
                'heave_salinity'   : heave_sa.transpose('lev','lat','lon'),
                'spice_temperature': spice_th.transpose('lev','lat','lon'),
                'heave_temperature': heave_th.transpose('lev','lat','lon'),
            },
            coords={'lev': ds_ssp.lev, 'lat': ds_ssp.lat, 'lon': ds_ssp.lon}
        )

        # Calculate zonal mean and select depths up to 2000m
        heave_temp_zonal_mean = ds_out.heave_temperature.mean('lon').sel(lev=slice(0, 2000)).isel(lev=slice(None, None, -1))
        spice_temp_zonal_mean = ds_out.spice_temperature.mean('lon').sel(lev=slice(0, 2000)).isel(lev=slice(None, None, -1))
        heave_sal_zonal_mean = ds_out.heave_salinity.mean('lon').sel(lev=slice(0, 2000)).isel(lev=slice(None, None, -1))
        spice_sal_zonal_mean = ds_out.spice_salinity.mean('lon').sel(lev=slice(0, 2000)).isel(lev=slice(None, None, -1))

        # --- SAVE ZONAL MEAN DATA ---
        ds_zonal = xr.Dataset({
            'heave_temperature': heave_temp_zonal_mean,
            'spice_temperature': spice_temp_zonal_mean,
            'heave_salinity': heave_sal_zonal_mean,
            'spice_salinity': spice_sal_zonal_mean,
        })
        output_path = os.path.join(output_data_dir, f"zonal_mean_{model}_{scenario}.nc")
        ds_zonal.to_netcdf(output_path)
        print(f"Saved preprocessed data to {output_path}")

        # --- Calculate RAW limits ---
        max_abs_heave_temp = np.nanmax(np.abs(heave_temp_zonal_mean))
        max_abs_spice_temp = np.nanmax(np.abs(spice_temp_zonal_mean))
        max_abs_heave_sal = np.nanmax(np.abs(heave_sal_zonal_mean))
        max_abs_spice_sal = np.nanmax(np.abs(spice_sal_zonal_mean))

        # --- Store RAW limits ---
        raw_contour_limits[model][scenario] = {
            'heave_temp': {'vmax': float(max_abs_heave_temp)},
            'spice_temp': {'vmax': float(max_abs_spice_temp)},
            'heave_sal': {'vmax': float(max_abs_heave_sal)},
            'spice_sal': {'vmax': float(max_abs_spice_sal)},
        }

        # Close datasets
        ds_h.close()
        ds_ssp.close()
        ds_out.close()
        ds_zonal.close()

# --- SAVE RAW LIMITS TO A JSON FILE ---
limits_file = 'raw_contour_limits.json'
with open(limits_file, 'w') as f:
    json.dump(raw_contour_limits, f, indent=4)

print(f"Finished processing all data. Raw limits saved to {limits_file}")

Processing model: IPSL-CM6A-LR, scenario: ssp370
Saved preprocessed data to preprocessed_data/zonal_mean_IPSL-CM6A-LR_ssp370.nc
Processing model: IPSL-CM6A-LR, scenario: ssp585
Saved preprocessed data to preprocessed_data/zonal_mean_IPSL-CM6A-LR_ssp585.nc
Processing model: CAMS-CSM1-0, scenario: ssp370
Saved preprocessed data to preprocessed_data/zonal_mean_CAMS-CSM1-0_ssp370.nc
Processing model: CAMS-CSM1-0, scenario: ssp585
Saved preprocessed data to preprocessed_data/zonal_mean_CAMS-CSM1-0_ssp585.nc
Processing model: GFDL-ESM4, scenario: ssp370
Saved preprocessed data to preprocessed_data/zonal_mean_GFDL-ESM4_ssp370.nc
Processing model: GFDL-ESM4, scenario: ssp585
Saved preprocessed data to preprocessed_data/zonal_mean_GFDL-ESM4_ssp585.nc
Processing model: MIROC6, scenario: ssp370
Saved preprocessed data to preprocessed_data/zonal_mean_MIROC6_ssp370.nc
Processing model: MIROC6, scenario: ssp585
Saved preprocessed data to preprocessed_data/zonal_mean_MIROC6_ssp585.nc
Processing model

In [None]:
# @title Normalize Contour Limits
print("Normalizing contour limits across all scenarios AND between heave/spice...")

try:
    with open('raw_contour_limits.json', 'r') as f:
        raw_contour_limits = json.load(f)
except FileNotFoundError:
    print("Error: 'raw_contour_limits.json' file not found.")
    print("Please run Step 1 script first.")
    # exit()

final_plot_limits = {}

for model in raw_contour_limits:
    final_plot_limits[model] = {}

    # 1. Find the absolute maximum value for Temp (Heave AND Spice)
    #    and Sal (Heave AND Spice) ACROSS all scenarios
    max_abs_temp = 0.0
    max_abs_sal = 0.0

    scenarios_for_model = list(raw_contour_limits[model].keys())
    if not scenarios_for_model:
        continue

    for scenario in scenarios_for_model:
        limits = raw_contour_limits[model][scenario]

        # Compare the current max TEMP with heave_temp and spice_temp of this scenario
        max_abs_temp = max(
            max_abs_temp,
            limits['heave_temp']['vmax'],
            limits['spice_temp']['vmax']
        )

        # Compare the current max SAL with heave_sal and spice_sal of this scenario
        max_abs_sal = max(
            max_abs_sal,
            limits['heave_sal']['vmax'],
            limits['spice_sal']['vmax']
        )

    # 2. Get the "nice" limits based on these global maximums
    vmin_t, vmax_t, fill_t, lines_t = get_nice_limits(max_abs_temp)
    vmin_s, vmax_s, fill_s, lines_s = get_nice_limits(max_abs_sal)

    # 3. Create the limit dictionaries.
    #    Note that heave/spice for TEMP will use THE SAME '..._t' limits
    global_nice_limits_temp = {
        'vmin': vmin_t, 'vmax': vmax_t, 'fill_levels': fill_t, 'line_levels': lines_t
    }
    global_nice_limits_sal = {
        'vmin': vmin_s, 'vmax': vmax_s, 'fill_levels': fill_s, 'line_levels': lines_s
    }

    # 4. Apply these limits to all scenarios for that model
    for scenario in scenarios_for_model:
         if scenario in raw_contour_limits[model]:
            final_plot_limits[model][scenario] = {
                'heave_temp': global_nice_limits_temp,
                'spice_temp': global_nice_limits_temp, # SAME temp limit
                'heave_sal': global_nice_limits_sal,
                'spice_sal': global_nice_limits_sal,  # SAME sal limit
            }

# 5. Save the final and "nice" limits
final_limits_file = 'final_plot_limits.json'
with open(final_limits_file, 'w') as f:
    json.dump(final_plot_limits, f, indent=4)

print(f"Normalization complete. Final limits saved to {final_limits_file}")

Normalizing contour limits across all scenarios AND between heave/spice...
Normalization complete. Final limits saved to final_plot_limits.json


In [None]:
# @title Calculate and Save Ensemble Mean Data
ensemble_data = {}
input_data_dir = "preprocessed_data"

for scenario in scenarios:
    print(f"Calculating ensemble mean for scenario: {scenario}")
    scenario_datasets = []
    valid_models_for_scenario = []

    for model in models:
        input_path = os.path.join(input_data_dir, f"zonal_mean_{model}_{scenario}.nc")
        if os.path.exists(input_path):
            try:
                ds_zonal = xr.open_dataset(input_path)
                scenario_datasets.append(ds_zonal)
                valid_models_for_scenario.append(model)
            except Exception as e:
                print(f"  Error opening dataset for model {model}, scenario {scenario}: {e}. Skipping.")
        # else:
            # print(f"  Missing preprocessed data for model {model}, scenario {scenario}. Skipping.")

    if not scenario_datasets:
        print(f"  No valid data found for scenario {scenario}. Cannot calculate ensemble mean.")
        continue

    # Ensure all datasets in scenario_datasets have the same 'lev' before concatenating
    # Find the union of all 'lev' coordinates across the datasets for this scenario
    all_levs = sorted(list(set(np.concatenate([ds['lev'].values for ds in scenario_datasets]))))
    target_lev = xr.DataArray(all_levs, dims=['lev'], coords={'lev': all_levs})

    aligned_scenario_datasets = []
    for ds in scenario_datasets:
        try:
            # Interpolate each dataset to the union of 'lev' coordinates
            aligned_ds = ds.interp(lev=target_lev, method='linear', kwargs={"fill_value": "extrapolate"})
            aligned_scenario_datasets.append(aligned_ds)
        except Exception as e:
            print(f"  Error interpolating dataset for a model in scenario {scenario}: {e}. Skipping this dataset from ensemble.")
            ds.close() # Close the dataset if interpolation fails
            continue
        finally:
             ds.close() # Close the original dataset after processing


    if not aligned_scenario_datasets:
        print(f"  No datasets could be aligned for scenario {scenario}. Skipping ensemble mean calculation.")
        continue

    try:
        # Concatenate and calculate the mean
        ensemble_data[scenario] = xr.concat(aligned_scenario_datasets, dim='model').mean(dim='model')
        print(f"Finished calculating ensemble mean for scenario: {scenario}")
    except Exception as e:
        print(f"  Error concatenating or calculating mean for scenario {scenario}: {e}. Skipping.")
        # Close any datasets in aligned_scenario_datasets that might still be open
        for ds in aligned_scenario_datasets:
             ds.close()
        continue


# Now, save the ensemble data
output_data_dir = "preprocessed_data"
os.makedirs(output_data_dir, exist_ok=True)

for scenario, ds_ensemble in ensemble_data.items():
    output_path = os.path.join(output_data_dir, f"ensemble_zonal_mean_{scenario}.nc")
    ds_ensemble.to_netcdf(output_path)
    print(f"Saved ensemble data for scenario {scenario} to {output_path}")
    ds_ensemble.close()

Calculating ensemble mean for scenario: ssp370
Finished calculating ensemble mean for scenario: ssp370
Calculating ensemble mean for scenario: ssp585
Finished calculating ensemble mean for scenario: ssp585
Saved ensemble data for scenario ssp370 to preprocessed_data/ensemble_zonal_mean_ssp370.nc
Saved ensemble data for scenario ssp585 to preprocessed_data/ensemble_zonal_mean_ssp585.nc


In [None]:
# @title Generate Final Plots with Ensemble Mean and Stippling
# --- Colormap Definition ---
cmap_base = cm.get_cmap('RdBu_r', 256)
white_index = 128
colors = cmap_base(np.linspace(0, 1, 256))
colors[white_index, :] = [1, 1, 1, 1]
RdBu_white_centered = ListedColormap(colors)

# --- Load FINALIZED limits ---
try:
    with open('final_plot_limits.json', 'r') as f:
        final_plot_limits = json.load(f)
except FileNotFoundError:
    print("Error: 'final_plot_limits.json' file not found.")
    print("Please run Step 2 (normalization) script first.")
    # exit()

input_data_dir = "preprocessed_data"
NCOLS = 3 # Defines the number of model columns

for scenario in scenarios:
    print(f"Generating plots for scenario: {scenario}")

    # 1. Count how many models have processed data for this scenario
    models_to_plot = []
    for model in models:
        input_path = os.path.join(input_data_dir, f"zonal_mean_{model}_{scenario}.nc")
        if os.path.exists(input_path):
             models_to_plot.append(model)

    n_models = len(models_to_plot)
    if n_models == 0:
        print(f"No processed data found for scenario {scenario}. Skipping.")
        continue

    # 2. Calculate the figure layout
    # Each model occupies 2 rows (Heave, Spice)
    # Adds 2 extra rows for the ensemble (Heave, Spice)
    n_model_rows = int(math.ceil(n_models / NCOLS))
    n_fig_rows = n_model_rows * 2 + 2 # +2 for ensemble plots

    # 3. Create TWO figures: one for Temp, one for Sal
    # squeeze=False ensures 'axes' is always a 2D array
    fig_temp, axes_temp = plt.subplots(
        n_fig_rows, NCOLS,
        figsize=(NCOLS * 8, n_fig_rows * 4), # (total_width, total_height)
        layout='constrained',
        squeeze=False
    )

    fig_sal, axes_sal = plt.subplots(
        n_fig_rows, NCOLS,
        figsize=(NCOLS * 8, n_fig_rows * 4),
        layout='constrained',
        squeeze=False
    )

    # --- 4. Loop for Plotting Individual Models ---
    for i, model in enumerate(models_to_plot):
        print(f"  Plotting model: {model}")

        # Calculate position in the grid
        col_idx = i % NCOLS
        model_row_idx = i // NCOLS
        heave_row_idx = model_row_idx * 2
        spice_row_idx = model_row_idx * 2 + 1

        # Select the correct axes
        ax_ht = axes_temp[heave_row_idx, col_idx]
        ax_st = axes_temp[spice_row_idx, col_idx]
        ax_hs = axes_sal[heave_row_idx, col_idx]
        ax_ss = axes_sal[spice_row_idx, col_idx]

        # Load preprocessed data
        input_path = os.path.join(input_data_dir, f"zonal_mean_{model}_{scenario}.nc")
        ds_zonal = xr.open_dataset(input_path)

        # Get limits (now identical for heave/spice)
        limits = final_plot_limits[model][scenario]
        limits_temp = limits['heave_temp'] # It's the same for heave_temp and spice_temp
        limits_sal = limits['heave_sal']   # It's the same for heave_sal and spice_sal

        # --- PLOT TEMPERATURE (Individual Model) ---

        # Heave Temperature (Top Plot)
        cf_ht = ax_ht.contourf(
            ds_zonal['lat'], ds_zonal['lev'], ds_zonal['heave_temperature'],
            levels=limits_temp['fill_levels'], cmap=RdBu_white_centered, extend='both',
            vmin=limits_temp['vmin'], vmax=limits_temp['vmax']
        )
        cs_ht = ax_ht.contour(
            ds_zonal['lat'], ds_zonal['lev'], ds_zonal['heave_temperature'],
            levels=limits_temp['line_levels'], colors='k', linewidths=0.6
        )
        ax_ht.clabel(cs_ht, fmt='%0.1f', fontsize=7)
        ax_ht.invert_yaxis()
        ax_ht.set_title(f"{model} - Heave (θ)")
        # Remove X-axis labels for all except the last row of models
        if model_row_idx < n_model_rows - 1 or n_model_rows == 0: # If not the last model row or if there are no models
             ax_ht.set_xticklabels([])
        fig_temp.colorbar(cf_ht, ax=ax_ht, orientation='vertical', label='°C',
                         ticks=limits_temp['line_levels'])

        # Spice Temperature (Bottom Plot)
        cf_st = ax_st.contourf(
            ds_zonal['lat'], ds_zonal['lev'], ds_zonal['spice_temperature'],
            levels=limits_temp['fill_levels'], cmap=RdBu_white_centered, extend='both',
            vmin=limits_temp['vmin'], vmax=limits_temp['vmax']
        )
        cs_st = ax_st.contour(
            ds_zonal['lat'], ds_zonal['lev'], ds_zonal['spice_temperature'],
            levels=limits_temp['line_levels'], colors='k', linewidths=0.6
        )
        ax_st.clabel(cs_st, fmt='%0.1f', fontsize=7)
        ax_st.invert_yaxis()
        ax_st.set_title(f"{model} - Spice (θ)")
        # Remove X-axis labels for all except the last row of models
        if model_row_idx < n_model_rows - 1 or n_model_rows == 0: # If not the last model row or if there are no models
             ax_st.set_xticklabels([])
        else:
             ax_st.set_xlabel("Latitude") # Add label only to the last model row (if any)
        fig_temp.colorbar(cf_st, ax=ax_st, orientation='vertical', label='°C',
                         ticks=limits_temp['line_levels'])


        # --- PLOT SALINITY (Individual Model) ---

        # Heave Salinity (Top Plot)
        cf_hs = ax_hs.contourf(
            ds_zonal['lat'], ds_zonal['lev'], ds_zonal['heave_salinity'],
            levels=limits_sal['fill_levels'], cmap=RdBu_white_centered, extend='both',
            vmin=limits_sal['vmin'], vmax=limits_sal['vmax']
        )
        cs_hs = ax_hs.contour(
            ds_zonal['lat'], ds_zonal['lev'], ds_zonal['heave_salinity'],
            levels=limits_sal['line_levels'], colors='k', linewidths=0.6
        )
        ax_hs.clabel(cs_hs, fmt='%0.2f', fontsize=7)
        ax_hs.invert_yaxis()
        ax_hs.set_title(f"{model} - Heave (Salinity)")
        # Remove X-axis labels for all except the last row of models
        if model_row_idx < n_model_rows - 1 or n_model_rows == 0: # If not the last model row or if there are no models
             ax_hs.set_xticklabels([])
        fig_sal.colorbar(cf_hs, ax=ax_hs, orientation='vertical', label='g/kg',
                         ticks=limits_sal['line_levels'])

        # Spice Salinity (Bottom Plot)
        cf_ss = ax_ss.contourf(
            ds_zonal['lat'], ds_zonal['lev'], ds_zonal['spice_salinity'],
            levels=limits_sal['fill_levels'], cmap=RdBu_white_centered, extend='both',
            vmin=limits_sal['vmin'], vmax=limits_sal['vmax']
        )
        cs_ss = ax_ss.contour(
            ds_zonal['lat'], ds_zonal['lev'], ds_zonal['spice_salinity'],
            levels=limits_sal['line_levels'], colors='k', linewidths=0.6
        )
        ax_ss.clabel(cs_ss, fmt='%0.2f', fontsize=7)
        ax_ss.invert_yaxis()
        ax_ss.set_title(f"{model} - Spice (Salinity)")
        # Remove X-axis labels for all except the last row of models
        if model_row_idx < n_model_rows - 1 or n_model_rows == 0: # If not the last model row or if there are no models
             ax_ss.set_xticklabels([])
        else:
             ax_ss.set_xlabel("Latitude") # Add label only to the last model row (if any)
        fig_sal.colorbar(cf_ss, ax=ax_ss, orientation='vertical', label='g/kg',
                         ticks=limits_sal['line_levels'])

        # Close the dataset
        ds_zonal.close()

        # --- Y-axis Title Management for Individual Models ---
        # Only show "Depth (m)" in the left column (col_idx == 0)
        if col_idx == 0:
            ax_ht.set_ylabel("Depth (m)")
            ax_st.set_ylabel("Depth (m)")
            ax_hs.set_ylabel("Depth (m)")
            ax_ss.set_ylabel("Depth (m)")
        else:
            ax_ht.set_yticklabels([])
            ax_st.set_yticklabels([])
            ax_hs.set_yticklabels([])
            ax_ss.set_yticklabels([])

    # --- 5. Plot the Ensemble Mean ---
    ensemble_input_path = os.path.join(input_data_dir, f"ensemble_zonal_mean_{scenario}.nc")
    if os.path.exists(ensemble_input_path):
        print(f"  Plotting Ensemble Mean for scenario: {scenario}")
        ds_ensemble_zonal = xr.open_dataset(ensemble_input_path)

        # Get the limits from one of the models (since they are normalized across models)
        first_model = models_to_plot[0] if models_to_plot else None
        if first_model and first_model in final_plot_limits and scenario in final_plot_limits[first_model]:
             limits_temp = final_plot_limits[first_model][scenario]['heave_temp']
             limits_sal = final_plot_limits[first_model][scenario]['heave_sal']
        else:
             print(f"  Could not retrieve normalized limits for ensemble plot for scenario {scenario}. Skipping ensemble plot.")
             ds_ensemble_zonal.close()
             # Need to remove the allocated ensemble axes if plotting is skipped
             for col_idx in range(NCOLS):
                 if n_model_rows * 2 + 0 < n_fig_rows: # Check if first ensemble row exists
                     fig_temp.delaxes(axes_temp[n_model_rows * 2 + 0, col_idx])
                     fig_sal.delaxes(axes_sal[n_model_rows * 2 + 0, col_idx])
                 if n_model_rows * 2 + 1 < n_fig_rows: # Check if second ensemble row exists
                     fig_temp.delaxes(axes_temp[n_model_rows * 2 + 1, col_idx])
                     fig_sal.delaxes(axes_sal[n_model_rows * 2 + 1, col_idx])
             continue


        # Calculate the starting row for the ensemble plots
        ensemble_heave_row_idx = n_model_rows * 2
        ensemble_spice_row_idx = n_model_rows * 2 + 1

        # --- STATISTICAL SIGNIFICANCE CALCULATION (T-TEST) ---
        # Need to load the individual model zonal mean data again to calculate std deviation across models
        all_models_heave_temp = []
        all_models_spice_temp = []
        all_models_heave_sal = []
        all_models_spice_sal = []

        for model in models_to_plot:
             input_path = os.path.join(input_data_dir, f"zonal_mean_{model}_{scenario}.nc")
             if os.path.exists(input_path):
                 try:
                     ds_zonal_model = xr.open_dataset(input_path)
                     # Ensure alignment before appending
                     # Use the ensemble mean's 'lev' as the target
                     target_lev = ds_ensemble_zonal['lev']
                     all_models_heave_temp.append(ds_zonal_model['heave_temperature'].interp(lev=target_lev, method='linear', kwargs={"fill_value": "extrapolate"}))
                     all_models_spice_temp.append(ds_zonal_model['spice_temperature'].interp(lev=target_lev, method='linear', kwargs={"fill_value": "extrapolate"}))
                     all_models_heave_sal.append(ds_zonal_model['heave_salinity'].interp(lev=target_lev, method='linear', kwargs={"fill_value": "extrapolate"}))
                     all_models_spice_sal.append(ds_zonal_model['spice_salinity'].interp(lev=target_lev, method='linear', kwargs={"fill_value": "extrapolate"}))
                     ds_zonal_model.close()
                 except Exception as e:
                     print(f"  Error loading or aligning individual model data for significance test for model {model}, scenario {scenario}: {e}. Skipping this model for stippling calculation.")
                     # Ensure ds_zonal_model is closed even on error
                     if 'ds_zonal_model' in locals() and ds_zonal_model:
                         ds_zonal_model.close()


        n_models_for_stippling = len(all_models_heave_temp)

        if n_models_for_stippling > 1:
            # Concatenate aligned data along a new 'model' dimension
            concatenated_heave_temp = xr.concat(all_models_heave_temp, dim='model')
            concatenated_spice_temp = xr.concat(all_models_spice_temp, dim='model')
            concatenated_heave_sal = xr.concat(all_models_heave_sal, dim='model')
            concatenated_spice_sal = xr.concat(all_models_spice_sal, dim='model')

            # Calculate standard deviation across the 'model' dimension
            ensemble_std_heave_temp = concatenated_heave_temp.std(dim="model")
            ensemble_std_spice_temp = concatenated_spice_temp.std(dim="model")
            ensemble_std_heave_sal = concatenated_heave_sal.std(dim="model")
            ensemble_std_spice_sal = concatenated_spice_sal.std(dim="model")

            # Calculate the t-statistic: t = (mean - 0) / (std / sqrt(n))
            # Mean is the ensemble mean loaded from ds_ensemble_zonal
            with np.errstate(divide='ignore', invalid='ignore'):
                t_stat_heave_temp = (ds_ensemble_zonal['heave_temperature'] / (ensemble_std_heave_temp / np.sqrt(n_models_for_stippling))).fillna(0)
                t_stat_spice_temp = (ds_ensemble_zonal['spice_temperature'] / (ensemble_std_spice_temp / np.sqrt(n_models_for_stippling))).fillna(0)
                t_stat_heave_sal = (ds_ensemble_zonal['heave_salinity'] / (ensemble_std_heave_sal / np.sqrt(n_models_for_stippling))).fillna(0)
                t_stat_spice_sal = (ds_ensemble_zonal['spice_salinity'] / (ensemble_std_spice_sal / np.sqrt(n_models_for_stippling))).fillna(0)

            # Degrees of freedom = n - 1
            df = n_models_for_stippling - 1

            # Calculate p-values (using numpy arrays and scipy.stats)
            p_values_heave_temp_np = stats.t.sf(np.abs(t_stat_heave_temp.values), df=df) * 2
            p_values_spice_temp_np = stats.t.sf(np.abs(t_stat_spice_temp.values), df=df) * 2
            p_values_heave_sal_np = stats.t.sf(np.abs(t_stat_heave_sal.values), df=df) * 2
            p_values_spice_sal_np = stats.t.sf(np.abs(t_stat_spice_sal.values), df=df) * 2

            # Convert p_values back to xarray.DataArray with coordinates
            p_values_heave_temp_da = xr.DataArray(p_values_heave_temp_np, coords=t_stat_heave_temp.coords, dims=t_stat_heave_temp.dims)
            p_values_spice_temp_da = xr.DataArray(p_values_spice_temp_np, coords=t_stat_spice_temp.coords, dims=t_stat_spice_temp.dims)
            p_values_heave_sal_da = xr.DataArray(p_values_heave_sal_np, coords=t_stat_heave_sal.coords, dims=t_stat_heave_sal.dims)
            p_values_spice_sal_da = xr.DataArray(p_values_spice_sal_np, coords=t_stat_spice_sal.coords, dims=t_stat_spice_sal.dims)

            # Create the significance mask (where p < 0.05, put 1, otherwise NaN)
            significance_mask_heave_temp = xr.where(p_values_heave_temp_da < 0.05, 1, np.nan)
            significance_mask_spice_temp = xr.where(p_values_spice_temp_da < 0.05, 1, np.nan)
            significance_mask_heave_sal = xr.where(p_values_heave_sal_da < 0.05, 1, np.nan)
            significance_mask_spice_sal = xr.where(p_values_spice_sal_da < 0.05, 1, np.nan)

        else:
            significance_mask_heave_temp = None
            significance_mask_spice_temp = None
            significance_mask_heave_sal = None
            significance_mask_spice_sal = None
            print(f"  Not enough models ({n_models_for_stippling}) for significance test for scenario {scenario}. Skipping stippling.")


        # --- PLOT TEMPERATURE (Ensemble Mean) ---

        # Heave Temperature (Ensemble)
        ax_ht_ens = axes_temp[ensemble_heave_row_idx, 0] # Always in the first column
        cf_ht_ens = ax_ht_ens.contourf(
            ds_ensemble_zonal['lat'], ds_ensemble_zonal['lev'], ds_ensemble_zonal['heave_temperature'],
            levels=limits_temp['fill_levels'], cmap=RdBu_white_centered, extend='both',
            vmin=limits_temp['vmin'], vmax=limits_temp['vmax']
        )
        cs_ht_ens = ax_ht_ens.contour(
            ds_ensemble_zonal['lat'], ds_ensemble_zonal['lev'], ds_ensemble_zonal['heave_temperature'],
            levels=limits_temp['line_levels'], colors='k', linewidths=0.6
        )
        ax_ht_ens.clabel(cs_ht_ens, fmt='%0.1f', fontsize=7)
        ax_ht_ens.invert_yaxis()
        ax_ht_ens.set_title(f"Ensemble Mean - Heave (θ)")
        ax_ht_ens.set_xlabel("Latitude") # Add xlabel to the bottom row
        ax_ht_ens.set_ylabel("Depth (m)") # Add ylabel to the leftmost column
        fig_temp.colorbar(cf_ht_ens, ax=ax_ht_ens, orientation='vertical', label='°C',
                         ticks=limits_temp['line_levels'])

        # Add stippling for Heave Temperature if available
        if significance_mask_heave_temp is not None:
            ax_ht_ens.contourf(
                significance_mask_heave_temp['lat'], significance_mask_heave_temp['lev'], significance_mask_heave_temp,
                levels=[0.5, 1.5], # Plot where mask is 1
                hatches=['...'], colors='none' # Use dots for stippling
            )


        # Spice Temperature (Ensemble)
        ax_st_ens = axes_temp[ensemble_spice_row_idx, 0] # Always in the first column
        cf_st_ens = ax_st_ens.contourf(
            ds_ensemble_zonal['lat'], ds_ensemble_zonal['lev'], ds_ensemble_zonal['spice_temperature'],
            levels=limits_temp['fill_levels'], cmap=RdBu_white_centered, extend='both',
            vmin=limits_temp['vmin'], vmax=limits_temp['vmax']
        )
        cs_st_ens = ax_st_ens.contour(
            ds_ensemble_zonal['lat'], ds_ensemble_zonal['lev'], ds_ensemble_zonal['spice_temperature'],
            levels=limits_temp['line_levels'], colors='k', linewidths=0.6
        )
        ax_st_ens.clabel(cs_st_ens, fmt='%0.1f', fontsize=7)
        ax_st_ens.invert_yaxis()
        ax_st_ens.set_title(f"Ensemble Mean - Spice (θ)")
        ax_st_ens.set_xlabel("Latitude") # Add xlabel to the bottom row
        ax_st_ens.set_ylabel("Depth (m)") # Add ylabel to the leftmost column
        fig_temp.colorbar(cf_st_ens, ax=ax_st_ens, orientation='vertical', label='°C',
                         ticks=limits_temp['line_levels'])

        # Add stippling for Spice Temperature if available
        if significance_mask_spice_temp is not None:
            ax_st_ens.contourf(
                significance_mask_spice_temp['lat'], significance_mask_spice_temp['lev'], significance_mask_spice_temp,
                levels=[0.5, 1.5], # Plot where mask is 1
                hatches=['...'], colors='none' # Use dots for stippling
            )


        # --- PLOT SALINITY (Ensemble Mean) ---

        # Heave Salinity (Ensemble)
        ax_hs_ens = axes_sal[ensemble_heave_row_idx, 0] # Always in the first column
        cf_hs_ens = ax_hs_ens.contourf(
            ds_ensemble_zonal['lat'], ds_ensemble_zonal['lev'], ds_ensemble_zonal['heave_salinity'],
            levels=limits_sal['fill_levels'], cmap=RdBu_white_centered, extend='both',
            vmin=limits_sal['vmin'], vmax=limits_sal['vmax']
        )
        cs_hs_ens = ax_hs_ens.contour(
            ds_ensemble_zonal['lat'], ds_ensemble_zonal['lev'], ds_ensemble_zonal['heave_salinity'],
            levels=limits_sal['line_levels'], colors='k', linewidths=0.6
        )
        ax_hs_ens.clabel(cs_hs_ens, fmt='%0.2f', fontsize=7)
        ax_hs_ens.invert_yaxis()
        ax_hs_ens.set_title(f"Ensemble Mean - Heave (Salinity)")
        ax_hs_ens.set_xlabel("Latitude") # Add xlabel to the bottom row
        ax_hs_ens.set_ylabel("Depth (m)") # Add ylabel to the leftmost column
        fig_sal.colorbar(cf_hs_ens, ax=ax_hs_ens, orientation='vertical', label='g/kg',
                         ticks=limits_sal['line_levels'])

        # Add stippling for Heave Salinity if available
        if significance_mask_heave_sal is not None:
            ax_hs_ens.contourf(
                significance_mask_heave_sal['lat'], significance_mask_heave_sal['lev'], significance_mask_heave_sal,
                levels=[0.5, 1.5], # Plot where mask is 1
                hatches=['...'], colors='none' # Use dots for stippling
            )


        # Spice Salinity (Ensemble)
        ax_ss_ens = axes_sal[ensemble_spice_row_idx, 0] # Always in the first column
        cf_ss_ens = ax_ss_ens.contourf(
            ds_ensemble_zonal['lat'], ds_ensemble_zonal['lev'], ds_ensemble_zonal['spice_salinity'],
            levels=limits_sal['fill_levels'], cmap=RdBu_white_centered, extend='both',
            vmin=limits_sal['vmin'], vmax=limits_sal['vmax']
        )
        cs_ss_ens = ax_ss_ens.contour(
            ds_ensemble_zonal['lat'], ds_ensemble_zonal['lev'], ds_ensemble_zonal['spice_salinity'],
            levels=limits_sal['line_levels'], colors='k', linewidths=0.6
        )
        ax_ss_ens.clabel(cs_ss_ens, fmt='%0.2f', fontsize=7)
        ax_ss_ens.invert_yaxis()
        ax_ss_ens.set_title(f"Ensemble Mean - Spice (Salinity)")
        ax_ss_ens.set_xlabel("Latitude") # Add xlabel to the bottom row
        ax_ss_ens.set_ylabel("Depth (m)") # Add ylabel to the leftmost column
        fig_sal.colorbar(cf_ss_ens, ax=ax_ss_ens, orientation='vertical', label='g/kg',
                         ticks=limits_sal['line_levels'])

        # Add stippling for Spice Salinity if available
        if significance_mask_spice_sal is not None:
            ax_ss_ens.contourf(
                significance_mask_spice_sal['lat'], significance_mask_spice_sal['lev'], significance_mask_spice_sal,
                levels=[0.5, 1.5], # Plot where mask is 1
                hatches=['...'], colors='none' # Use dots for stippling
            )

        # Close the ensemble dataset
        ds_ensemble_zonal.close()

        # --- Clean up unused ensemble axes ---
        # The ensemble only uses the first column. Remove remaining axes in the ensemble rows.
        for col_idx in range(1, NCOLS):
             if ensemble_heave_row_idx < n_fig_rows:
                 fig_temp.delaxes(axes_temp[ensemble_heave_row_idx, col_idx])
                 fig_sal.delaxes(axes_sal[ensemble_heave_row_idx, col_idx])
             if ensemble_spice_row_idx < n_fig_rows:
                 fig_temp.delaxes(axes_temp[ensemble_spice_row_idx, col_idx])
                 fig_sal.delaxes(axes_sal[ensemble_spice_row_idx, col_idx])


    # --- 6. Clean up remaining empty axes (if any) ---
    # This part handles any axes in the last model row that weren't filled,
    # and is still needed even with the dedicated ensemble row.
    total_model_plots = n_models * 2 # 2 plots per model (Heave, Spice)
    for i in range(total_model_plots, (n_model_rows * 2) * NCOLS):
        row_idx = i // NCOLS
        col_idx = i % NCOLS
        if row_idx < n_fig_rows: # Ensure we don't go out of bounds
             fig_temp.delaxes(axes_temp[row_idx, col_idx])
             fig_sal.delaxes(axes_sal[row_idx, col_idx])


    # --- 7. Save the Figures ---
    fig_temp.savefig(f'{scenario}_temperature_plots_with_ensemble_stippling.jpeg', dpi=600, bbox_inches='tight')
    fig_sal.savefig(f'{scenario}_salinity_plots_with_ensemble_stippling.jpeg', dpi=600, bbox_inches='tight')

    plt.close(fig_temp)
    plt.close(fig_sal)
    print(f"Finished generating plots for scenario: {scenario}")

  cmap_base = cm.get_cmap('RdBu_r', 256)


Generating plots for scenario: ssp370
  Plotting model: IPSL-CM6A-LR
  Plotting model: CAMS-CSM1-0
  Plotting model: GFDL-ESM4
  Plotting model: MIROC6
  Plotting model: CNRM-ESM2-1
  Plotting model: CESM2
  Plotting Ensemble Mean for scenario: ssp370
Finished generating plots for scenario: ssp370
Generating plots for scenario: ssp585
  Plotting model: IPSL-CM6A-LR
  Plotting model: CAMS-CSM1-0
  Plotting model: GFDL-ESM4
  Plotting model: MIROC6
  Plotting model: CNRM-ESM2-1
  Plotting model: CESM2
  Plotting Ensemble Mean for scenario: ssp585
Finished generating plots for scenario: ssp585
