# 4.1. SHAP Values: Origins and Applications 
### Alex Gagliano (gaglian2@mit.edu)
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/alexandergagliano/InterpretabilityDemos/blob/main/Tutorial%204.1%20Shapley.ipynb)

References and resources for additional reading:
* [A Unified Approach to Interpreting Model Predictions](https://github.com/shap/shap) (Lundberg & Lee, 2017)
* [Explaining Predictive Uncertainty with Information Theoretic Shapley Values](https://arxiv.org/pdf/2306.05724.pdf) (Watson et al., 2023)
* [Interpretable Machine Learning](https://christophm.github.io/interpretable-ml-book/), Christoph Molnar

## Interpretability is deeply subjective and problem-dependent.
Is understanding how the model is constructed enough, or do we want the construction to be intuitive (e.g., through regularization of the latent space)? Is it sufficient that a model is _robust_ (generalizes well to unseen data), or would we like to make causal statements linking features to the model's predictions?   

However you define it, interpretability often decreases as model complexity increases; we build a deep neural network for its expressivity, but then fail to understand how input leads to its output. In these cases, it can be useful to build a simpler _explanation_ model, a highly interpretable approximation of our original model. There are many ways to build explanation models; we can build a random forest and plot individual trees, or a linear regression model and examine its coefficients. 

Additive Feature Attribution Models are explanation models that take the following functional form:
$$
g(z') = \phi_0 + \sum_{i=1}^N \phi_i z'_i
$$

Here $\phi$ are the coefficients of the linear model we'd like to find, and $z'$ is a binary encoding of $N$ features in the model. The features are binarized such that, if the output approximates our more complex model, the derived coefficients $\phi$ (and their sign) become a proxy for how much each feature contributes to our final output.

These types of explanation models have been applied in many different forms through the literature: some interpretability tools using them include LIME [(Ribeiro et al., 2016)](https://arxiv.org/abs/1602.04938), DeepLIFT [(Shrikumar et al., 2017)](https://arxiv.org/abs/1704.02685), and Layer-Wise Relevance Propagation [(Binder et al., 2016)](https://arxiv.org/abs/1604.00825). Another explanation of these from cooperative game theory is the _Shapley value_. This tutorial provides a basic description for these below and then dives into some concrete applications.

## Question: Let's say we're designing an additive feature attribution model $g(z')$. What properties should the model satisfy?

1. *Local accuracy*: The function should approximate the model output for the same input features $x$:
$$
f(x) \approx g(z')
$$
We only require local accuracy because, if $g$ could approximate our more complex $f$ globally, we wouldn't need it to begin with.

2. *Missingness*: An absent feature should not influence the output:
$$
z_i' = 0 \rightarrow \phi_i = 0$$
Technically, if $z_i' = 0$ then _any_ value of $\phi_i$ would be possible (the term would have no impact on our output value $g(z')$). Constraining $\phi_i=0$ in these cases matches our intuition: if we don't include feature $i$, it has no bearing on our model output.

3. *Consistency*: If model 1's output increases with feature *i* more than in model 2, then its coefficient for feature $i$ in the linear model should be larger; in other words,
$$f_1(x \cup \{i\}) - f_1(x) > f_2(x \cup \{i\}) - f_2(x) $$
$$\rightarrow \phi_{i, f_1} > \phi_{i, f_2}$$

Demanding these three constraints, only a *single* function satisfies all these properties. 
$$
\phi_i = \frac{1}{N!}\sum_{R} \left[f(x_R \cup \{i\}) - f(x_R) \right]
$$
Where $R$ is an order-specific subset (called a _coalition_) of features from set $S$ and $N$ is the total number of features.


This function calculates Shapley values, first introduced by Lloyd Shapley in 1951. Shapley values are the average expected marginal contributions of each feature (or each player in a cooperative game) after all possible combinations of features (players) have been considered.

Very confusingly, the paper that unified these additive feature attribution models under the Shapley equation called their approach SHAP (SHapley Additive exPlanation; Lundberg & Lee, 2017). Shapley values are the exact maginal contributions, and SHAP approaches are those that estimate them in various way (calculating them precisely is difficult). One popular technique is samping permutations of features using monte carlo.

Let's start by installing and importing some necessary packages.

In [None]:
!pip install xgboost shap torchvision torch seaborn pyarrow gdown rfpimp

In [None]:
import os
import pandas as pd
import xgboost
import numpy as np 
import shap
import seaborn as sns
import matplotlib.pyplot as plt
import subprocess
from bisect import bisect
import itertools 
import rfpimp
import gdown

In [None]:
#make our plots pretty 
sns.set_context("notebook")

## 4.1.1. Photometric redshifts.
Redshift estimation is a field in which machine learning methods are relatively mature, owing to the use of features with observational uncertainties and the significant and unavoidable degeneracies involved (a red, dim galaxy at low-redshift looks like a blue, bright galaxy at high-redshift). In this example, we're going to take a look at simulated galaxy photometry from the [CosmoDC2](https://arxiv.org/abs/1907.06530) catalog. The photometry is generated using an image simulation of the upcoming Vera Rubin Observatory, and includes realistic models for dust from both the Milky Way and the host galaxy, atmospheric distortion, and the optical response system of the telescope. 

First, let's download the data we'll need for all the tutorials and open the redshift example. We'll also download some archival models, in case we can't get anything to load in real-time.

In [None]:
!gdown 1s42ri7tpvBHN-kneC0MHXHsd2nB2BNTE
!gdown 1CC07axEZravXcIkq0w2gnkFGGhO4vbnx

for file in ['data', 'models']:
    subprocess.run(["tar", "-xf", "%s.tar.gz" % file]);

In [None]:
traindf = pd.read_parquet('data/redshift/dc2_gold_training_9816.pq')
testdf = pd.read_parquet('data/redshift/dc2_gold_test_9816.pq')

Let's look at the features in the dataset:

In [None]:
traindf.drop(columns=['redshift'])

Now let's train a gradient boosted decision tree model with XGBoost:

In [None]:
X_train = traindf[['ug', 'gr', 'ri', 'iz', 'zy']]
y_train = traindf['redshift']
model = xgboost.XGBRegressor().fit(X_train, y_train)

In [None]:
#is it a good model? 
X_test = testdf[['ug', 'gr', 'ri', 'iz', 'zy']]
y_test = testdf['redshift']
y_pred = model.predict(X_test)
plt.plot(y_test, y_pred, 'o', ms=1);
plt.ylabel(r"$z_{\rm{phot}}$");
plt.xlabel(r"$z_{\rm{true}}$");
plt.plot(np.linspace(0, 3), np.linspace(0, 3), c='k', ls='--', lw=3);

## Question: How do I estimate the uncertainties on my predicted redshifts? 

One might guess that we can just monte-carlo our estimates, assuming our errors are normally distributed and independent. 

## Challenge: Translate the photometric uncertainties on our observations into color uncertainties, and generate a basic noise model for the first 100 galaxies.
Assume our uncertainties are drawn from a multivariate Normal, and draw 1000 samples of the feature set for each galaxy. Then, apply our XGBoost model to generate 1000 redshift estimates for each galaxy.

In [None]:
# TODO: Translate photometric uncertainties (mag_err_g_lsst, mag_err_r_lsst, etc) to color uncertainties in testdf here.

true_vals = []
redshift_samples = []
for idx in np.arange(10000):
    means =  testdf.iloc[idx][['ug', 'gr', 'ri', 'iz', 'zy']]
    errs = testdf.iloc[idx][['ug_err', 'gr_err', 'ri_err', 'iz_err', 'zy_err']]
    
    num_samples = 1000
    flat_means = means.ravel()
    
    feature_samples = #TODO: draw num_samples from a multivariate Gaussian to generate additional noisy features
    # Store the predicted redshifts and the true redshifts
    redshift_samples.append(model.predict(feature_samples))
    true_vals.append(testdf.iloc[idx]['redshift'])

That would give us the following prediction for a single galaxy:

In [None]:
point =  model.predict(testdf.iloc[[idx]][['ug', 'gr', 'ri', 'iz', 'zy']])
err = np.std(redshift_samples[idx])
print(r"Prediction: z = %.2f +/- %.2f."%(point, err))

But is our guess at the noise model reasonable? Let's plot the CDF for a single galaxy and see where our true redshift falls along the function.

In [None]:
CDF_rank = bisect(np.sort(redshift_samples[0]), true_vals[0])

plt.plot(np.sort(redshift_samples[0]), np.linspace(0, 1, len(redshift_samples[0]), endpoint=False))
plt.plot(true_vals[0], CDF_rank/len(redshift_samples[0]), 'o')
plt.xlabel("Redshift");
plt.ylabel("Empirical CDF");

Clearly our guess for the first galaxy is a bad one. If we plot the rank for all of our galaxies, we would expect them to be uniformly distributed if our noise model was well-calibrated _globally_. That doesn't mean that it can't be locally wrong! Let's look at a larger sample:

In [None]:
# TODO: Generalize the above routine to store the CDF ranks of all galaxies 
# for which we generated samples, and plot a histogram.

plt.xlabel("True Redshift Rank");

## Question: Does the result look like a uniform distribution? If not, what features do you notice? Is your noise model biased? Overestimated?

Propagating the uncertainties directly through a linear model is fine in the Taylor expansion limit with small errors relative to the measurement. Unfortunately, XGBoost uses a softmax objective function to predict a mean point estimate and "double-counts" the measurement error to generate variance in that estimate. Sometimes this is get close to the right variance, but sometimes it doesn't. Be careful!

In practice, it's common to determine a scale factor on all variances that pushes this distribution closer to uniform. But an individual noise estimate can still be wrong. Another way to understand investigate these is to generate a full posterior from a well-trained normalizing flow model, and compare the uncertainties between approaches. 

Next, let's explain the model's predictions using SHAP.

How informative are our colors? Let's check it out for a low-redshift and then a high-redshift event:

In [None]:
# Note that this same syntax works for LightGBM, CatBoost, scikit-learn, transformers, Spark, etc!
shap.initjs()

explainer = shap.Explainer(model)
shap_values = explainer(X_test)

shap.plots.waterfall(shap_values[0])

## Question: Which features have the largest marginal contribution toward a final prediction for this low-redshift event, and which have the lowest?

What about for a high-redshift event?

In [None]:
shap.initjs()

shap.plots.waterfall(shap_values[-1])

Do you notice any trends between the high-redshift events and the low-redshift ones?

How do we observe the shapley values for an ensemble of predictions, and not just one? 

In [None]:
shap.initjs()

shap.plots.force(shap_values[0:1000])

Here we see that most colors suggest that nearby events are low-redshift. It looks like the feature with the most leverage at low-$z$ (on average) is $g-r$.

How does the entire distribution of SHAP values look for each parameter? 

In [None]:
shap.initjs()

shap.plots.beeswarm(shap_values);

## Question: What features do we observe? 
Remember what SHAP values indicate - the direction the features push the specific expectation value away from the mean expectation value. Can you think of a physically motivated reason for these particular color features? Consider what happens to a galaxy's position in color-color space as it gets redshifted:

<img src="images/ColorZ.gif" alt="ColorZGif" width="800"/>


Before we move on, we should note that we're using gradient boosted trees as our model, and these have some useful model-specific interpretability tricks themselves. These will be _global_, meaning they speak to the overall model across a wide range of features. 

First, because the model constructs a series of decision trees, we can visualize the trees directly to better understand how the model's decision-making process takes place. 

In [None]:
#on macOS, run "brew install graphviz" before executing this cell
fig, ax = plt.subplots(figsize=(20, 20))
xgboost.plot_tree(model, num_trees=1, ax=ax, rankdir='LR')
plt.savefig("DecisionTree.png", dpi=500, bbox_inches='tight')

This tells us _exactly_ what the regressor is doing. Further, the fact that $g-r$ is the first split in this tree suggests it might be very important in the model overall. Unfortunately (at least for interpretability), there can be many trees constructed by the model (XGBoost uses 100 by default, albeit small by neural network standards), making it a bit too onerous to inspect them one-by-one. 

We can also calculate the _feature importances_ of the model. The default here is the "gain", the average increase in accuracy of a tree after introducing the feature in question. Because this is averaged across all trees (ensembling), it is more a more unbiased estimate than what you might get from a single decision tree.

In [None]:
sorted_idx = model.feature_importances_.argsort()
plt.barh(X_test.columns[sorted_idx], model.feature_importances_[sorted_idx])
plt.xlabel("Xgboost Feature Importance");

Our guess from the first tree was right - it looks like $g-r$ is the most important feature in our model. It is important to note that there are _many_ ways to calculate feature importance, and [some of them produce biased results](https://explained.ai/rf-importance/#7). To be safe, always explore multiple feature importance estimators:

In [None]:
fig, axs = plt.subplots(nrows=1, ncols=6, figsize=(20, 8))
axs = axs.ravel()
measures = ['gain','total_gain', 'cover',  'total_cover', 'weight']
palette = itertools.cycle(sns.color_palette('Dark2'))

for i, measure in enumerate(measures):
    importances = model.get_booster().get_score(importance_type=measure)
    axs[i].barh(list(importances.keys()), list(importances.values()), color=next(palette))
    axs[i].set_xlabel("Feat. Importance");
    axs[i].set_title("Importance = %s"%measure)

importance_df = rfpimp.importances(model, X_test, y_test).reindex(importances.keys())
axs[5].barh(importance_df.index, importance_df['Importance'])
axs[5].set_xlabel("Feat. Importance");
axs[5].set_title("Permutation Importance");

Total gain is the increase in accuracy summed across all decision trees (instead of averaged), while coverage (and total coverage) is the mean (and total) number of observations included within the splits determined by a feature. The permutation importance is defined as the decrease in model score caused by randomly shuffling the values of that feature, and is believed to resolve several biases introduced by other estimators. 

We can see how much variability there is between importance estimates. Nonetheless, some trends emerge. In many estimators, $z-y$ and $u-g$ are of low importance to the model - $u$ and $y$ have low overall transmission and the photometry in these bands is typically noisy.

One final point: Look at the SHAP values for a single galaxy versus the gain importance for the model overall:

In [None]:
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(10, 8))
axs = axs.ravel()
palette = itertools.cycle(sns.color_palette('Dark2'))

axs[0].barh(X_test.columns, model.feature_importances_)
axs[0].set_title("Gain Importance");

shap_df = pd.DataFrame({'Shap':shap_values[-1].values}, index=X_test.columns).reindex(importances.keys())
axs[1].barh(shap_df.index, shap_df['Shap'])
axs[1].set_title("SHAP values for z = %.2f galaxy"%y_test.values[-1]);

We see some similarities, but also some important differences - a macroscopic and microscopic view of the same model!

## 4.1.2. Now let's  use Shapley values to better understand a more complex model. 
(Modified from https://github.com/pmocz/artificialneuralnetwork-python) 

Here we're going to attempt to classify galaxies according to their morphology. We consider three classes: spiral, elliptical, and irregular. These are reasonably straightforward to distinguish by eye: Spiral galaxies have a wound spiral structure, elliptical galaxies have a smooth light profile with a bright core, and irregular galaxies have some messy structure (this is a catchall category for the oddballs - irregular galaxies likely formed from some merger, close encounter, or chaotic internal activity):

![](Images/GalaxyTypes.png)

The images we're using are from the Sloan Digital Sky Survey. Let's start off by loading some required packages:

In [None]:
import torch, torchvision
from torchvision import datasets, transforms
from torch import nn, optim
import subprocess
import os
from torch.utils.data import DataLoader
from tqdm import trange

In [None]:
galaxyPath = "./data/galaxy/"

We specify a batch size for our training:

In [None]:
batch_size = 8

Now we create datasets for the train and test sets, making sure to convert to grayscale and represent the image as tensors (so that pytorch can use them). We also use a pytorch generator object to ensure that our random batches are reproducible.

In [None]:
def seed_worker(worker_id):
    worker_seed = torch.initial_seed() % 2**32
    numpy.random.seed(worker_seed)
    random.seed(worker_seed)

g = torch.Generator()
g.manual_seed(0)

data_transform = transforms.Compose([transforms.Grayscale(num_output_channels=1),
                                     transforms.ToTensor()])

train = torchvision.datasets.ImageFolder(os.path.join(galaxyPath, 'train'), transform = data_transform)
train_loader = DataLoader(train, shuffle=True, batch_size=batch_size, worker_init_fn=seed_worker, generator=g)

test = torchvision.datasets.ImageFolder(os.path.join(galaxyPath, 'test'), transform = data_transform)
test_loader = DataLoader(test, shuffle=True, batch_size=batch_size, worker_init_fn=seed_worker, generator=g)

Our classification scheme is encoded as follows:

In [None]:
GalaxyClasses = {0:'Ellip.', 1:'Spiral', 2:'Irreg.'}

Next, we create a basic convolutional neural network for classification.

In [None]:
device = torch.device('cpu')

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 3, kernel_size = 3, stride = 1, padding = 1),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
            nn.Conv2d(3, 20, kernel_size=5),
            nn.Dropout(0.3), #TODO: Modify the dropout fraction here.
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(720, 50),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(50, 3),
            nn.Softmax(dim=1)
        )

    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(-1, 3*16*15)
        x = self.fc_layers(x)
        return x
        
model = Net().to(device)

Our network includes _dropout_, in which random subsets of neurons are deactivated during the training. The number in parentheses sets the dropout fraction. Dropout is an important tool in ensuring efficient training of neural networks, as it prevents a few neurons from overtraining while the majority of neurons learn nothing. This is similar to ensembling - training multiple random subsets of the network at a time and combining the results is often more powerful than training the full network at once.

We now train for 50 epochs, evaluating the training loss every epoch to see how we're improving.

## Challenge: Change the dropout fraction, and evaluate its impact on the training routine.

In [None]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss() 

for epoch in trange(50):  # loop over the dataset multiple times
    for i, data in enumerate(train_loader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
print('Finished Training')

#save the model 
torch.save(model.state_dict(), './models/galaxyCNN.pth')

In [None]:
# If you're having trouble training, pre-load the weights from one that I trained earlier:
#model.load_state_dict(torch.load('./models/galaxyCNN_legacy.pth'))

How accurately do we distinguish between our classes after this training? 

In [None]:
correct_pred = {name: 0 for name in GalaxyClasses.values()}
total_pred = {name: 0 for name in GalaxyClasses.values()}

for x, y in test_loader:
    images, labels = x.to(device), y.to(device)
    outputs = model(images)
    _, predictions = torch.max(outputs, 1)
    for label, prediction in zip(labels, predictions):
        if label == prediction:
            correct_pred[GalaxyClasses[label.item()]] += 1
        total_pred[GalaxyClasses[label.item()]] += 1

for name, correct_count in correct_pred.items():
    accuracy = 100 * float(correct_count) / total_pred[name]
    print("Accuracy for class {} is: {:.1f} %".format(name, accuracy))

## Question: How do we add additional regularization? 
The stoschastic gradient descent (SGD) optimizer's `weight_decay` parameter implements L2-regularization on the network weights, which we learned on Monday can improve generalizability. Try re-running the above network with the new optimizer and the `weight_decay` parameter set. Compute the L2-norm of the weights in each case and compare them. Do the same for multiple values of `weight_decay`. 

In [None]:
# TODO: compute L2 norm of earlier trained weights.

In [None]:
optimizer = #TODO: Implement stochastic gradient descent optimizer, varying weight_decay.
criterion = nn.CrossEntropyLoss() 

for epoch in trange(50):  # loop over the dataset multiple times
    for i, data in enumerate(train_loader, 0):
            # get the inputs; data is a list of [inputs, labels]
            inputs, labels = data

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
print('Finished Training')

#save the model 
torch.save(model.state_dict(), './models/galaxyCNN.pth')

In [None]:
# TODO: compute L2 norm of the new model's trained weights. How does it compare?

Our results aren't terrible for a quick training process. Now let's evaluate our model on a subset of the test set:

In [None]:
# snag one batch of the test set
# TODO: Re-evaluate this cell multiple times to see how the results change. Where does the model struggle?
dataiter = iter(test_loader)
images, labels = next(dataiter)

_, axs = plt.subplots(1, batch_size, figsize=(12, 12))
axs = axs.flatten()

# show all images in greyscale
for i, ax in enumerate(axs):
    ax.imshow(images[i].squeeze(0, 1), cmap='gray')
    
# apply the model to the images, and convert probability scores to classifications    
outputs = model(images)
_, predicted = torch.max(outputs, 1)

# print the results alongside the images
print('Ground Truth:     ', '      '.join(f'{GalaxyClasses[labels.numpy()[j]]:5s}' for j in range(batch_size)))
print('Model Prediction: ', '      '.join(f'{GalaxyClasses[predicted.numpy()[j]]:5s}' for j in range(batch_size)))

## Challenge: Print a few batches of test data and the corresponding model predictions. Which images does the model correctly classify? Do they have anything in common (either classes or features)? 
The first tip for interpreting your model is _looking at your data._

How can we use SHAP values to understand how the individual pixels in the image contribute to the final classification? If we consider each pixel as a feature, we can use our SHAP framework as before. This Shapley estimator is modified from the Deep Learning Important FeaTures (DeepLIFT) method (Shrikumar et al., 2017).

In [None]:
# since shuffle=True, this is a random sample of test data. We increase the batch_size so that we have more images in our background
background_size = # TODO: Choose a background batch size, and then vary it.  
background_loader = DataLoader(test, shuffle=True, batch_size=background_size)
background_batch = next(iter(background_loader))
background_images, background_labels = background_batch

test_batch = next(iter(test_loader))
test_images, test_labels = test_batch

e = shap.DeepExplainer(model, background_images);
shap_values = e.shap_values(test_images)

It's important to note here that a "background" estimate is required across multiple images. As before, SHAP values describe how individual values push an estimate _away from the overall expectation value of the model._ Here, the expectation value is approximated using a finite sample of images. As we increase the number of background images, we more closely approximate the true shapley values for each pixel, but our method will be more computationally expensive.

In [None]:
# We re-arrange our grid of shapley values and test images so that we can show them side-by-side
shap_numpy = [np.swapaxes(np.swapaxes(s, 1, -1), 1, 2) for s in shap_values]
test_numpy = np.swapaxes(np.swapaxes(test_images.numpy(), 1, -1), 1, 2)

In [None]:
# Next, plot them
title_row = list(GalaxyClasses.values())
titles = np.vstack([title_row]*len(test_images))

shap.image_plot(shap_numpy, test_numpy, cmap='coolwarm', labels=titles)

_, predicted = torch.max(model(test_images), 1)
print('Ground Truth:     ', '    '.join(f'{GalaxyClasses[test_labels.numpy()[j]]:5s}' for j in range(len(test_images))))
print('Model Prediction: ', '    '.join(f'{GalaxyClasses[predicted.numpy()[j]]:5s}' for j in range(len(test_images))))

## Challenge: Play around with different background data sizes (and different images) and see how this changes the resultant SHAP values for each image. How many background images do you need for reasonable Shapley estimates?
Think for a bit, then check the solutions notebook for the answer.

### Does this have anything to do with saliency maps? 


Both saliency maps and pixel-based SHAP are 'pixel attribution methods', although saliency maps use image gradients whereas SHAP uses perturbations from the original pixels in each image. We'll talk more about saliency maps later on in the tutorial.

## Conclusion: There are pros and cons to SHAP values. 
On the one hand, it can be useful to estimate the contribution of _a specific feature_ toward the models where it is used relative to _all other possible models_ that don't use it. On the other hand, it is difficult (but not impossible) to use SHAP values to explore correlated features, and SHAP values are always measured relative to the average prediction of the model. It can also be slow for many features.

An open question is how to adapt the SHAP framework for interpreting features of temporal datasets, since with e.g., Recurrent Neural Networks, each recurrent unit uses information from elsewhere in the time-series sample. This could be a useful summer school hack!

### Does this have anything to do with uncertainty quantification?
Let's return for a second to our equation for Shapley values:


$$
\phi_i = \frac{1}{N!}\sum_{R} \left[f(x_R \cup \{i\}) - f(x_R) \right]
$$

This function calculates the average marginal contribution of each feature (or player) $i$ to a value function $f$, which in our case was our model. What if we replaced $f$ with some other value function $v$? 

Two months ago, [a paper was released](https://arxiv.org/pdf/2306.05724.pdf) which defined $v$ as the conditional entropy of a given model. Under this framework, it was proposed, one could estimate the contribution of each feature not to the output of the model but to _its uncertainty_. Let's define a new network, very similar to our CNN above but trained to predict its entropy:

In [None]:
class EntropyNet(nn.Module):
    def __init__(self):
        super(EntropyNet, self).__init__()

        self.conv_layers = nn.Sequential(
            nn.Conv2d(1, 3, kernel_size = 3, stride = 1, padding = 1),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
            nn.Conv2d(3, 20, kernel_size=5),
            nn.Dropout(0.2),
            nn.MaxPool2d(kernel_size=2),
            nn.ReLU(),
        )
        self.fc_layers = nn.Sequential(
            nn.Linear(720, 50),
            nn.ReLU(),
            nn.Dropout(0.2),
            nn.Linear(50, 3),
            nn.Softmax(dim=1)
        )

    # we now define the entropy of the model:
    def entropy(self, x):
        _x = x
        logx = torch.log(_x)
        out = _x * logx
        out = torch.sum(out, 1)
        out = out[:, None]
        return -out
    
    def forward(self, x):
        x = self.conv_layers(x)
        x = x.view(-1, 3*16*15)
        x = self.fc_layers(x)
        #the only change - calculate the entropy of the prediction:
        x = self.entropy(x)
        return x

In [None]:
entropyModel = EntropyNet().to(device)
entropyModel.load_state_dict(torch.load('./models/galaxyCNN.pth'))
#to instead use the pre-trained weights:
#entropyModel.load_state_dict(torch.load('./models/galaxyCNN_legacy.pth')) 
entropyModel.eval();

In [None]:
test_batch = next(iter(test_loader))
test_images, test_labels = test_batch

entropies = entropyModel(test_images)
entropy_explainer = shap.DeepExplainer(entropyModel, background_images)
entropy_shap = entropy_explainer.shap_values(test_images)

In [None]:
# We re-arrange our grid of shapley values and test images so that we can show them side-by-side
entropy_numpy = [np.swapaxes(np.swapaxes(entropy_shap, 1, -1), 1, 2)]
test_numpy = np.swapaxes(np.swapaxes(test_images.numpy(), 1, -1), 1, 2)

In [None]:
# Next, plot them
title_row = 'Uncertainty'
titles = np.vstack([title_row]*len(test_images))
print([GalaxyClasses[x.item()] for x in test_labels])
shap.image_plot(entropy_numpy, test_numpy, cmap='coolwarm', labels=titles)