# SHAP analysis of RANDM thermospheric density model
This notebook is one companion piece to Bard, Murphy, Halford (2025): "Elucidating the Grey Atmosphere: SHAP Value Analysis of a Random Forest Atmospheric Neutral Density Model".

## Part A: Generating SHAP Explanations

Code for generating and saving SHAP explanations for the RANDM model is given here. Since this process takes a long time to execute, the saved SHAP explanations are hosted in `./SHAP_values/` . The other notebooks use these saved SHAP explanations in order to generate the figures/analysis.

If you wish to recreate the SHAP values, it is recommended that you use a computer with many cores and set NJOBS to the highest number possible for parallel processing. The feature data required for the Explainer can be downloaded from https://zenodo.org/uploads/17201537 .

#### Requirements
- numpy
- pandas
- matplotlib
- fasttreeshap
- shap
- mltdm

#### References
- SHAP paper (arxiv link)
- saved SHAP explanations (`./SHAP_values/`)
- Murphy+2025 (RANDM model paper): https://doi.org/10.1029/2024SW003928
- MLTDM repository: https://github.com/kylermurphy/mltdm/

In [5]:
NJOBS = 8 # change based on your computer

In [1]:
import numpy as np
import numpy.random as rand
import pandas as pd
import shap_analysis as shpa
import time as pytime
import matplotlib.pyplot as plt
from matplotlib import gridspec

In [2]:
import fasttreeshap as fts
from fasttreeshap.plots import beeswarm, waterfall, bar
import shap

In [3]:
import mltdm
from mltdm.den_fx import fx_den

In [None]:
# assuming mltdm model has already been set up
# if not, run:
# from mltdm import den_fx
# den_fx.setup()

### Utility functions

In [6]:
# hack to save and load minimal viable explanation, 
# since it takes a while to make
def save_explanation(fname, expln):
    # hack to avoid pickling and associated errors
    dta = expln.data.astype(float)
    np.savez(fname, values=expln.values, 
             base_values=expln.base_values, data=dta,
             feature_names=expln.feature_names, 
             output_names=expln.output_names, compute_time=expln.compute_time)

def load_explanation(fname):
    dat_dict = np.load(fname)
    return fts.Explanation(values = dat_dict["values"], base_values=dat_dict["base_values"], 
                           data=dat_dict["data"], feature_names=dat_dict["feature_names"], 
                           output_names=dat_dict["output_names"], 
                           compute_time=dat_dict["compute_time"])

# hack to blend multiple explanations together
def stack_expln(expln1, expln2):
    return fts.Explanation(values = np.vstack([expln1.values, expln2.values]), 
                             base_values = np.vstack([expln1.base_values, expln2.base_values]), 
                             data = np.vstack([expln1.data, expln2.data]),
                             feature_names=expln1.feature_names,
                            )

In [None]:
# functions to load feature data from hardcoded file
# Download file "FI_GEO_RF_data.h5" from Zenodo: https://zenodo.org/uploads/17201537

# data columns used in best-performing Murphy RF model
fgeo_col = ['1300_02', '43000_09', '85550_13', '94400_18', 'SYM_H index', 
            'AE', 'SatLat', 'cos_SatMagLT', 'sin_SatMagLT']

def _load_datafile(option: str = "test_d", path='./'):
    """
    Loads forest training/test data from CB's checkpointing of
    the train/test split in KM's Notebook (https://github.com/kylermurphy/mltdm/blob/main/Notebooks/RF_model.ipynb)

    option : str
        1. "test_d" : Grace B test data
        2. "train_d" : Grace B training data
        3. "oos_d" : Grace A
        4. "oos2_d" : CHAMP

    """

    return pd.read_hdf(path+"FI_GEO_RF_data.h5", option)


def load_data(storm: bool = None, option: str = "test_d", all_cols: bool = False, path='./'):
    """
    Loads pandas data array from file in `load_datafile`

    `storm` : bool | None
        1. None : load full data
        2. True : load storm data ("storm" == 1)
        3. False : load nonstorm data ("storm" == -1)

    option : str
        1. "test_d" : Grace B test data
        2. "train_d" : Grace B training data
        3. "oos_d" : Grace A
        4. "oos2_d" : CHAMP

    all_cols : bool
        If True, load all data columns
        If False, load just fgeo_columns used in best RF model

    path : str
        path to find FI_GEO_RF_data.h5 datafile, by default './'
    """

    dat = _load_datafile(option, path=path)

    if storm is not None:
        dat = dat[dat['storm'] == ((storm == True) - (storm == False))]

    return dat if all_cols else dat[fgeo_col]

In [9]:
NUM_SAMP = 2000
def sample(gen, df, num_samp=2000):
    """
    Samples df with num_samp points.

    gen: np.random.rand instance
        Used to generate random indices for sample selection
    df: pandas.dataframe
        data to sample
    num_samp : int
        number of points in sample
    """
    num_samp = min(num_samp, df.shape[0])
    ids = gen.choice(range(df.shape[0]), size = num_samp, replace=False)
    return df.iloc[ids]

### Setup

In [None]:
# assuming mltdm model has already been set up
# if not, run:
# from mltdm import den_fx
# den_fx.setup()

In [10]:
mltdm_wrapper = fx_den(dropAE=False)
rf = mltdm_wrapper.rfmod # this is RANDM, as created in Murphy+25

Loading fx_den_bz2.skops, this will take a few minutes.


In [18]:
rf

In [14]:
# used to generate SHAP values for the events in the data
explainer = fts.TreeExplainer(rf, n_jobs = NJOBS)

In [None]:
full_data = load_data(all_cols=True, option="test_d", path='./') # loads Grace B test data

### Data Part 1: 
Samples for storm, quiet, recovery, mainphase as classified by Murphy+18,20.

#### References
----------
1. Murphy+18, https://doi.org/10.1002/2017GL076674
2. Murphy+20, https://doi.org/10.1029/2020SW002477

In [10]:
# sort into more specific datasets
storm_data = full_data[full_data['storm'] == 1]
quiet_data = full_data[full_data['storm'] == -1]

In [14]:
# NOTE: we reset the generator here in order to match exact indices 
# for sampling used in paper
SEED = 693993
gen = rand.default_rng(SEED)

storm_samp_data = sample(gen, storm_data)
quiet_samp_data = sample(gen, quiet_data)
mainphase_samp_data = sample(gen, mainphase_data)
recovery_samp_data = sample(gen, recovery_data)

In [11]:
recovery_data = storm_data[storm_data["storm phase"] == 2]
mainphase_data = storm_data[storm_data["storm phase"] == 1]

In [21]:
mainphase_samp_data = sample(mainphase_data)
recovery_samp_data = sample(recovery_data)

#### SHAP values

WARNING: this will take a while on a laptop, depending on the NUM_SAMP. For 2000 pts on my laptop (NJOBS=8), it took about 25-30 mins, depending on multitasking.

To run the SHAP value code, uncomment the cells below. Alternatively, the saved `.npz` files from the below commands are found in `./SHAP_values/`.

In [12]:
"""
print("start shap storm")
tfirst = pytime.perf_counter()
storm_shap = explainer(storm_samp_data[fgeo_col])
tnow = pytime.perf_counter()
print(f"Time taken: {tnow - tfirst}")
save_explanation("storm1", storm_shap)
"""

start shap storm
Time taken: 1739.8983653000032


In [19]:
"""
print("start shap quiet")
tfirst = pytime.perf_counter()
quiet_shap = explainer(quiet_samp_data[fgeo_col])
tnow = pytime.perf_counter()
print(f"Time taken: {tnow - tfirst}")
save_explanation("quiet1", quiet_shap)
"""

start shap quiet
Time taken: 1506.1464909000206


In [154]:
"""
print("start shap main")
tfirst = pytime.perf_counter()
mainphase_shap = explainer(mainphase_samp_data[fgeo_col])
tnow = pytime.perf_counter()
print(f"Time taken: {tnow - tfirst}")
save_explanation("mainphase1", mainphase_shap)
"""

start shap main
Time taken: 1556.1542722999584


In [155]:
"""
print("start shap recovery")
tfirst = pytime.perf_counter()
recovery_shap = explainer(recovery_samp_data[fgeo_col])
tnow = pytime.perf_counter()
print(f"Time taken: {tnow - tfirst}")
save_explanation("recovery1", recovery_shap)
"""

start shap main
Time taken: 1481.5337197000626


### Data Part 2: 
Samples for broad SYM-H binning. 

In [13]:
small_symH = full_data[(full_data["SYM_H index"] < 0) & (full_data["SYM_H index"] > -50)]
mod_symH = full_data[(full_data["SYM_H index"] < -50) & (full_data["SYM_H index"] > -100)]
large_symH = full_data[full_data["SYM_H index"] < -100]

In [11]:
small_symH.shape, mod_symH.shape, large_symH.shape

((198948, 14), (6051, 14), (1137, 14))

In [18]:
# NOTE: we reset the generator here in order to match exact indices 
# for sampling used in paper
SEED = 693993
gen = rand.default_rng(SEED)

small_symH_samp = sample(gen, small_symH)
mod_symH_samp = sample(gen, mod_symH)
large_symH_samp = sample(gen, large_symH)

#### SHAP values

SHAP `.npz` files are provided in `./SHAP_values/`.

In [56]:
"""
print("start sym-h smol")
tfirst = pytime.perf_counter()
small_symH_shap = explainer(small_symH_samp[fgeo_col])
tnow = pytime.perf_counter()
print(f"Time taken: {tnow - tfirst}")
save_explanation("symHsml", small_symH_shap)
"""

start sym-h smol
Time taken: 1551.2525466999505


In [57]:
"""
print("start sym-h moderate")
tfirst = pytime.perf_counter()
mod_symH_shap = explainer(mod_symH_samp[fgeo_col])
tnow = pytime.perf_counter()
print(f"Time taken: {tnow - tfirst}")
save_explanation("symHmod", mod_symH_shap)
"""

start sym-h moderate
Time taken: 1492.3360663999338


In [58]:
"""
print("start sym-h large")
tfirst = pytime.perf_counter()
large_symH_shap = explainer(large_symH_samp[fgeo_col])
tnow = pytime.perf_counter()
print(f"Time taken: {tnow - tfirst}")
save_explanation("symHlarge", large_symH_shap)
"""

start sym-h large
Time taken: 844.8028948999709


### Data part 3:
Samples for 5-nT bins of SYM-H, from -20 to -75 nT.

In [9]:
# function to get specifc bins
def SymH_bin(df, low=0., high = -100.):
    # (low, high) are l/r of bins in storm time
    assert low > high, "Sym-h is defined in negative direction; high is 'more negative' than low"
    return df[(df["SYM_H index"] <= low) & (df["SYM_H index"] > high)]

In [11]:
sym_h_bins = list(range(-20,-76,-5))
sym_h_binned_samps = []

In [12]:
sym_h_bins

[-20, -25, -30, -35, -40, -45, -50, -55, -60, -65, -70, -75]

In [13]:
# Need to reset RNG here
SEED = 693993
gen = rand.default_rng(SEED)
for i in range(len(sym_h_bins)-1):
    print(f"binning in [{sym_h_bins[i]}, {sym_h_bins[i+1]})")
    sym_h_binned_samps.append(sample(gen, SymH_bin(full_data, sym_h_bins[i], sym_h_bins[i+1]), num_samp=500))

binning in [-20, -25)
binning in [-25, -30)
binning in [-30, -35)
binning in [-35, -40)
binning in [-40, -45)
binning in [-45, -50)
binning in [-50, -55)
binning in [-55, -60)
binning in [-60, -65)
binning in [-65, -70)
binning in [-70, -75)


#### SHAP values
`.npz` files are found in `./SHAP_values/`, under names `sym_h_bin-##.npz`: <>

In [16]:
"""
bn_shap = []
for i,bn_samp in enumerate(sym_h_binned_samps):
    print(f"start bin {i}")
    tfirst = pytime.perf_counter()
    bn_shap.append(explainer(bn_samp[fgeo_col]))
    tnow = pytime.perf_counter()
    print(f"Time taken: {tnow - tfirst}")
    save_explanation(f"sym_h_bin{sym_h_bins[i]}", bn_shap[-1])
"""

start bin 0
Time taken: 378.9885622999991
start bin 1
Time taken: 367.1852671000015
start bin 2
Time taken: 351.637211199999
start bin 3
Time taken: 340.28905210000084
start bin 4
Time taken: 342.40822300000036
start bin 5
Time taken: 339.72353420000036
start bin 6
Time taken: 336.50638190000063
start bin 7
Time taken: 338.3492058000011
start bin 8
Time taken: 345.77513050000016
start bin 9
Time taken: 348.92111379999915
start bin 10
Time taken: 337.58933689999867


##### end

The next notebook, Part B, will feature simple analysis of the global input features.