In [None]:
import pandas as pd
import os
import matplotlib.pyplot as plt
import shap
import numpy as np
from explain import load_and_prepare_data, find_model_and_scaler_paths
from models.model_utils import load_model
from models.model_explainer import ModelExplainer
from matplotlib import rcParams
from matplotlib.mathtext import _mathtext as mathtext

"""
Extended Data Figure 4
Modifying the plotting logic of the SHAP package's heatmap during calls is highly complex. Therefore, we have modified the source code of shap.plots.heatmap. We strongly recommend that readers back up the original code when reproducing this scenario.
"""

mathtext.FontConstantsBase.sup1 = 0.30
rcParams['font.family'] = 'serif'
rcParams['font.serif'] = ['Arial']
plt.rcParams['font.family'] = 'Arial'
plt.rcParams['mathtext.fontset'] = 'custom'
plt.rcParams['mathtext.rm'] = 'Arial'
plt.rcParams['mathtext.it'] = 'Arial'
plt.rcParams['mathtext.bf'] = 'Arial'


TARGETS = ['CH4', 'N2O', 'SOCSR', 'Yield', 'NH3', 'NL', 'NR', 'NO']
MANAGMENTS_DICTS = {'TYPE_IN': "IN type",
                    'IN': "IN",
                    'IP': "IP",
                    'IK': "IK",
                    'OrgN': "ON",
                    'BC': "BC",
                    'STRAW': "Straw",
                    'tillage': "Tillage",
                    'WM_CF': "CF", 
                    'WM_MD': "MD",
                    'WM_AWD': "II",
                    'WM_RF': "RF"}

SHOW_DICT = {
    'CH4': f'CH$_{4}$', 
    'N2O': f'N$_{2}$O', 
    'SOCSR': 'ΔSOC', 
    'Yield': 'Yield',
    'NH3': f'NH$_{3}$', 
    'NL': 'N leaching', 
    'NR': 'N runoff', 
    'NO': 'NO'
}

UNIT_DICT = {
    "CH4": f"(kg ha$^{{-1}}$)",
    "N2O": f"(kg ha$^{{-1}}$)",
    "SOCSR": f"(t C ha$^{{-1}}$)",
    "Yield": f"(t ha$^{{-1}}$)",
    "NH3": f"(kg N ha$^{{-1}}$)",
    "NL": f"(kg N ha$^{{-1}}$)",
    "NR": f"(kg N ha$^{{-1}}$)",
    "NO": f"(kg N ha$^{{-1}}$)"
}

NUMBER_DICT = {
    "CH4": "a",
    "N2O": "b",
    "SOCSR": "c",
    "Yield": "d",
    "NH3": "e",
    "NL": "f",
    "NR": "g",
    "NO": "h"
}

LABELPAD_DICT = {
    "CH4": -140,
    "N2O": -95,
    "SOCSR": -95,
    "Yield": -95,
    "NH3": -115,
    "NL": -95,
    "NR": -95,
    "NO": -95
}

model_dir = "../../../runs/save/1/best_model"
data_dir = "../../../data/ml_dataset"
save_dir = "save"

os.makedirs(save_dir, exist_ok=True)


for tar in TARGETS:

    model_path, scaler_path = find_model_and_scaler_paths(model_dir, tar)
    
    if model_path is None or scaler_path is None:
        print(f"✗ {tar}: Model or scaler file not found")
        continue
        
    data_path = os.path.join(data_dir, f'{tar}.csv')
    
    if not os.path.exists(data_path):
        print(f"✗ {tar}: The data file does not exist: {data_path}")
        continue
    
    features, std_features, feature_names = load_and_prepare_data(data_path, scaler_path)
    
    concern_cols = [col for col in MANAGMENTS_DICTS.keys() if col in features.columns]
    
    if not concern_cols:
        print(f"✗ {tar}: No features found in MANAGMENTS-DICTS")
        continue
    
    feature_values_real = features[concern_cols].values
    
    model = load_model(model_path)
    scaler = load_model(scaler_path)
    
    explainer = ModelExplainer(
        model=model, 
        std_scaler=scaler, 
        features=features,
        feature_names=feature_names, 
        target_names=None
    )
    
    shap_df = pd.read_csv(f"data/{tar}_shap_all.csv")
    
    shap_values_filtered = shap_df[concern_cols].values
    
    feature_names_filtered = [MANAGMENTS_DICTS[col] for col in concern_cols]
    
    explanation = shap.Explanation(
        values=shap_values_filtered,
        feature_names=feature_names_filtered,
        data=feature_values_real 
    )

    fig, ax = plt.subplots(figsize=(18, 12))
   
    shap.plots.heatmap(
        explanation,
        max_display=min(15, len(feature_names_filtered)),
        show=False,
        instance_order=explanation.sum(1),
        cmTitle=f"SHAP value {UNIT_DICT[tar]}",
        cmlabelpad=LABELPAD_DICT[tar],
        plot_width=12,
        row_height=0.4,
        cmTitleSize=44,
        cmTickSize=44,
        cmtickalpha=1 if tar != "Yield" else 1000,
        figsize=(14, 12))
        
    yticks = ax.get_yticklabels()
    ytick_positions = ax.get_yticks()
    plt.yticks(ha='center',     
        va='center',    
        ma='center')
    plt.gca().yaxis.set_tick_params(pad=75) 
    
    #SHOW_DICT[tar]
    for i, label in enumerate(yticks):
        text = label.get_text()
        if '$f(x)$' in text:
            new_text = text.replace('$f(x)$', '' if tar != 'Yield' else '')
            ax.set_yticklabels([new_text if j == i else ytick.get_text() for j, ytick in enumerate(yticks)])
            
            new_labels = ax.get_yticklabels()
            new_labels[i].set_fontfamily('Arial')
            new_labels[i].set_fontsize(40)
            break


    ax.set_xticks([])  
    ax.set_xlabel(ax.get_xlabel(),  fontsize=44, fontname='Arial')  
    ax.tick_params(axis='y', labelsize=44) 
    ax.spines['bottom'].set_visible(True)
    
    import  matplotlib
    for artist in ax.get_children():
        if isinstance(artist, matplotlib.image.AxesImage):  
            im = artist
            break

    plt.text(-0.15, 0.98, NUMBER_DICT[tar], transform=plt.gca().transAxes, 
          fontsize=44, fontweight='bold', 
         verticalalignment='center', horizontalalignment='center')
    plt.text(0.5, 0.98, SHOW_DICT[tar], transform=plt.gca().transAxes, 
          fontsize=44, verticalalignment='center', horizontalalignment='center')
    
    save_path = os.path.join(save_dir, f"{tar}_shap_heatmap.png")
    plt.savefig(save_path, dpi=200, bbox_inches='tight', facecolor='white')
    plt.close(fig)
    print(f"✓ Saved: {save_path}. Total {len(concern_cols)} features")
    
print("done")

In [39]:
# Here is the modified source code for shap.plots.heatmap. Please replace the original .py file with this version.
# DO NOT run it directly
import matplotlib.pyplot as pl
import numpy as np
from .. import Explanation
from ..utils import OpChain
from . import colors
from ._labels import labels
from ._utils import convert_ordering

def heatmap(shap_values, instance_order=Explanation.hclust(), feature_values=Explanation.abs.mean(0),
            feature_order=None, max_display=10, cmap=colors.red_white_blue, show=True,
            plot_width=8, withcmap=True, cmTitle=None, cmlabelpad=None, isBar=False, row_height=None, cmTitleSize=None,
            xlabel=None, cmTickSize=20, cmtickalpha=1, figsize=None):
    """Create a heatmap plot of a set of SHAP values.

    This plot is designed to show the population substructure of a dataset using supervised
    clustering and a heatmap.
    Supervised clustering involves clustering data points not by their original
    feature values but by their explanations.
    By default, we cluster using :func:`shap.utils.hclust_ordering`,
    but any clustering can be used to order the samples.

    Parameters
    ----------
    shap_values : shap.Explanation
        A multi-row :class:`.Explanation` object that we want to visualize in a
        cluster ordering.

    instance_order : OpChain or numpy.ndarray
        A function that returns a sort ordering given a matrix of SHAP values and an axis, or
        a direct sample ordering given as an ``numpy.ndarray``.

    feature_values : OpChain or numpy.ndarray
        A function that returns a global summary value for each input feature, or an array of such values.

    feature_order : None, OpChain, or numpy.ndarray
        A function that returns a sort ordering given a matrix of SHAP values and an axis, or
        a direct input feature ordering given as an ``numpy.ndarray``.
        If ``None``, then we use ``feature_values.argsort``.

    max_display : int
        The maximum number of features to display (default is 10).

    show : bool
        Whether ``matplotlib.pyplot.show()`` is called before returning.
        Setting this to ``False`` allows the plot
        to be customized further after it has been created.

    plot_width: int, default 8
        The width of the heatmap plot.

    Examples
    --------
    See `heatmap plot examples <https://shap.readthedocs.io/en/latest/example_notebooks/api_examples/plots/heatmap.html>`_.

    """
    # sort the SHAP values matrix by rows and columns
    values = shap_values.values
    if issubclass(type(feature_values), OpChain):
        feature_values = feature_values.apply(Explanation(values))
    if issubclass(type(feature_values), Explanation):
        feature_values = feature_values.values
    if feature_order is None:
        feature_order = np.argsort(-feature_values)
    elif issubclass(type(feature_order), OpChain):
        feature_order = feature_order.apply(Explanation(values))
    elif not hasattr(feature_order, "__len__"):
        raise Exception("Unsupported feature_order: %s!" % str(feature_order))

    instance_order = convert_ordering(instance_order, shap_values)
    # if issubclass(type(instance_order), OpChain):
    #     #xlabel += " " + instance_order.summary_string("SHAP values")
    #     instance_order = instance_order.apply(Explanation(values))
    # elif not hasattr(instance_order, "__len__"):
    #     raise Exception("Unsupported instance_order: %s!" % str(instance_order))
    # else:
    #     instance_order_ops = None

    feature_names = np.array(shap_values.feature_names)[feature_order]
    values = shap_values.values[instance_order][:,feature_order] / cmtickalpha
    feature_values = feature_values[feature_order]

    # if we have more features than `max_display`, then group all the excess features
    # into a single feature
    if values.shape[1] > max_display:
        new_values = np.zeros((values.shape[0], max_display))
        new_values[:, :-1] = values[:, :max_display-1]
        new_values[:, -1] = values[:, max_display-1:].sum(1)
        new_feature_values = np.zeros(max_display)
        new_feature_values[:-1] = feature_values[:max_display-1]
        new_feature_values[-1] = feature_values[max_display-1:].sum()
        feature_names = [
            *feature_names[:max_display-1],
            f"Sum of {values.shape[1] - max_display + 1} other features",
        ]
        values = new_values
        feature_values = new_feature_values

    # define the plot size based on how many features we are plotting
    row_height = 0.5 if row_height is None else row_height
    if figsize is not None:
        # 确保figsize是一个元组(宽度, 高度)
        pl.gcf().set_size_inches(figsize[0], figsize[1])
    else:
        pl.gcf().set_size_inches(plot_width, values.shape[1] * row_height + 2.5)
    ax = pl.gca()

    # plot the matrix of SHAP values as a heat map
    vmin, vmax = np.nanpercentile(values.flatten(), [1, 99])
    ax.imshow(
        values.T,
        aspect='auto', # aspect=0.7 * values.shape[0] / values.shape[1],
        interpolation="nearest",
        vmin=min(vmin,-vmax),
        vmax=max(-vmin,vmax),
        cmap=cmap,
    )

    # adjust the axes ticks and spines for the heat map + f(x) line chart
    ax.xaxis.set_ticks_position("bottom")
    ax.yaxis.set_ticks_position("left")
    ax.spines[["left", "right"]].set_visible(True)
    ax.spines[["left", "right"]].set_bounds(values.shape[1] - row_height, -row_height)
    ax.spines[["top", "bottom"]].set_visible(False)
    ax.tick_params(axis="both", direction="out")

    ax.set_ylim(values.shape[1] - row_height, -3)
    heatmap_yticks_pos = np.arange(values.shape[1])
    heatmap_yticks_labels = feature_names
    ax.yaxis.set_ticks(
        [-1.5, *heatmap_yticks_pos],
        [r"$f(x)$", *heatmap_yticks_labels],
        fontsize=13,
    )
    # remove the y-tick line for the f(x) label
    ax.yaxis.get_ticklines()[0].set_visible(False)

    ax.set_xlim(-0.5, values.shape[0] - 0.5)
    if xlabel is not None:
        ax.set_xlabel(xlabel)

    # plot the f(x) line chart above the heat map
    ax.axhline(-1.5, color="#aaaaaa", linestyle="--", linewidth=1.5)
    fx = values.T.mean(0)
    ax.plot(
        -fx / np.abs(fx).max() - 1.5,
        color="#000000",
        linewidth=2.5,
    )

    # plot the bar plot on the right spine of the heat map
    if isBar:
        bar_container = ax.barh(
            heatmap_yticks_pos,
            (feature_values / np.abs(feature_values).max()) * values.shape[0] / 20,
            height=0.7,
            align="center",
            color="#000000",
            left=values.shape[0] * 1.0 - 0.5,
            # color=[colors.red_rgb if shap_values[feature_inds[i]] > 0 else colors.blue_rgb for i in range(len(y_pos))]
        )

        for b in bar_container:
            b.set_clip_on(False)

    if withcmap:
        # draw the color bar
        import matplotlib.cm as cm
        m = cm.ScalarMappable(cmap=cmap)
        m.set_array([min(vmin, -vmax), max(-vmin, vmax)])
        cb = pl.colorbar(
            m,
            ticks=[min(vmin, -vmax), max(-vmin, vmax)],
            ax=ax,
            aspect=65,
            fraction=0.02,
            pad=0.01,  # padding between the cb and the main axes
        )
        import matplotlib.ticker as ticker
        cb.ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))
        cb.set_label(labels["VALUE"] if cmTitle is None else cmTitle, size=cmTitleSize if cmTitle is None else cmTitleSize, labelpad=-10  if cmlabelpad is None else cmlabelpad)
        cb.ax.tick_params(labelsize=cmTickSize, length=0)
        cb.set_alpha(1)
        cb.outline.set_visible(False)

    if show:
        pl.show()


✓ 已保存: save\CH4_shap_heatmap.png (包含 12 个管理特征)
✓ 已保存: save\N2O_shap_heatmap.png (包含 12 个管理特征)
✓ 已保存: save\SOCSR_shap_heatmap.png (包含 12 个管理特征)
✓ 已保存: save\Yield_shap_heatmap.png (包含 12 个管理特征)
✓ 已保存: save\NH3_shap_heatmap.png (包含 12 个管理特征)
✓ 已保存: save\NL_shap_heatmap.png (包含 12 个管理特征)
✓ 已保存: save\NR_shap_heatmap.png (包含 12 个管理特征)
✓ 已保存: save\NO_shap_heatmap.png (包含 12 个管理特征)

所有管理特征的SHAP热力图生成完成！


In [38]:
"""
Figure 2
"""
import matplotlib as mpl

RANGE_DICT = {
    'CH4': [-250, 500], 
    'N2O': [-3, 3], 
    'SOCSR': [-0.5, 2], 
    'Yield': [-3000, 3000],
    'NH3': [-40, 60], 
    'NL': [-10, 10], 
    'NR': [-12, 12], 
    'NO': [-0.2, 0.2],
}

STEP_DICT = {
    'CH4': 250, 
    'N2O': 1.5, 
    'SOCSR': 0.5, 
    'Yield': 1500,
    'NH3': 20, 
    'NL': 5, 
    'NR': 4, 
    'NO': 0.1,
}

UNIT_DICT = {
    "CH4": f"(kg ha$^{{-1}}$)",
    "N2O": f"(kg ha$^{{-1}}$)",
    "SOCSR": f"(t C ha$^{{-1}}$)",
    "Yield": f"(kg ha$^{{-1}}$)",
    "NH3": f"(kg N ha$^{{-1}}$)",
    "NL": f"(kg N ha$^{{-1}}$)",
    "NR": f"(kg N ha$^{{-1}}$)",
    "NO": f"(kg N ha$^{{-1}}$)"
}

for tar in TARGETS:
    model_path, scaler_path = find_model_and_scaler_paths(model_dir, tar)
    
    if model_path is None or scaler_path is None:
        print(f"✗ {tar}: No model or scaler")
        continue
    data_dir = "../../../data/ml_dataset/shap"
    data_path = os.path.join(data_dir, f'{tar}.csv')
    
    if not os.path.exists(data_path):
        print(f"✗ {tar}: File not existed: {data_path}")
        continue
    
    features, std_features, feature_names = load_and_prepare_data(data_path, scaler_path)
    
    concern_cols = [col for col in MANAGMENTS_DICTS.keys() if col in features.columns]
    
    if not concern_cols:
        print(f"✗ {tar}: not found in MANAGMENTS_DICTS")
        continue
    
    feature_values_real = features[concern_cols].values
    
    model = load_model(model_path)
    scaler = load_model(scaler_path)
    
    explainer = ModelExplainer(
        model=model, 
        std_scaler=scaler, 
        features=features,
        feature_names=feature_names, 
        target_names=None
    )
    
    shap_df = pd.read_csv(f"data/{tar}_shap_all.csv")
    
    shap_values_filtered = shap_df[concern_cols].values
    
    feature_names_filtered = [MANAGMENTS_DICTS[col] for col in concern_cols]
    
    x_min, x_max = RANGE_DICT[tar]

    mask = (shap_values_filtered >= x_min) & (shap_values_filtered <= x_max)
    
    mask_row = mask.any(axis=1)
    shap_values_filtered = shap_values_filtered[mask_row]
    feature_values_real = feature_values_real[mask_row]
    
    explanation = shap.Explanation(
        values=shap_values_filtered,
        feature_names=feature_names_filtered,
        data=feature_values_real
    )
    

    shap.summary_plot(
        explanation.values, 
        explanation.data,
        feature_names=explanation.feature_names,
        plot_type="dot",
        show=False,
        max_display=len(concern_cols)
    )
    
    colors = ["#c1533d", "#d46d2b", "#e89316", "#e8b70b", "#b0ee00", "#3fdf03", "#02cb23", "#1cb16f", "#1a9693", "#155c87"]
    custom_cmap = mpl.colors.LinearSegmentedColormap.from_list('custom_purple_yellow', colors)
    
    fig = plt.gcf()
    fig.set_size_inches(10, 7.5)  
    
    ax = plt.gca()
    

    for line in ax.get_lines():
        if (line.get_color() == '#cccccc' and 
            line.get_linestyle() == '--' and 
            abs(line.get_linewidth() - 0.5) < 1e-3):
            line.set_linewidth(0.5)  
            line.set_dashes((2, 2))  
            line.set_color('#666666')  
            line.set_zorder(-1)  
    
    for artist in ax.get_children():
        if isinstance(artist, mpl.collections.PathCollection):
            artist.set_cmap(custom_cmap)
            artist.set_sizes([50])  
            artist.set_edgecolor('white')
            artist.set_linewidth(0.2)
            array = artist.get_array()
            if array is not None:
                if len(plt.gcf().axes) > 1:
                    plt.gcf().axes[-1].remove()

                vmin, vmax = array.min(), array.max()

               
                norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)

                
    x_min, x_max = RANGE_DICT[tar]
    x_range = x_max - x_min
    step = STEP_DICT[tar]
    
    import warnings
    warnings.filterwarnings("ignore")
    if step > 1:
        x_ticks = np.arange(np.floor(x_min/step)*step, np.ceil(x_max/step)*step + 1, step)
    else:
        x_ticks = np.arange(x_min, x_max + 0.0001, step)
    ax.set_xticks(x_ticks)
    ax.set_xlabel(f"SHAP value {UNIT_DICT[tar]}", fontsize=32, fontname='Arial')
    ax.tick_params(axis='both', which='major', labelsize=32)
    
    ax.spines['bottom'].set_linewidth(2)  
    ax.spines['bottom'].set_color('#000000')  
    
    ax.tick_params(
        axis='x',          
        which='major',     
        width=2,           
        color='#000000'    
    )
    
    for line in ax.get_lines():
   
        if np.allclose(line.get_xdata(), [0, 0]):  
            line.set_linewidth(3) 
            line.set_linestyle('--')  
    
    ax.set_xlim(x_min, x_max)

    ax.set_clip_on(True)

    for artist in ax.get_children():
        if hasattr(artist, 'set_clip_on'):
            artist.set_clip_on(True)
        if hasattr(artist, 'set_clip_box'):
            artist.set_clip_box(ax.bbox)
    

    for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
                 ax.get_xticklabels() + ax.get_yticklabels()):
        item.set_fontname('Arial')
        
    plt.text(-0.25, 1.035, NUMBER_DICT[tar], transform=plt.gca().transAxes, 
         fontsize=32, fontweight='bold', 
         verticalalignment='center', horizontalalignment='center')
    plt.text(0.5, 1.035, SHOW_DICT[tar], transform=plt.gca().transAxes, 
          fontsize=32, verticalalignment='center', horizontalalignment='center')

    
    save_path = os.path.join(save_dir, f"{tar}_shap_summary.png")
    plt.tight_layout()
    plt.savefig(save_path, dpi=300, facecolor='white', pad_inches=0.01, bbox_inches='tight')
    plt.close(fig)
    print(f"✓ Saved summary_plot: {save_path}")
    
print("Done")

✓ 已保存summary_plot: save\CH4_shap_summary.png
✓ 已保存summary_plot: save\N2O_shap_summary.png
✓ 已保存summary_plot: save\SOCSR_shap_summary.png
✓ 已保存summary_plot: save\Yield_shap_summary.png
✓ 已保存summary_plot: save\NH3_shap_summary.png
✓ 已保存summary_plot: save\NL_shap_summary.png
✓ 已保存summary_plot: save\NR_shap_summary.png
✓ 已保存summary_plot: save\NO_shap_summary.png

所有管理特征的SHAP summary_plot生成完成！


In [36]:
from matplotlib.cm import ScalarMappable
fig_legend, ax_legend = plt.subplots(figsize=(16, 0.4))  
sm = ScalarMappable(cmap=custom_cmap, norm=norm)
sm.set_array([])

cb = fig_legend.colorbar(sm, cax=ax_legend, orientation='horizontal',  
                aspect=40, pad=0.02)

cb.set_ticks([vmin, vmax])

cb.set_ticklabels(['Low', 'High'])  

cb.ax.tick_params(labelsize=36)
for label in cb.ax.get_xticklabels():  
    label.set_fontname('Arial')
    label.set_fontsize(36)
    
cb.outline.set_visible(False)
cb.ax.tick_params(labelsize=36, length=0)

legend_title = f"Standardized feature value"  

cb.set_label(legend_title, rotation=0, labelpad=-25, fontsize=36, fontname='Arial', 
             loc='center', y=1.2)  
plt.tight_layout()
plt.savefig(f"save/summary_plot_legend.png", dpi=100, bbox_inches='tight', pad_inches=0.01)
plt.close(fig_legend)