In [None]:
import xarray as xr
import numpy as np
import matplotlib.pyplot as plt
import cartopy.crs as ccrs
from pathlib import Path
import re
import yaml
import torch

from fme.core.data_loading.perturbation import (
    ConstantConfig,
    GreensFunctionConfig,
    PerturbationSelector,
)
plt.rcParams['font.size'] = 16

### Read inference forcing data and control run data

In [None]:
forcing = xr.open_mfdataset("/pscratch/sd/e/elynnwu/e3smv2-fme-data/gfmip-annual-repeating-sst-sic-with-hybrid-amip-1970-2020/*.nc")
forcing = forcing.sel(time=slice("1971-01-01", "1981-01-01"))
control = xr.open_dataset("/pscratch/sd/e/elynnwu/fme-output/greens-experiment-v3/control_run_climSST_10yr/time_mean_diagnostics.nc")

### Load grid area directly from grid file

In [None]:
one_deg_grid = xr.open_dataset("/pscratch/sd/e/elynnwu/ace-run-inference/greens-function-test/ncremap_gaussian_grid_180_by_360.nc")
area = one_deg_grid["grid_area"].values.reshape((180, 360))

### Prepare ocean fraction, ocean mask, and ice free mask

- Ocean mask is defined as ocean fraction >= 0.5
- Ice free mask is defined as sea ice fraction <= 0.01

** we use the same time period average in the inference (1971-1981)


In [None]:
ocean_fraction = forcing.OCNFRAC.mean(dim="time").values
lat = control.lat.values
lon = control.lon.values
lons, lats = np.meshgrid(lon, lat)
area_weights = np.cos(np.deg2rad(lat))
ice_free_mask = xr.where(forcing.ICEFRAC.mean(dim="time")<=0.05 , True, False).values
ocean_mask = xr.where(forcing.OCNFRAC.mean(dim="time")>=0.5 , True, False).values

### Get ACE's SST perturbation
Instead of reading from inference output's TS, we back out the anomaly directly from the perturbation calculation to have a clean SST perturbation. If we use TS from inference directly, we have trouble filtering out sea ice regions as TS does change in those grid cells.

In [None]:
def get_ace_sst_perturbation(config, ocean_fraction, lat, lon):
    selector = PerturbationSelector(
        name="greens_function",
        config=config,
    )
    perturbation = selector.perturbation
    lats, lons = torch.meshgrid(torch.from_numpy(lat), torch.from_numpy(lon), indexing="ij")
    nx, ny = len(lat), len(lon)
    ocean_fraction = torch.from_numpy(ocean_fraction)
    data = torch.zeros(nx, ny, device="cpu")
    perturbation.apply_perturbation(data, lats.to("cpu"), lons.to("cpu"), ocean_fraction)
    return data.data.numpy()

### Read in all patch inference output

In [None]:
base_directory = Path("/pscratch/sd/e/elynnwu/fme-output/greens-experiment-v3")
patch_sim_dict = {}
for patch in base_directory.iterdir():
    if patch.is_dir() and  "sst_patch_" in patch.name and "10yr" in patch.name:
        config = yaml.safe_load(open(f"{base_directory}/{patch.name}/config.yaml"))
        amplitude = config["forcing_loader"]["perturbations"]["sst"][0]["config"]["amplitude"]
        lon_center = config["forcing_loader"]["perturbations"]["sst"][0]["config"]["lon_center"]
        lat_center = config["forcing_loader"]["perturbations"]["sst"][0]["config"]["lat_center"]
        current_patch = xr.open_dataset(f"{base_directory}/{patch.name}/time_mean_diagnostics.nc")       
        current_patch_string = f"{lat_center}_{lon_center}_{amplitude}K"
        if current_patch_string not in patch_sim_dict.keys():
            patch_sim_dict[current_patch_string] = current_patch

In [None]:
print("Total number of pathces:", len(patch_sim_dict.keys()))

In [None]:
# Patch has the following format: latcenter_loncenter_amplitude
pattern = r"-?\d+\.\d+"

### Equation 3 from [Bloch‐Johnson et al. 2024](https://agupubs.onlinelibrary.wiley.com/doi/epdf/10.1029/2023MS003700)
$\frac{df}{dSST^*_{i}} \approx \frac{\sum_{p} \bigl(\Delta f_p / \langle \Delta \overrightarrow{SST_p} \rangle \bigr) \Delta SST_{p,i}}{\sum_{p} \Delta SST_{p,i}}$

In [None]:
def get_normalized_derivative_of_N(VAR, patch_amplitude, control_run):
    delta_sst_p_i_sum = np.zeros_like(control_run["gen_map-TS"].values) #eq3 denominator
    sum_of_eq3_numerator = np.zeros_like(control_run["gen_map-TS"].values) #eq3 numerator
    for patch in patch_sim_dict.keys():
        info = re.findall(pattern, patch)
        amplitude = float(info[-1])
        lat_center = float(info[0])
        lon_center = float(info[1])
        if amplitude == patch_amplitude:
            config={
                        "amplitude": amplitude,
                        "lat_center": lat_center,
                        "lon_center": lon_center,
                        "lat_width": 20.0,
                        "lon_width": 80.0,
                    }
            sst_anom = get_ace_sst_perturbation(config, ocean_fraction, lat, lon)
            patch_and_icefree = sst_anom!=0 & ice_free_mask
            if patch_and_icefree.sum() > 0: # some patches are entirely over land and ice free region
                delta_sst_patch_avg = np.sum(sst_anom[patch_and_icefree] * area[patch_and_icefree]) / np.sum(area)
                f_patch = np.average(patch_sim_dict[patch][VAR], weights=area_weights, axis=0).mean()
                f_control = np.average(control_run[VAR], weights=area_weights, axis=0).mean()
                delta_fp = f_patch - f_control
                sum_of_eq3_numerator[patch_and_icefree] += delta_fp / delta_sst_patch_avg * sst_anom[patch_and_icefree]
                delta_sst_p_i_sum += sst_anom
    return sum_of_eq3_numerator, delta_sst_p_i_sum

In [None]:
VAR = "gen_map-net_energy_flux_toa_into_atmosphere"

In [None]:
def get_eq3_warming_cooling_avg(control_run):
    sum_of_eq3_numerator, delta_sst_p_i_sum = get_normalized_derivative_of_N(VAR, 2, control_run)
    eq3_warming = np.zeros_like(control_run["gen_map-TS"].values)
    eq3_warming[:, :] = np.nan
    eq3_warming[ocean_mask & ice_free_mask] = sum_of_eq3_numerator[ocean_mask & ice_free_mask] / delta_sst_p_i_sum[ocean_mask & ice_free_mask]

    sum_of_eq3_numerator, delta_sst_p_i_sum = get_normalized_derivative_of_N(VAR, -2, control_run)
    eq3_cooling = np.zeros_like(control_run["gen_map-TS"].values)
    eq3_cooling[:, :] = np.nan
    eq3_cooling[ocean_mask & ice_free_mask] = sum_of_eq3_numerator[ocean_mask & ice_free_mask] / delta_sst_p_i_sum[ocean_mask & ice_free_mask]

    eq3_avg = np.zeros_like(control_run["gen_map-TS"].values)
    eq3_avg[:, :] = np.nan
    eq3_avg = (eq3_warming[:,:] + eq3_cooling[:,:]) / 2.
    return eq3_warming, eq3_cooling, eq3_avg

In [None]:
control10_eq3_warming, control10_eq3_cooling, control10_eq3_avg = get_eq3_warming_cooling_avg(control)

In [None]:
fig, axes = plt.subplots(1, 3, figsize=(20, 9), subplot_kw={'projection': ccrs.PlateCarree(central_longitude=180)})
axes = axes.flatten()
ACE_results = [control10_eq3_warming, control10_eq3_cooling, control10_eq3_avg]
labels = ["Warming", "Cooling", "Average"]
for i in range(3):
    cf = axes[i].contourf(lons, lats, ACE_results[i], levels=np.arange(-30, 31, 2), extend="both", transform=ccrs.PlateCarree(), cmap="seismic")
    cbar = plt.colorbar(cf, orientation="horizontal", pad=0.05)
    cbar.set_label("dN/dSST [W/m^2/K]")
    axes[i].coastlines()
    axes[i].set_global()
    axes[i].set_title(labels[i])
    gl = axes[i].gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                    linewidth=0.5, color='gray', alpha=0.5, linestyle='--')
    gl.top_labels = False

In [None]:
def get_paper_fig2_data(model_name, case):
    model = xr.open_dataset(f"/pscratch/sd/e/elynnwu/ace-run-inference/greens-function-test/preliminary_spatial_feedbacks/data/{model_name}.nc")
    model_lons, model_lats = np.meshgrid(model.longitude, model.latitude)
    if case == "warming":
        data = model.spatial_feedbacks_from_warming.values
    elif case == "cooling":
        data = model.spatial_feedbacks_from_cooling.values
    else:
        data = model.spatial_feedbacks_from_both.values
    return model_lons, model_lats, data

In [None]:
model_results = {}
case_name = "both"
model_results["CAM5"] = get_paper_fig2_data("cam5", case_name)
model_results["HadCM3"] = get_paper_fig2_data("hadcm3", case_name)
model_results["GFDL-AM4"] =get_paper_fig2_data("gfdlam4", case_name)
model_results["ACE"] = (lons, lats, control10_eq3_avg)

In [None]:
fig, axes = plt.subplots(1, 4, figsize=(20, 9), subplot_kw={'projection': ccrs.PlateCarree(central_longitude=180)})
axes = axes.flatten()
for i, model in enumerate(model_results.keys()):
    lons, lats, avg = model_results[model]
    cf = axes[i].contourf(lons, lats, avg, levels=np.arange(-30, 31, 2), extend="both", transform=ccrs.PlateCarree(), cmap="seismic")
    cbar = plt.colorbar(cf, orientation="horizontal", pad=0.05)
    cbar.set_label("dN/dSST [W/m^2/K]")
    axes[i].coastlines()
    axes[i].set_global()
    axes[i].set_title(model)
    gl = axes[i].gridlines(crs=ccrs.PlateCarree(), draw_labels=True,
                    linewidth=0.5, color='gray', alpha=0.5, linestyle='--')
    gl.top_labels = False
