# (Tabular) MNIST

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import functools

import pred_diff.preddiff as preddiff
from pred_diff.imputers import vae_impute, general_imputers
from pred_diff.tools import utils_mnist as ut_mnist
from pred_diff.tools import init_plt

In [None]:
# paper style
init_plt.update_rcParams(fig_width_pt=234.88*2)

# default
# plt.style.use('default')


Load MNIST data and train a model

In [None]:
model, df_train, df_test, target_test = ut_mnist.get_model_and_data(max_epochs=1, retrain=False)
print(f"How much overconfident is the model?\n"
      f"temperature rescaling factor: model.T = {model.T:.2f}")

Select parameter for *PredDiff*

In [None]:
imputer_selection = 'TrainSet'
# imputer_selection = 'VAE'

n_imputations = 600
# possible values: 1, 2, 4, 7, 14
filter_size = 4

In [None]:
iparam = ut_mnist.ImgParams(n_pixel=28, block_size=filter_size)
    

if imputer_selection == 'TrainSet':
    imputer = general_imputers.TrainSetImputer(train_data=df_train.to_numpy())
elif imputer_selection == 'VAE':
    imputer = vae_impute.VAEImputer(df_train=df_train, epochs=20, gpus=0)
else:
    assert False, f'please enter a valid imputer_selection = {imputer_selection}'

pd_explainer = preddiff.PredDiff(model, df_train, n_imputations=n_imputations, regression=False,
                                 imputer=imputer, fast_evaluation=True, n_group=200, unified_integral=False)

Select data

In [None]:
data_selection = 'PaperSelection'
# data_selection = 'RandomSelection'

In [None]:
if data_selection == 'PaperSelection':
    data = df_test.iloc[[4, 15, 84, 9]]         # one digits each: 4, 5, 8, 9
elif data_selection == 'RandomSelection':
    data = df_test.iloc[np.random.randint(low=0, high=df_test.shape[0], size=2)]
else:
    assert False, f'please enter a valid data_selection = {data_selection}'

    
data_np = data.to_numpy().reshape(-1, iparam.n_pixel, iparam.n_pixel)

In [None]:
# calculate relevances
m_relevance, prediction_prob, m_list = ut_mnist.get_relevances(explainer=pd_explainer,
                                                               data=data_np, img_params=iparam)
# m_relevance, prediction_prob, m_list = ut_mnist.get_relevances(explainer=pd_explainer, data=data, img_params=iparam)

In [None]:
plot_selection = 'PredictedClass'
# plot_selection = 'FourClasses'


In [None]:
for img_id in np.arange(data.shape[0]):
    n_importance = 1
    i_reference = ut_mnist.get_reference_pixel(m_relevance=m_relevance, prediction_prob=prediction_prob,
                                               img_id=img_id, n_importance=n_importance)
    m_interaction = ut_mnist.get_interaction(explainer=pd_explainer, data=data_np, iparam=iparam, m_list=m_list,
                                             i_reference=i_reference)
#     m_interaction = ut_mnist.get_interaction(explainer=explainer, data=data, iparam=iparam, m_list=m_list,
#                                              i_reference=i_reference)

    i_vertical, i_horizontal = divmod(i_reference, iparam.max_index)

    rect = functools.partial(ut_mnist.plot_rect, i_reference=i_reference, iparam=iparam)

    if plot_selection == 'PredictedClass':
        ut_mnist.plot_predicted_digit(relevance=m_relevance, interaction=m_interaction, prob_classes=prediction_prob,
                                      data_digit=data, rect=rect, img_params=iparam, image_id=img_id)

    elif plot_selection == 'FourClasses':
        ut_mnist.plot_comparison(m_list_collected=m_relevance, prob_classes=prediction_prob, data_digit=data,
                                 img_params=iparam, image_id=img_id)
