<a href="https://colab.research.google.com/github/thunderhoser/cira_ml_short_course/blob/master/lecture05_notebook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Convolutional neural networks (CNN)

This notebook was adapted from  Lagerquist, Ryan, and David John Gagne II, 2020: "Lecture 5: Convolutional neural networks". 

 <font color='red'>Note: Make sure to use a GPU (graphical processing unit) to run this notebook!</font>
The CNNs in this notebook run $\sim$100 times faster with a GPU. 

# <font color='red'>Clone the Git repository (required)</font>

- **Please note**: when a section title is in <font color='red'>red</font>, that means the code cell below is required.
- In other words, if you don't run the code cell below, subsequent code cells might not work.

In [None]:
import os
import shutil

if os.path.isdir('course_repository'):
    shutil.rmtree('course_repository')

!git clone https://github.com/thunderhoser/cira_ml_short_course course_repository
!cd course_repository; python setup.py install

!pip uninstall -y netCDF4
!pip install netCDF4

!pip uninstall -y cftime
!pip install cftime

# <font color='red'>Import packages (required)</font>

The next two cells import all packages used in the notebook.

In [None]:
import sys

sys.path.append('/content/data/')
sys.path.append('/content/course_repository/')
sys.path.append('/content/course_repository/cira_ml_short_course/')
sys.path.append('/content/course_repository/cira_ml_short_course/utils/')

In [None]:
%matplotlib inline
import copy
import random
import os.path
import warnings
import numpy
import keras
from matplotlib import pyplot
import tensorflow.compat.v1 as tf
from cira_ml_short_course.utils import utils, image_utils, \
    image_normalization, image_thresholding, cnn, upconvnet, saliency, \
    class_activation, novelty_detection
from cira_ml_short_course.utils import backwards_optimization as backwards_opt
from cira_ml_short_course.plotting import image_plotting, permutation_plotting

tf.disable_v2_behavior()
warnings.filterwarnings('ignore')

SEPARATOR_STRING = '\n\n' + '*' * 50 + '\n\n'
MINOR_SEPARATOR_STRING = '\n\n' + '-' * 50 + '\n\n'

DATA_DIRECTORY_NAME = '/content/data/track_data_ncar_ams_3km_nc_small'

BEST_HIT_MATRIX_KEY = 'best_hits_predictor_matrix'
WORST_FALSE_ALARM_MATRIX_KEY = 'worst_false_alarms_predictor_matrix'
WORST_MISS_MATRIX_KEY = 'worst_misses_predictor_matrix'
BEST_CORRECT_NULLS_MATRIX_KEY = 'best_correct_nulls_predictor_matrix'
PREDICTOR_NAMES_KEY = 'predictor_names'

# <font color='red'>Prevent auto-scrolling (required)</font>

In [None]:
%%javascript
IPython.OutputArea.prototype._should_scroll = function(lines) {
    return false;
}

# <font color='red'>Download input data (required)</font>

The next cell downloads all input data used in this notebook.

In [None]:
!python /content/course_repository/download_image_data.py

# <font color='red'>Read input data (required)</font>

The next cell reads all input data for this notebook into memory.

In [None]:
training_file_names = image_utils.find_many_files(
    first_date_string='20100101', last_date_string='20141224',
    directory_name=DATA_DIRECTORY_NAME
)

validation_file_names = image_utils.find_many_files(
    first_date_string='20150101', last_date_string='20151231',
    directory_name=DATA_DIRECTORY_NAME
)

testing_file_names = image_utils.find_many_files(
    first_date_string='20160101', last_date_string='20171231',
    directory_name=DATA_DIRECTORY_NAME
)

training_image_dict = image_utils.read_many_files(training_file_names)
print('\n')

validation_image_dict = image_utils.read_many_files(validation_file_names)
print('\n')

testing_image_dict = image_utils.read_many_files(testing_file_names)

# Data exploration

## Plot random example with wind barbs

 - The next cell plots a random example with wind barbs.
 - One example = one storm object = one storm cell at one time.

In [None]:
predictor_matrix = (
    validation_image_dict[image_utils.PREDICTOR_MATRIX_KEY]
)
predictor_names = validation_image_dict[image_utils.PREDICTOR_NAMES_KEY]

num_examples = predictor_matrix.shape[0]
random_index = random.randint(0, num_examples - 1)
predictor_matrix = predictor_matrix[random_index, ...]

temperature_matrix_kelvins = predictor_matrix[
    ..., predictor_names.index(image_utils.TEMPERATURE_NAME)
]
min_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 1)
max_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 99)

image_plotting.plot_many_predictors_with_barbs(
    predictor_matrix=predictor_matrix, predictor_names=predictor_names,
    min_colour_temp_kelvins=min_temp_kelvins,
    max_colour_temp_kelvins=max_temp_kelvins
)

pyplot.show()

## Plot strong example with wind barbs

The next cell plots the strongest example in the validation data (that with the greatest max future vorticity), using wind barbs.

In [None]:
target_matrix_s01 = validation_image_dict[image_utils.TARGET_MATRIX_KEY]
example_index = numpy.unravel_index(
    numpy.argmax(target_matrix_s01), target_matrix_s01.shape
)[0]

predictor_matrix = validation_image_dict[image_utils.PREDICTOR_MATRIX_KEY][
    example_index, ...
]
predictor_names = validation_image_dict[image_utils.PREDICTOR_NAMES_KEY]

temperature_matrix_kelvins = predictor_matrix[
    ..., predictor_names.index(image_utils.TEMPERATURE_NAME)
]
min_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 1)
max_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 99)

image_plotting.plot_many_predictors_with_barbs(
    predictor_matrix=predictor_matrix, predictor_names=predictor_names,
    min_colour_temp_kelvins=min_temp_kelvins,
    max_colour_temp_kelvins=max_temp_kelvins
)

pyplot.show()

## Plot random example without wind barbs

 - The next cell plots a random example without wind barbs.
 - In this case, the wind field is plotted as two scalar fields ($u$-wind and $v$-wind).
 - This plotting format will be used when interpretation quantities (represented by line contours) are overlain on the predictors.

In [None]:
predictor_matrix = (
    validation_image_dict[image_utils.PREDICTOR_MATRIX_KEY]
)
predictor_names = validation_image_dict[image_utils.PREDICTOR_NAMES_KEY]

num_examples = predictor_matrix.shape[0]
random_index = random.randint(0, num_examples - 1)
predictor_matrix = predictor_matrix[random_index, ...]

temperature_matrix_kelvins = predictor_matrix[
    ..., predictor_names.index(image_utils.TEMPERATURE_NAME)
]
min_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 1)
max_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 99)

wind_speed_matrix_m_s01 = numpy.sqrt(
    predictor_matrix[..., predictor_names.index(image_utils.U_WIND_NAME)] ** 2 +
    predictor_matrix[..., predictor_names.index(image_utils.V_WIND_NAME)] ** 2
)
max_speed_m_s01 = numpy.percentile(
    numpy.absolute(wind_speed_matrix_m_s01), 99
)

image_plotting.plot_many_predictors_sans_barbs(
    predictor_matrix=predictor_matrix, predictor_names=predictor_names,
    min_colour_temp_kelvins=min_temp_kelvins,
    max_colour_temp_kelvins=max_temp_kelvins,
    max_colour_wind_speed_m_s01=max_speed_m_s01
)

pyplot.show()

## Plot strong example without wind barbs

The next cell plots the strongest example in the validation data (that with the greatest max future vorticity), using scalar fields instead of wind barbs.

In [None]:
target_matrix_s01 = validation_image_dict[image_utils.TARGET_MATRIX_KEY]
example_index = numpy.unravel_index(
    numpy.argmax(target_matrix_s01), target_matrix_s01.shape
)[0]

predictor_matrix = validation_image_dict[image_utils.PREDICTOR_MATRIX_KEY][
    example_index, ...
]
predictor_names = validation_image_dict[image_utils.PREDICTOR_NAMES_KEY]

temperature_matrix_kelvins = predictor_matrix[
    ..., predictor_names.index(image_utils.TEMPERATURE_NAME)
]
min_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 1)
max_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 99)

wind_speed_matrix_m_s01 = numpy.sqrt(
    predictor_matrix[..., predictor_names.index(image_utils.U_WIND_NAME)] ** 2 +
    predictor_matrix[..., predictor_names.index(image_utils.V_WIND_NAME)] ** 2
)
max_speed_m_s01 = numpy.percentile(
    numpy.absolute(wind_speed_matrix_m_s01), 99
)

image_plotting.plot_many_predictors_sans_barbs(
    predictor_matrix=predictor_matrix, predictor_names=predictor_names,
    min_colour_temp_kelvins=min_temp_kelvins,
    max_colour_temp_kelvins=max_temp_kelvins,
    max_colour_wind_speed_m_s01=max_speed_m_s01
)

pyplot.show()

# <font color='red'>Find normalization parameters (required)</font>

- The next cell finds normalization parameters for each predictor variable.
- Recall from Lecture 1 that normalization parameters are based on the training data only.
- For $z$-score normalization, the parameters are mean and standard deviation.
<br><br>

- we still compute one mean and standard deviation for each predictor, rather than one for each predictor and grid cell.
- Thus, the mean and standard deviation for each predictor are based on data from all training examples and all 1024 grid cells.
<br><br>

- We will not normalize the data right away.
- We will plot the data frequently throughout this notebook, and we want to plot variables in physical space rather than normalized space.
- However, we will store the normalization parameters and use them when feeding data into a CNN.

In [None]:
normalization_dict = image_normalization.get_normalization_params(
    image_dict=training_image_dict
)

## Sanity check

- The next cell normalizes, and then denormalizes, a small amount of testing data.
- Normalized values should be small positive or negative numbers (mostly from $-3\ldots+3$).
- Denormalized values must equal original (physical) values.

In [None]:
predictor_names = testing_image_dict[image_utils.PREDICTOR_NAMES_KEY]
original_values = (
    testing_image_dict[image_utils.PREDICTOR_MATRIX_KEY][0, :5, :5, 0]
)

print('\nOriginal values of {0:s} for first storm object:\n{1:s}'.format(
    predictor_names[0], str(original_values)
))

testing_image_dict[image_utils.PREDICTOR_MATRIX_KEY], _ = (
    image_normalization.normalize_data(
        predictor_matrix=testing_image_dict[image_utils.PREDICTOR_MATRIX_KEY],
        predictor_names=predictor_names,
        normalization_dict=normalization_dict
    )
)

normalized_values = (
    testing_image_dict[image_utils.PREDICTOR_MATRIX_KEY][0, :5, :5, 0]
)
print((
    '\nNormalized values of {0:s} for first storm object:\n{1:s}'
).format(
    predictor_names[0], str(normalized_values)
))

testing_image_dict[image_utils.PREDICTOR_MATRIX_KEY] = (
    image_normalization.denormalize_data(
        predictor_matrix=testing_image_dict[image_utils.PREDICTOR_MATRIX_KEY],
        predictor_names=predictor_names,
        normalization_dict=normalization_dict
    )
)

denormalized_values = (
    testing_image_dict[image_utils.PREDICTOR_MATRIX_KEY][0, :5, :5, 0]
)
print((
    '\nDenormalized values of {0:s} for first storm object:\n{1:s}'
).format(
    predictor_names[0], str(denormalized_values)
))

# <font color='red'>Binarization (required)</font>

- The next cell binarizes the target variable (max future vorticity in s$^{-1}$).
- **However, CNN can also perform regression.**
<br><br>

- we will maximize over the 1024 grid cells to get a scalar target variable again.
<br><br>

- we will not binarize the target variable right away.
- However, we will store the binarization threshold and use it when feeding data into a CNN.

In [None]:
binarization_threshold = image_thresholding.get_binarization_threshold(
    image_dict=training_image_dict, percentile_level=90.
)

## Sanity check

- The next cell binarizes the target variable for a small amount of testing data.

In [None]:
target_matrix_s01 = copy.deepcopy(
    testing_image_dict[image_utils.TARGET_MATRIX_KEY]
)
spatial_maxima_s01 = numpy.max(target_matrix_s01, axis=(1, 2))

print((
    '\nSpatial maxima of {0:s} for the first few storm objects:\n{1:s}'
).format(
    testing_image_dict[image_utils.TARGET_NAME_KEY],
    str(spatial_maxima_s01[:20])
))

target_classes = image_thresholding.binarize_target_images(
    target_matrix=testing_image_dict[image_utils.TARGET_MATRIX_KEY],
    binarization_threshold=binarization_threshold
)

print((
    '\nBinarized target values (classes) for the first few storm objects:'
    '\n{0:s}'
).format(
    str(target_classes[:20])
))

# Basic example

## Architecture

The next cell creates a CNN with the following hyperparameters:

 - 4 convolutional blocks
 - 2 convolutional layers per block
 - Dropout rate for conv layers and output layer = 0
 - Dropout rate for non-terminal dense layers = 0.5
 - Activation function for conv layers and non-terminal dense layers = leaky ReLU with slope of 0.2
 - Activation function for output layer = sigmoid (binary classification)
 - No L$_1$ regularization
 - L$_2$ regularization (strength of 0.001) for all convolutional layers
<br><br>

The other hyperparameters, which are all architectural (*i.e.*, layer types and sizes), are shown in the table printed below.

In [None]:
default_model_object = cnn.setup_cnn()

In [None]:
# DEFAULT_INPUT_DIMENSIONS = numpy.array([32, 32, 4], dtype=int)
# DEFAULT_CONV_BLOCK_LAYER_COUNTS = numpy.array([2, 2, 2, 2], dtype=int)
# DEFAULT_CONV_CHANNEL_COUNTS = numpy.array(
#     [32, 32, 64, 64, 128, 128, 256, 256], dtype=int
# )
# DEFAULT_CONV_DROPOUT_RATES = numpy.full(8, 0.)
# DEFAULT_CONV_FILTER_SIZES = numpy.full(8, 3, dtype=int)
# DEFAULT_DENSE_NEURON_COUNTS = numpy.array([776, 147, 28, 5, 1], dtype=int)
# DEFAULT_DENSE_DROPOUT_RATES = numpy.array([0.5, 0.5, 0.5, 0.5, 0])
# DEFAULT_INNER_ACTIV_FUNCTION_NAME = copy.deepcopy(utils.RELU_FUNCTION_NAME)
# DEFAULT_INNER_ACTIV_FUNCTION_ALPHA = 0.2
# DEFAULT_OUTPUT_ACTIV_FUNCTION_NAME = copy.deepcopy(utils.SIGMOID_FUNCTION_NAME)
# DEFAULT_OUTPUT_ACTIV_FUNCTION_ALPHA = 0.
# DEFAULT_L1_WEIGHT = 0.
# DEFAULT_L2_WEIGHT = 0.001

cnn.setup_cnn(
    input_dimensions=cnn.DEFAULT_INPUT_DIMENSIONS,
    conv_block_layer_counts=numpy.array([1, 1, 1, 1], dtype=int),
    conv_layer_channel_counts=numpy.array([32, 64, 128, 256], dtype=int),
    conv_layer_dropout_rates=numpy.full(4, 0.),
    conv_layer_filter_sizes=numpy.full(4, 3, dtype=int),
    dense_layer_neuron_counts=cnn.DEFAULT_DENSE_NEURON_COUNTS,
    dense_layer_dropout_rates=cnn.DEFAULT_DENSE_DROPOUT_RATES,
    inner_activ_function_name=cnn.DEFAULT_INNER_ACTIV_FUNCTION_NAME,
    inner_activ_function_alpha=cnn.DEFAULT_INNER_ACTIV_FUNCTION_ALPHA,
    output_activ_function_name=cnn.DEFAULT_OUTPUT_ACTIV_FUNCTION_NAME,
    output_activ_function_alpha=cnn.DEFAULT_OUTPUT_ACTIV_FUNCTION_ALPHA,
    l1_weight=cnn.DEFAULT_L1_WEIGHT, l2_weight=cnn.DEFAULT_L2_WEIGHT,
    use_batch_normalization=True
)

## Training

The next cell trains the CNN we just created.

In [None]:
cnn.train_model_sans_generator(
    model_object=default_model_object,
    training_file_names=training_file_names,
    validation_file_names=validation_file_names,
    num_examples_per_batch=1024,
    normalization_dict=normalization_dict,
    binarization_threshold=binarization_threshold,
    num_epochs=100,
    output_dir_name='/content/models/default_cnn'
)

## Evaluation

The next cell evaluates the CNN we just trained.

In [None]:
predictor_names = training_image_dict[image_utils.PREDICTOR_NAMES_KEY]

training_norm_predictor_matrix, _ = image_normalization.normalize_data(
    predictor_matrix=copy.deepcopy(
        training_image_dict[image_utils.PREDICTOR_MATRIX_KEY]
    ),
    predictor_names=predictor_names, normalization_dict=normalization_dict
)

training_classes = image_thresholding.binarize_target_images(
    target_matrix=training_image_dict[image_utils.TARGET_MATRIX_KEY],
    binarization_threshold=binarization_threshold
)

training_probs = cnn.apply_model(
    model_object=default_model_object,
    predictor_matrix=training_norm_predictor_matrix
)

_ = utils.eval_binary_classifn(
    observed_labels=training_classes,
    forecast_probabilities=training_probs,
    training_event_frequency=numpy.mean(training_classes),
    dataset_name='training'
)

validation_norm_predictor_matrix, _ = image_normalization.normalize_data(
    predictor_matrix=copy.deepcopy(
        validation_image_dict[image_utils.PREDICTOR_MATRIX_KEY]
    ),
    predictor_names=predictor_names, normalization_dict=normalization_dict
)

validation_classes = image_thresholding.binarize_target_images(
    target_matrix=validation_image_dict[image_utils.TARGET_MATRIX_KEY],
    binarization_threshold=binarization_threshold
)

validation_probs = cnn.apply_model(
    model_object=default_model_object,
    predictor_matrix=validation_norm_predictor_matrix
)

_ = utils.eval_binary_classifn(
    observed_labels=validation_classes,
    forecast_probabilities=validation_probs,
    training_event_frequency=numpy.mean(training_classes),
    dataset_name='validation'
)

# <font color='red'>Read pre-trained CNN (required)</font>

The next cell reads a pre-trained CNN, to which interpretation methods will be applied.

In [None]:
pretrained_model_file_name = (
    '/content/course_repository/pretrained_cnn/model.h5'
)

pretrained_model_object = utils.read_dense_net(pretrained_model_file_name)
pretrained_model_object.summary()

# Interpretation method 1: Permutation importance test (PIT)

## Theory
 - **The PIT measures the importance of each predictor variable, averaged over all examples in a dataset.**
 - We apply the PIT (and all other interpretation methods) to the testing set.
<br><br>

 - **The "importance" of predictor $x_j$ is determined by how much model performance declines when $x_j$ is permuted.**
 - For scalar predictors, this means randomly shuffling values of $x_j$ over the examples.
 - For spatial predictors, this means randomly shuffling entire spatial maps of $x_j$ over the examples.
 - In other words, **spatial maps are kept intact, but the order of spatial maps is permuted.**
<br><br>

 - **There are four versions of the PIT:**
   - Single-pass forward test
   - Multi-pass forward test
   - Single-pass backwards test
   - Multi-pass backwards test
 - The four versions handle correlated predictors in different ways.
 - The more correlation (interdependence) there is, the more results among the four versions of the test differ.

## Run forward versions of test

 - The next cell runs both forward versions (single-pass and multi-pass) of the permutation importance test.
 - **The loss function is negative AUC (area under ROC curve).**
 - At each step, the code computes a 95% confidence interval for the loss, using 1000 bootstrap replicates.
 - **In each figure, the most (least) important predictor is at the top (bottom)**.

In [None]:
predictor_names = testing_image_dict[image_utils.PREDICTOR_NAMES_KEY]

testing_norm_predictor_matrix, _ = image_normalization.normalize_data(
    predictor_matrix=copy.deepcopy(
        testing_image_dict[image_utils.PREDICTOR_MATRIX_KEY]
    ),
    predictor_names=predictor_names, normalization_dict=normalization_dict
)

testing_classes = image_thresholding.binarize_target_images(
    target_matrix=testing_image_dict[image_utils.TARGET_MATRIX_KEY],
    binarization_threshold=binarization_threshold
)

forward_result_dict = utils.run_forward_test(
    predictor_matrix=testing_norm_predictor_matrix,
    predictor_names=predictor_names,
    target_classes=testing_classes,
    model_object=pretrained_model_object
)

## Plot results

In [None]:
axes_object = permutation_plotting.plot_single_pass_test(
    result_dict=forward_result_dict
)
axes_object.set_title('Single-pass forward test')
axes_object.set_xlabel('Testing AUC')
pyplot.show()
print('\n\n')

axes_object = permutation_plotting.plot_multipass_test(
    result_dict=forward_result_dict
)
axes_object.set_title('Multi-pass forward test')
axes_object.set_xlabel('Testing AUC')
pyplot.show()

## Run backwards versions of test

 - The next cell runs both backwards versions (single-pass and multi-pass) of the permutation importance test.
 - **In each figure, the most (least) important predictor is at the top (bottom)**.

In [None]:
predictor_names = testing_image_dict[image_utils.PREDICTOR_NAMES_KEY]

testing_norm_predictor_matrix, _ = image_normalization.normalize_data(
    predictor_matrix=copy.deepcopy(
        testing_image_dict[image_utils.PREDICTOR_MATRIX_KEY]
    ),
    predictor_names=predictor_names, normalization_dict=normalization_dict
)

testing_classes = image_thresholding.binarize_target_images(
    target_matrix=testing_image_dict[image_utils.TARGET_MATRIX_KEY],
    binarization_threshold=binarization_threshold
)

backwards_result_dict = utils.run_backwards_test(
    predictor_matrix=testing_norm_predictor_matrix,
    predictor_names=predictor_names,
    target_classes=testing_classes,
    model_object=pretrained_model_object
)

## Plot results

In [None]:
axes_object = permutation_plotting.plot_single_pass_test(
    result_dict=backwards_result_dict, num_predictors_to_plot=20
)
axes_object.set_title('Single-pass backwards test')
axes_object.set_xlabel('Testing AUC')
pyplot.show()
print('\n\n')

axes_object = permutation_plotting.plot_multipass_test(
    result_dict=backwards_result_dict, num_predictors_to_plot=20
)
axes_object.set_title('Multi-pass backwards test')
axes_object.set_xlabel('Testing AUC')
pyplot.show()

# Extreme cases for pre-trained CNN

There are four types of extreme cases, defined below.  A positive (negative) example is a storm that does (not) develop strong rotation in the future.

 - **Best hits**: the 100 positive examples with highest forecast probability
 - **Worst false alarms**: the 100 negative examples with highest forecast probability
 - **Worst misses**: the 100 positive examples with lowest forecast probability
 - **Best correct nulls**: the 100 negative examples with lowest forecast probability

## <font color='red'>Find extreme cases (required)</font>

The next cell finds extreme cases for the pre-trained CNN.

In [None]:
predictor_names = testing_image_dict[image_utils.PREDICTOR_NAMES_KEY]

testing_predictor_matrix_denorm = (
    testing_image_dict[image_utils.PREDICTOR_MATRIX_KEY]
)

testing_predictor_matrix_norm, _ = image_normalization.normalize_data(
    predictor_matrix=copy.deepcopy(testing_predictor_matrix_denorm),
    predictor_names=predictor_names, normalization_dict=normalization_dict
)

testing_classes = image_thresholding.binarize_target_images(
    target_matrix=testing_image_dict[image_utils.TARGET_MATRIX_KEY],
    binarization_threshold=binarization_threshold
)

testing_probs = cnn.apply_model(
    model_object=pretrained_model_object,
    predictor_matrix=testing_predictor_matrix_norm
)

this_dict = utils.find_extreme_examples(
    observed_labels=testing_classes, forecast_probabilities=testing_probs,
    num_examples_per_set=100
)
best_hit_indices = this_dict[utils.HIT_INDICES_KEY]
worst_false_alarm_indices = this_dict[utils.FALSE_ALARM_INDICES_KEY]
worst_miss_indices = this_dict[utils.MISS_INDICES_KEY]
best_correct_null_indices = this_dict[utils.CORRECT_NULL_INDICES_KEY]

extreme_example_dict_denorm = {
    BEST_HIT_MATRIX_KEY:
        testing_predictor_matrix_denorm[best_hit_indices, ...],
    WORST_FALSE_ALARM_MATRIX_KEY:
        testing_predictor_matrix_denorm[worst_false_alarm_indices, ...],
    WORST_MISS_MATRIX_KEY:
        testing_predictor_matrix_denorm[worst_miss_indices, ...],
    BEST_CORRECT_NULLS_MATRIX_KEY:
        testing_predictor_matrix_denorm[best_correct_null_indices, ...],
    PREDICTOR_NAMES_KEY: predictor_names
}

extreme_example_dict_norm = {
    BEST_HIT_MATRIX_KEY:
        testing_predictor_matrix_norm[best_hit_indices, ...],
    WORST_FALSE_ALARM_MATRIX_KEY:
        testing_predictor_matrix_norm[worst_false_alarm_indices, ...],
    WORST_MISS_MATRIX_KEY:
        testing_predictor_matrix_norm[worst_miss_indices, ...],
    BEST_CORRECT_NULLS_MATRIX_KEY:
        testing_predictor_matrix_norm[best_correct_null_indices, ...],
    PREDICTOR_NAMES_KEY: predictor_names
}

this_bh_matrix = utils.run_pmm_many_variables(
    field_matrix=extreme_example_dict_denorm[BEST_HIT_MATRIX_KEY]
)
this_wfa_matrix = utils.run_pmm_many_variables(
    field_matrix=extreme_example_dict_denorm[WORST_FALSE_ALARM_MATRIX_KEY]
)
this_wm_matrix = utils.run_pmm_many_variables(
    field_matrix=extreme_example_dict_denorm[WORST_MISS_MATRIX_KEY]
)
this_bcn_matrix = utils.run_pmm_many_variables(
    field_matrix=extreme_example_dict_denorm[BEST_CORRECT_NULLS_MATRIX_KEY]
)

extreme_example_dict_denorm_pmm = {
    BEST_HIT_MATRIX_KEY: this_bh_matrix,
    WORST_FALSE_ALARM_MATRIX_KEY: this_wfa_matrix,
    WORST_MISS_MATRIX_KEY: this_wm_matrix,
    BEST_CORRECT_NULLS_MATRIX_KEY: this_bcn_matrix,
    PREDICTOR_NAMES_KEY: predictor_names
}

## Plot extreme cases

 - The next cell plots a composite (average storm) for each set of extreme cases.
 - Specifically, we plot the PMM (probability-matched means; Ebert 2001) composite.
 - PMM is similar to taking the arithmetic mean at each grid cell, but it preserves spatial structure better.

In [None]:
best_hits_matrix_denorm_pmm = extreme_example_dict_denorm_pmm[
    BEST_HIT_MATRIX_KEY
]
worst_fa_matrix_denorm_pmm = extreme_example_dict_denorm_pmm[
    WORST_FALSE_ALARM_MATRIX_KEY
]
worst_misses_matrix_denorm_pmm = extreme_example_dict_denorm_pmm[
    WORST_MISS_MATRIX_KEY
]
best_nulls_matrix_denorm_pmm = extreme_example_dict_denorm_pmm[
    BEST_CORRECT_NULLS_MATRIX_KEY
]
predictor_names = extreme_example_dict_denorm_pmm[PREDICTOR_NAMES_KEY]

concat_predictor_matrix = numpy.stack((
    best_hits_matrix_denorm_pmm, worst_fa_matrix_denorm_pmm,
    worst_misses_matrix_denorm_pmm, best_nulls_matrix_denorm_pmm,
), axis=0)

temperature_matrix_kelvins = concat_predictor_matrix[
    ..., predictor_names.index(image_utils.TEMPERATURE_NAME)
]
min_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 1)
max_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 99)

figure_object, axes_object_matrix = (
    image_plotting.plot_many_predictors_with_barbs(
        predictor_matrix=best_hits_matrix_denorm_pmm,
        predictor_names=predictor_names,
        min_colour_temp_kelvins=min_temp_kelvins,
        max_colour_temp_kelvins=max_temp_kelvins)
)

for i in range(axes_object_matrix.shape[0]):
    for j in range(axes_object_matrix.shape[1]):
        axes_object_matrix[i, j].set_title('Best hits')

pyplot.show()
print('\n\n')

figure_object, axes_object_matrix = (
    image_plotting.plot_many_predictors_with_barbs(
        predictor_matrix=worst_fa_matrix_denorm_pmm,
        predictor_names=predictor_names,
        min_colour_temp_kelvins=min_temp_kelvins,
        max_colour_temp_kelvins=max_temp_kelvins)
)

for i in range(axes_object_matrix.shape[0]):
    for j in range(axes_object_matrix.shape[1]):
        axes_object_matrix[i, j].set_title('Worst false alarms')

pyplot.show()
print('\n\n')

figure_object, axes_object_matrix = (
    image_plotting.plot_many_predictors_with_barbs(
        predictor_matrix=worst_misses_matrix_denorm_pmm,
        predictor_names=predictor_names,
        min_colour_temp_kelvins=min_temp_kelvins,
        max_colour_temp_kelvins=max_temp_kelvins)
)

for i in range(axes_object_matrix.shape[0]):
    for j in range(axes_object_matrix.shape[1]):
        axes_object_matrix[i, j].set_title('Worst misses')

pyplot.show()
print('\n\n')

figure_object, axes_object_matrix = (
    image_plotting.plot_many_predictors_with_barbs(
        predictor_matrix=best_nulls_matrix_denorm_pmm,
        predictor_names=predictor_names,
        min_colour_temp_kelvins=min_temp_kelvins,
        max_colour_temp_kelvins=max_temp_kelvins)
)

for i in range(axes_object_matrix.shape[0]):
    for j in range(axes_object_matrix.shape[1]):
        axes_object_matrix[i, j].set_title('Best correct nulls')

# Interpretation method 2: Saliency maps

## Theory

Most of the theory is explained in Lecture 4 (dense neural networks), but a brief recap is provided below.
<br><br>

**Saliency (Simonyan *et al.* 2014) is defined as:**

<center>$s = \frac{\partial a}{\partial x} \bigg \rvert_{x = x_0}$</center>

 - $a$ is the activation of a neuron in the model
 - **$x$ is one scalar predictor (in Lecture 5, one variable at one grid cell)**
 - $x_0$ is the value of $x$ in a real example
<br><br>

 - Thus, saliency is a linear approximation to $\frac{\partial a}{\partial x}$, linearized around the $x$-value that occurs in the example.
 - **This can be computed for all scalar predictors $x$, resulting in a map.**
<br><br>

 - **Here as in Lecture 4, we will compute saliency for the output neuron.**
 - Thus, we will compute the following equation for each scalar predictor $x$, where $p$ is the probability of strong future rotation:
<center>$s = \frac{\partial p}{\partial x} \bigg \rvert_{x = x_0}$</center>

## Saliency for random example

The next cell computes and plots the saliency map for a random example in the testing data.  To interpret the plot:

 - Remember that saliency is $\frac{\partial p}{\partial x}$, where $p$ is probability of strong future rotation and $x$ is a predictor.
 - Solid contours mean positive saliency ($p$ increases when $x$ increases).
 - Dashed contours mean positive saliency ($p$ decreases when $x$ increases).
 - Positive $u$-wind is westerly (towards the east/right).
 - Positive $v$-wind is southerly (towards the north/top).

In [None]:
predictor_matrix_denorm = testing_image_dict[image_utils.PREDICTOR_MATRIX_KEY]
predictor_names = testing_image_dict[image_utils.PREDICTOR_NAMES_KEY]

num_examples = predictor_matrix_denorm.shape[0]
random_index = random.randint(0, num_examples - 1)
predictor_matrix_denorm = predictor_matrix_denorm[random_index, ...]

predictor_matrix_norm, _ = image_normalization.normalize_data(
    predictor_matrix=copy.deepcopy(predictor_matrix_denorm),
    predictor_names=predictor_names, normalization_dict=normalization_dict
)

saliency_matrix = utils.get_saliency_one_neuron(
    model_object=pretrained_model_object,
    predictor_matrix=numpy.expand_dims(predictor_matrix_norm, axis=0),
    layer_name=pretrained_model_object.layers[-1].name,
    neuron_indices=numpy.array([0], dtype=int),
    ideal_activation=1.
)[0, ...]

saliency_matrix = saliency.smooth_saliency_maps(
    saliency_matrices=[saliency_matrix], smoothing_radius_grid_cells=1
)[0]

temperature_matrix_kelvins = predictor_matrix_denorm[
    ..., predictor_names.index(image_utils.TEMPERATURE_NAME)
]
min_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 1)
max_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 99)

wind_indices = numpy.array([
    predictor_names.index(image_utils.U_WIND_NAME),
    predictor_names.index(image_utils.V_WIND_NAME)
], dtype=int)

max_speed_m_s01 = numpy.percentile(
    numpy.absolute(predictor_matrix_denorm[..., wind_indices]), 99
)

figure_object, axes_object_matrix = (
    image_plotting.plot_many_predictors_sans_barbs(
        predictor_matrix=predictor_matrix_denorm,
        predictor_names=predictor_names,
        min_colour_temp_kelvins=min_temp_kelvins,
        max_colour_temp_kelvins=max_temp_kelvins,
        max_colour_wind_speed_m_s01=max_speed_m_s01
    )
)

max_saliency = numpy.percentile(numpy.absolute(saliency_matrix), 99)

saliency.plot_saliency_maps(
    saliency_matrix_3d=saliency_matrix,
    axes_object_matrix=axes_object_matrix,
    colour_map_object=pyplot.get_cmap('Greys'),
    max_contour_value=max_saliency,
    contour_interval=max_saliency / 8
)

## Saliency for strong example

The next cell computes and plots the saliency map for the strongest example in the testing data (that with the greatest max future vorticity).

In [None]:
target_matrix_s01 = testing_image_dict[image_utils.TARGET_MATRIX_KEY]
example_index = numpy.unravel_index(
    numpy.argmax(target_matrix_s01), target_matrix_s01.shape
)[0]

predictor_matrix_denorm = (
    testing_image_dict[image_utils.PREDICTOR_MATRIX_KEY][example_index, ...]
)
predictor_names = testing_image_dict[image_utils.PREDICTOR_NAMES_KEY]

predictor_matrix_norm, _ = image_normalization.normalize_data(
    predictor_matrix=copy.deepcopy(predictor_matrix_denorm),
    predictor_names=predictor_names, normalization_dict=normalization_dict
)

saliency_matrix = utils.get_saliency_one_neuron(
    model_object=pretrained_model_object,
    predictor_matrix=numpy.expand_dims(predictor_matrix_norm, axis=0),
    layer_name=pretrained_model_object.layers[-1].name,
    neuron_indices=numpy.array([0], dtype=int),
    ideal_activation=1.
)[0, ...]

saliency_matrix = saliency.smooth_saliency_maps(
    saliency_matrices=[saliency_matrix], smoothing_radius_grid_cells=1
)[0]

temperature_matrix_kelvins = predictor_matrix_denorm[
    ..., predictor_names.index(image_utils.TEMPERATURE_NAME)
]
min_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 1)
max_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 99)

wind_indices = numpy.array([
    predictor_names.index(image_utils.U_WIND_NAME),
    predictor_names.index(image_utils.V_WIND_NAME)
], dtype=int)

max_speed_m_s01 = numpy.percentile(
    numpy.absolute(predictor_matrix_denorm[..., wind_indices]), 99
)

figure_object, axes_object_matrix = (
    image_plotting.plot_many_predictors_sans_barbs(
        predictor_matrix=predictor_matrix_denorm,
        predictor_names=predictor_names,
        min_colour_temp_kelvins=min_temp_kelvins,
        max_colour_temp_kelvins=max_temp_kelvins,
        max_colour_wind_speed_m_s01=max_speed_m_s01
    )
)

max_saliency = numpy.percentile(numpy.absolute(saliency_matrix), 99)

saliency.plot_saliency_maps(
    saliency_matrix_3d=saliency_matrix,
    axes_object_matrix=axes_object_matrix,
    colour_map_object=pyplot.get_cmap('Greys'),
    max_contour_value=max_saliency,
    contour_interval=max_saliency / 8
)

## Saliency for extreme cases

The next two cells compute, then plot, the composite (PMM) saliency map for each set of extreme cases.

In [None]:
best_hits_saliency_matrix = utils.get_saliency_one_neuron(
    model_object=pretrained_model_object,
    predictor_matrix=extreme_example_dict_norm[BEST_HIT_MATRIX_KEY],
    layer_name=pretrained_model_object.layers[-1].name,
    neuron_indices=numpy.array([0], dtype=int),
    ideal_activation=1.
)
best_hits_saliency_matrix = utils.run_pmm_many_variables(
    field_matrix=best_hits_saliency_matrix
)
best_hits_saliency_matrix = saliency.smooth_saliency_maps(
    saliency_matrices=[best_hits_saliency_matrix],
    smoothing_radius_grid_cells=1
)[0]

worst_false_alarms_saliency_matrix = utils.get_saliency_one_neuron(
    model_object=pretrained_model_object,
    predictor_matrix=extreme_example_dict_norm[WORST_FALSE_ALARM_MATRIX_KEY],
    layer_name=pretrained_model_object.layers[-1].name,
    neuron_indices=numpy.array([0], dtype=int),
    ideal_activation=1.
)
worst_false_alarms_saliency_matrix = utils.run_pmm_many_variables(
    field_matrix=worst_false_alarms_saliency_matrix
)
worst_false_alarms_saliency_matrix = saliency.smooth_saliency_maps(
    saliency_matrices=[worst_false_alarms_saliency_matrix],
    smoothing_radius_grid_cells=1
)[0]

worst_misses_saliency_matrix = utils.get_saliency_one_neuron(
    model_object=pretrained_model_object,
    predictor_matrix=extreme_example_dict_norm[WORST_MISS_MATRIX_KEY],
    layer_name=pretrained_model_object.layers[-1].name,
    neuron_indices=numpy.array([0], dtype=int),
    ideal_activation=1.
)
worst_misses_saliency_matrix = utils.run_pmm_many_variables(
    field_matrix=worst_misses_saliency_matrix
)
worst_misses_saliency_matrix = saliency.smooth_saliency_maps(
    saliency_matrices=[worst_misses_saliency_matrix],
    smoothing_radius_grid_cells=1
)[0]

best_correct_nulls_saliency_matrix = utils.get_saliency_one_neuron(
    model_object=pretrained_model_object,
    predictor_matrix=extreme_example_dict_norm[BEST_CORRECT_NULLS_MATRIX_KEY],
    layer_name=pretrained_model_object.layers[-1].name,
    neuron_indices=numpy.array([0], dtype=int),
    ideal_activation=1.
)
best_correct_nulls_saliency_matrix = utils.run_pmm_many_variables(
    field_matrix=best_correct_nulls_saliency_matrix
)
best_correct_nulls_saliency_matrix = saliency.smooth_saliency_maps(
    saliency_matrices=[best_correct_nulls_saliency_matrix],
    smoothing_radius_grid_cells=1
)[0]

In [None]:
predictor_names = extreme_example_dict_denorm_pmm[PREDICTOR_NAMES_KEY]

concat_predictor_matrix = numpy.stack((
    extreme_example_dict_denorm_pmm[BEST_HIT_MATRIX_KEY],
    extreme_example_dict_denorm_pmm[WORST_FALSE_ALARM_MATRIX_KEY],
    extreme_example_dict_denorm_pmm[WORST_MISS_MATRIX_KEY],
    extreme_example_dict_denorm_pmm[BEST_CORRECT_NULLS_MATRIX_KEY]
), axis=0)

temperature_matrix_kelvins = concat_predictor_matrix[
    ..., predictor_names.index(image_utils.TEMPERATURE_NAME)
]
min_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 1)
max_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 99)

wind_indices = numpy.array([
    predictor_names.index(image_utils.U_WIND_NAME),
    predictor_names.index(image_utils.V_WIND_NAME)
], dtype=int)

max_speed_m_s01 = numpy.percentile(
    numpy.absolute(concat_predictor_matrix[..., wind_indices]), 99
)

this_max_saliency = numpy.percentile(
    numpy.absolute(best_hits_saliency_matrix), 99
)

figure_object, axes_object_matrix = (
    image_plotting.plot_many_predictors_sans_barbs(
        predictor_matrix=extreme_example_dict_denorm_pmm[BEST_HIT_MATRIX_KEY],
        predictor_names=predictor_names,
        min_colour_temp_kelvins=min_temp_kelvins,
        max_colour_temp_kelvins=max_temp_kelvins,
        max_colour_wind_speed_m_s01=max_speed_m_s01
    )
)

saliency.plot_saliency_maps(
    saliency_matrix_3d=best_hits_saliency_matrix,
    axes_object_matrix=axes_object_matrix,
    colour_map_object=pyplot.get_cmap('Greys'),
    max_contour_value=this_max_saliency,
    contour_interval=this_max_saliency / 8
)

figure_object.suptitle(
    'Best hits (max absolute saliency = {0:.2g})'.format(this_max_saliency)
)
pyplot.show()
print('\n\n')

this_max_saliency = numpy.percentile(
    numpy.absolute(worst_false_alarms_saliency_matrix), 99
)

figure_object, axes_object_matrix = (
    image_plotting.plot_many_predictors_sans_barbs(
        predictor_matrix=
        extreme_example_dict_denorm_pmm[WORST_FALSE_ALARM_MATRIX_KEY],
        predictor_names=predictor_names,
        min_colour_temp_kelvins=min_temp_kelvins,
        max_colour_temp_kelvins=max_temp_kelvins,
        max_colour_wind_speed_m_s01=max_speed_m_s01
    )
)

saliency.plot_saliency_maps(
    saliency_matrix_3d=worst_false_alarms_saliency_matrix,
    axes_object_matrix=axes_object_matrix,
    colour_map_object=pyplot.get_cmap('Greys'),
    max_contour_value=this_max_saliency,
    contour_interval=this_max_saliency / 8
)

figure_object.suptitle(
    'Worst false alarms (max absolute saliency = {0:.2g})'.format(
        this_max_saliency
    )
)
pyplot.show()
print('\n\n')

this_max_saliency = numpy.percentile(
    numpy.absolute(worst_misses_saliency_matrix), 99
)

figure_object, axes_object_matrix = (
    image_plotting.plot_many_predictors_sans_barbs(
        predictor_matrix=
        extreme_example_dict_denorm_pmm[WORST_MISS_MATRIX_KEY],
        predictor_names=predictor_names,
        min_colour_temp_kelvins=min_temp_kelvins,
        max_colour_temp_kelvins=max_temp_kelvins,
        max_colour_wind_speed_m_s01=max_speed_m_s01
    )
)

saliency.plot_saliency_maps(
    saliency_matrix_3d=worst_misses_saliency_matrix,
    axes_object_matrix=axes_object_matrix,
    colour_map_object=pyplot.get_cmap('Greys'),
    max_contour_value=this_max_saliency,
    contour_interval=this_max_saliency / 8
)

figure_object.suptitle(
    'Worst misses (max absolute saliency = {0:.2g})'.format(this_max_saliency)
)
pyplot.show()
print('\n\n')

this_max_saliency = numpy.percentile(
    numpy.absolute(best_correct_nulls_saliency_matrix), 99
)

figure_object, axes_object_matrix = (
    image_plotting.plot_many_predictors_sans_barbs(
        predictor_matrix=
        extreme_example_dict_denorm_pmm[BEST_CORRECT_NULLS_MATRIX_KEY],
        predictor_names=predictor_names,
        min_colour_temp_kelvins=min_temp_kelvins,
        max_colour_temp_kelvins=max_temp_kelvins,
        max_colour_wind_speed_m_s01=max_speed_m_s01
    )
)

saliency.plot_saliency_maps(
    saliency_matrix_3d=best_correct_nulls_saliency_matrix,
    axes_object_matrix=axes_object_matrix,
    colour_map_object=pyplot.get_cmap('Greys'),
    max_contour_value=this_max_saliency,
    contour_interval=this_max_saliency / 8
)

figure_object.suptitle(
    'Best correct nulls (max absolute saliency = {0:.2g})'.format(
        this_max_saliency
    )
)
pyplot.show()

# Interpretation method 3: Class-activation maps (CAM)

## Theory (high-level)

 - **Class activation is the amount of evidence for a particular class, defined at each grid point.**
 - In this notebook we consider only the positive class (strong future rotation).
 - However, we could also compute class activation for the negative class.
<br><br>

**Conceptually, are four key differences between CAMs and saliency maps:**

 1. CAMs yield one value per grid point, whereas saliency maps yield one value per scalar predictor.
 2. CAMs yield only non-negative values (in this context there is no such thing as negative evidence), whereas saliency can be positive or negative.
 3. CAMs highlight the most important values for generating the model's actual prediction, whereas saliency maps highlight the most important values for changing the prediction.
 4. CAMs are specific to a single convolutional layer, while saliency maps are not.
    - CAMs produced by deeper layers are smoother, with large values confined to a smaller region.
    - This is because deeper layers learn higher-level abstractions, which allows them to focus more selectively on important parts of the image.

## Theory (nitty-gritty)

 - The original CAM method (Zhou *et al.* 2016) works only for a specific type of CNN architecture, so we use a generalized version called gradient-weighted CAM (Grad-CAM; Selvaraju *et al.* 2017).
 - The original CAM method requires the CNN to end with a global-average-pooling layer, whereas Grad-CAM allows for a wider variety of architectures, including those ending with dense layers.
<br><br>

Under Grad-CAM, class activation is defined separately for each data example, at each convolutional layer $\mathcal{L}$, at each grid point $(i, j)$ in the feature maps output by $\mathcal{L}$.  Specifically:
<center>$E_{ij, k} = \textrm{max}(\sum\limits_{c = 1}^{C} \alpha_k^c A_{ij}^{c}, 0)$</center>
<center>$\textrm{where }\alpha_k^c = \frac{1}{MN} \sum\limits_{i = 1}^{M} \sum\limits_{j = 1}^{N} \frac{\partial \tilde{p}_k}{\partial A_{ij}^{c}}$</center>
<br><br>

 - $A_{ij}^{c}$ is the value for channel $c$ at grid point $(i, j)$ of the feature map output by layer $\mathcal{L}$.
 - $\tilde{p}_k$ is the psuedo-probability (before the activation function, which here is sigmoid) of class $k$.
 - $M$ and $N$ are the number of spatial rows and columns, respectively, in the feature map output by layer $\mathcal{L}$.
 - $C$ is the number of channels in the feature map output by layer $\mathcal{L}$.
 - $E_{ij, k}$ is the class activation (or evidence) for class $k$ at grid point $(i, j)$.
<br><br>

 - By applying the above equations to every grid point for a given convolutional layer and data example (iterating over $i$ and $j$), one can create a class-activation *map* (CAM).
 - These maps are usually overlain on the input data (predictors).
 - However, the spatial dimensions of the feature map output by layer $\mathcal{L}$ may not match the spatial dimensions of the input, due to pooling layers in between.
 - Thus, before plotting the CAM, it is common practice to upsample the CAM to the dimensions of the input data.
 - In this notebook we use cubic interpolation to do this upsampling.

## CAM for random example

The next cell computes and plots the CAM for a random example in the testing data.  To interpret the plot:

 - Remember that class activation = amount of evidence for the positive class (strong future rotation).
 - Darker contours mean higher class activation.
 - Absence of contours means zero class activation.
 - The same CAM is overlain on all predictor variables.

In [None]:
predictor_matrix_denorm = testing_image_dict[image_utils.PREDICTOR_MATRIX_KEY]
predictor_names = testing_image_dict[image_utils.PREDICTOR_NAMES_KEY]

num_examples = predictor_matrix_denorm.shape[0]
random_index = random.randint(0, num_examples - 1)
predictor_matrix_denorm = predictor_matrix_denorm[random_index, ...]

predictor_matrix_norm, _ = image_normalization.normalize_data(
    predictor_matrix=copy.deepcopy(predictor_matrix_denorm),
    predictor_names=predictor_names, normalization_dict=normalization_dict
)

temperature_matrix_kelvins = predictor_matrix_denorm[
    ..., predictor_names.index(image_utils.TEMPERATURE_NAME)
]
min_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 1)
max_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 99)

wind_indices = numpy.array([
    predictor_names.index(image_utils.U_WIND_NAME),
    predictor_names.index(image_utils.V_WIND_NAME)
], dtype=int)

max_speed_m_s01 = numpy.percentile(
    numpy.absolute(predictor_matrix_denorm[..., wind_indices]), 99
)

conv_layer_names = [
    'batch_normalization_1', 'batch_normalization_3',
    'batch_normalization_5', 'batch_normalization_7'
]
conv_layer_indices = numpy.array([2, 4, 6, 8], dtype=int)

for i in range(len(conv_layer_names)):
    class_activation_matrix = class_activation.run_gradcam(
        model_object=pretrained_model_object, input_matrix=predictor_matrix_norm,
        target_class=1, target_layer_name=conv_layer_names[i]
    )

    figure_object, axes_object_matrix = (
        image_plotting.plot_many_predictors_sans_barbs(
            predictor_matrix=predictor_matrix_denorm,
            predictor_names=predictor_names,
            min_colour_temp_kelvins=min_temp_kelvins,
            max_colour_temp_kelvins=max_temp_kelvins,
            max_colour_wind_speed_m_s01=max_speed_m_s01)
    )

    max_activation = numpy.percentile(class_activation_matrix, 99)

    class_activation.plot_2d_cam(
        class_activation_matrix_2d=class_activation_matrix,
        axes_object_matrix=axes_object_matrix,
        num_channels=predictor_matrix_norm.shape[-1],
        colour_map_object=pyplot.get_cmap('Greys'),
        min_contour_value=max_activation / 15,
        max_contour_value=max_activation,
        contour_interval=max_activation / 15
    )

    figure_object.suptitle(
        'CAM for conv layer {0:d} of 8 (max class activation = {1:.2g})'.format(
            conv_layer_indices[i], max_activation
        )
    )
    pyplot.show()

    if i != len(conv_layer_names) - 1:
        print('\n\n')

## CAM for strong example

The next cell computes and plots the CAM for the strongest example in the testing data (that with the greatest max future vorticity).

In [None]:
target_matrix_s01 = testing_image_dict[image_utils.TARGET_MATRIX_KEY]
example_index = numpy.unravel_index(
    numpy.argmax(target_matrix_s01), target_matrix_s01.shape
)[0]

predictor_matrix_denorm = (
    testing_image_dict[image_utils.PREDICTOR_MATRIX_KEY][example_index, ...]
)
predictor_names = testing_image_dict[image_utils.PREDICTOR_NAMES_KEY]

predictor_matrix_norm, _ = image_normalization.normalize_data(
    predictor_matrix=copy.deepcopy(predictor_matrix_denorm),
    predictor_names=predictor_names, normalization_dict=normalization_dict
)

temperature_matrix_kelvins = predictor_matrix_denorm[
    ..., predictor_names.index(image_utils.TEMPERATURE_NAME)
]
min_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 1)
max_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 99)

wind_indices = numpy.array([
    predictor_names.index(image_utils.U_WIND_NAME),
    predictor_names.index(image_utils.V_WIND_NAME)
], dtype=int)

max_speed_m_s01 = numpy.percentile(
    numpy.absolute(predictor_matrix_denorm[..., wind_indices]), 99
)

conv_layer_names = [
    'batch_normalization_1', 'batch_normalization_3',
    'batch_normalization_5', 'batch_normalization_7'
]
conv_layer_indices = numpy.array([2, 4, 6, 8], dtype=int)

for i in range(len(conv_layer_names)):
    class_activation_matrix = class_activation.run_gradcam(
        model_object=pretrained_model_object, input_matrix=predictor_matrix_norm,
        target_class=1, target_layer_name=conv_layer_names[i]
    )

    figure_object, axes_object_matrix = (
        image_plotting.plot_many_predictors_sans_barbs(
            predictor_matrix=predictor_matrix_denorm,
            predictor_names=predictor_names,
            min_colour_temp_kelvins=min_temp_kelvins,
            max_colour_temp_kelvins=max_temp_kelvins,
            max_colour_wind_speed_m_s01=max_speed_m_s01)
    )

    max_activation = numpy.percentile(class_activation_matrix, 99)

    class_activation.plot_2d_cam(
        class_activation_matrix_2d=class_activation_matrix,
        axes_object_matrix=axes_object_matrix,
        num_channels=predictor_matrix_norm.shape[-1],
        colour_map_object=pyplot.get_cmap('Greys'),
        min_contour_value=max_activation / 15,
        max_contour_value=max_activation,
        contour_interval=max_activation / 15
    )

    figure_object.suptitle(
        'CAM for conv layer {0:d} of 8 (max class activation = {1:.2g})'.format(
            conv_layer_indices[i], max_activation
        )
    )
    pyplot.show()

    if i != len(conv_layer_names) - 1:
        print('\n\n')

## CAM for extreme cases

The next two cells compute, then plot, the composite (PMM) class-activation map for each set of extreme cases.

In [None]:
num_examples_per_set = extreme_example_dict_norm[BEST_HIT_MATRIX_KEY].shape[0]

dimensions = extreme_example_dict_norm[BEST_HIT_MATRIX_KEY].shape[:-1]
best_hits_cam_matrix = numpy.full(dimensions, numpy.nan)
worst_false_alarms_cam_matrix = numpy.full(dimensions, numpy.nan)
worst_misses_cam_matrix = numpy.full(dimensions, numpy.nan)
best_correct_nulls_cam_matrix = numpy.full(dimensions, numpy.nan)

conv_layer_name = 'batch_normalization_3'
conv_layer_index = 3

for i in range(num_examples_per_set):
    print('Have computed CAM for {0:d} of {1:d} extreme examples...'.format(
        4 * i, 4 * num_examples_per_set
    ))

    best_hits_cam_matrix[i, ...] = class_activation.run_gradcam(
        model_object=pretrained_model_object,
        input_matrix=extreme_example_dict_norm[BEST_HIT_MATRIX_KEY][i, ...],
        target_class=1, target_layer_name=conv_layer_name
    )

    worst_false_alarms_cam_matrix[i, ...] = class_activation.run_gradcam(
        model_object=pretrained_model_object,
        input_matrix=
        extreme_example_dict_norm[WORST_FALSE_ALARM_MATRIX_KEY][i, ...],
        target_class=1, target_layer_name=conv_layer_name
    )

    worst_misses_cam_matrix[i, ...] = class_activation.run_gradcam(
        model_object=pretrained_model_object,
        input_matrix=
        extreme_example_dict_norm[WORST_MISS_MATRIX_KEY][i, ...],
        target_class=1, target_layer_name=conv_layer_name
    )

    best_correct_nulls_cam_matrix[i, ...] = class_activation.run_gradcam(
        model_object=pretrained_model_object,
        input_matrix=
        extreme_example_dict_norm[BEST_CORRECT_NULLS_MATRIX_KEY][i, ...],
        target_class=1, target_layer_name=conv_layer_name
    )

print('Have computed CAM for all {0:d} extreme examples!'.format(
    4 * num_examples_per_set
))

best_hits_cam_matrix = utils.run_pmm_one_variable(
    field_matrix=best_hits_cam_matrix
)
worst_false_alarms_cam_matrix = utils.run_pmm_one_variable(
    field_matrix=worst_false_alarms_cam_matrix
)
worst_misses_cam_matrix = utils.run_pmm_one_variable(
    field_matrix=worst_misses_cam_matrix
)
best_correct_nulls_cam_matrix = utils.run_pmm_one_variable(
    field_matrix=best_correct_nulls_cam_matrix
)

In [None]:
predictor_names = extreme_example_dict_denorm_pmm[PREDICTOR_NAMES_KEY]

concat_predictor_matrix = numpy.stack((
    extreme_example_dict_denorm_pmm[BEST_HIT_MATRIX_KEY],
    extreme_example_dict_denorm_pmm[WORST_FALSE_ALARM_MATRIX_KEY],
    extreme_example_dict_denorm_pmm[WORST_MISS_MATRIX_KEY],
    extreme_example_dict_denorm_pmm[BEST_CORRECT_NULLS_MATRIX_KEY]
), axis=0)

temperature_matrix_kelvins = concat_predictor_matrix[
    ..., predictor_names.index(image_utils.TEMPERATURE_NAME)
]
min_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 1)
max_temp_kelvins = numpy.percentile(temperature_matrix_kelvins, 99)

wind_indices = numpy.array([
    predictor_names.index(image_utils.U_WIND_NAME),
    predictor_names.index(image_utils.V_WIND_NAME)
], dtype=int)

max_speed_m_s01 = numpy.percentile(
    numpy.absolute(concat_predictor_matrix[..., wind_indices]), 99
)

this_max_activation = numpy.percentile(best_hits_cam_matrix, 99)

figure_object, axes_object_matrix = (
    image_plotting.plot_many_predictors_sans_barbs(
        predictor_matrix=extreme_example_dict_denorm_pmm[BEST_HIT_MATRIX_KEY],
        predictor_names=predictor_names,
        min_colour_temp_kelvins=min_temp_kelvins,
        max_colour_temp_kelvins=max_temp_kelvins,
        max_colour_wind_speed_m_s01=max_speed_m_s01
    )
)

class_activation.plot_2d_cam(
    class_activation_matrix_2d=best_hits_cam_matrix,
    axes_object_matrix=axes_object_matrix,
    num_channels=len(predictor_names),
    colour_map_object=pyplot.get_cmap('Greys'),
    min_contour_value=this_max_activation / 15,
    max_contour_value=this_max_activation,
    contour_interval=this_max_activation / 15
)

figure_object.suptitle(
    'Best hits (max class activation = {0:.2g})'.format(this_max_activation)
)
pyplot.show()
print('\n\n')

this_max_activation = numpy.percentile(worst_false_alarms_cam_matrix, 99)

figure_object, axes_object_matrix = (
    image_plotting.plot_many_predictors_sans_barbs(
        predictor_matrix=
        extreme_example_dict_denorm_pmm[WORST_FALSE_ALARM_MATRIX_KEY],
        predictor_names=predictor_names,
        min_colour_temp_kelvins=min_temp_kelvins,
        max_colour_temp_kelvins=max_temp_kelvins,
        max_colour_wind_speed_m_s01=max_speed_m_s01
    )
)

class_activation.plot_2d_cam(
    class_activation_matrix_2d=worst_false_alarms_cam_matrix,
    axes_object_matrix=axes_object_matrix,
    num_channels=len(predictor_names),
    colour_map_object=pyplot.get_cmap('Greys'),
    min_contour_value=this_max_activation / 15,
    max_contour_value=this_max_activation,
    contour_interval=this_max_activation / 15
)

figure_object.suptitle(
    'Worst false alarms (max class activation = {0:.2g})'.format(
        this_max_activation
    )
)
pyplot.show()
print('\n\n')

this_max_activation = numpy.percentile(worst_misses_cam_matrix, 99)

figure_object, axes_object_matrix = (
    image_plotting.plot_many_predictors_sans_barbs(
        predictor_matrix=extreme_example_dict_denorm_pmm[WORST_MISS_MATRIX_KEY],
        predictor_names=predictor_names,
        min_colour_temp_kelvins=min_temp_kelvins,
        max_colour_temp_kelvins=max_temp_kelvins,
        max_colour_wind_speed_m_s01=max_speed_m_s01
    )
)

class_activation.plot_2d_cam(
    class_activation_matrix_2d=worst_misses_cam_matrix,
    axes_object_matrix=axes_object_matrix,
    num_channels=len(predictor_names),
    colour_map_object=pyplot.get_cmap('Greys'),
    min_contour_value=this_max_activation / 15,
    max_contour_value=this_max_activation,
    contour_interval=this_max_activation / 15
)

figure_object.suptitle(
    'Worst misses (max class activation = {0:.2g})'.format(this_max_activation)
)
pyplot.show()
print('\n\n')

this_max_activation = numpy.percentile(best_correct_nulls_cam_matrix, 99)

figure_object, axes_object_matrix = (
    image_plotting.plot_many_predictors_sans_barbs(
        predictor_matrix=
        extreme_example_dict_denorm_pmm[BEST_CORRECT_NULLS_MATRIX_KEY],
        predictor_names=predictor_names,
        min_colour_temp_kelvins=min_temp_kelvins,
        max_colour_temp_kelvins=max_temp_kelvins,
        max_colour_wind_speed_m_s01=max_speed_m_s01
    )
)

class_activation.plot_2d_cam(
    class_activation_matrix_2d=best_correct_nulls_cam_matrix,
    axes_object_matrix=axes_object_matrix,
    num_channels=len(predictor_names),
    colour_map_object=pyplot.get_cmap('Greys'),
    min_contour_value=this_max_activation / 15,
    max_contour_value=this_max_activation,
    contour_interval=this_max_activation / 15
)

figure_object.suptitle(
    'Best correct nulls (max class activation = {0:.2g})'.format(
        this_max_activation
    )
)
pyplot.show()

# Interpretation method 4: Backwards optimization (BWO)

## Theory (high-level)

 - **BWO (Erhan *et al.* 2009) creates a synthetic input that extremizes (minimizes or maximizes) the activation of a particular neuron in the model.**
 - BWO is sometimes called "activation maximization," "feature optimization," or "optimal input".
<br><br>

 - **The BWO procedure is basically training in reverse.**
 - During training, gradient descent is used to adjust weights in a way that minimizes the loss function.
 - During BWO, gradient descent is used to adjust predictor values in a way that extremizes the neuron activation.
<br><br>

 - **In this notebook we focus on the output neuron, whose activation is probability of strong future rotation.**
 - For example, if the goal is to maximize probability, BWO creates a prototypical strongly rotating storm (supercell).
 - If the goal is to minimize probability, BWO creates a prototypical weakly rotating storm.

## Theory (nitty-gritty)

 - The BWO procedure involves many iterations.  At each iteration the synthetic example is updated via the rule:
<center>$\mathbf{X} \leftarrow \mathbf{X} - \alpha \frac{\partial J}{\partial \mathbf{X}}$</center>
<br><br>

 - $\mathbf{X}$ is the tensor of predictor values.
 - $J$ is the loss function.
 - $\frac{\partial J}{\partial \mathbf{X}}$ is a gradient tensor with the same dimensions as $\mathbf{X}$.
 - $\alpha$ is the learning rate, usually a positive number $\ll$ 1.
 - Both $\alpha$ and the number of iterations are hyperparameters.
<br><br>

 - In this notebook we set $J = (p - p^*)^2$, where $p$ is the CNN-generated class probability and $p^*$ is the desired probability (0.0 or 1.0).
 - Thus, the above equation can be written as:
<center>$\mathbf{X} \leftarrow \mathbf{X} - 2\alpha (p - p^*) \frac{\partial p}{\partial \mathbf{X}}$</center>
<br><br>

 - Note that $\frac{\partial p}{\partial \mathbf{X}}$ is the saliency map.
 - **Thus, BWO consists of serially adding a small fraction of the saliency map to the predictor map.**
 - Although saliency is a linear approximation, a new saliency map is created at each iteration of BWO, linearized around the new synthetic storm.
 - **Thus, BWO overcomes the linear limitation of saliency maps.**
<br><br>

 - **Gradient descent** adjusts but does not initialize values, so it **requires a starting point or "initial seed".**
 - Some options are all-zeros, random noise, or a real data example.
 - The advantage of all-zeros and random noise is that the initial seed does not look like a real example, so the synthetic example ultimately produced is more novel (different from the initial seed).
 - The disadvantage is that, because the initial seed is unrealistic, the synthetic example is often unrealistic as well.
 - **This problem is alleviated by using a real data example, which we do in this notebook.**

## BWO for random example

 - The next cell runs and plots BWO for a random example in the testing data.
 - **The goal is to increase strong-rotation probability (create a supercell).**

In [None]:
orig_predictor_matrix_denorm = (
    testing_image_dict[image_utils.PREDICTOR_MATRIX_KEY]
)
predictor_names = testing_image_dict[image_utils.PREDICTOR_NAMES_KEY]

num_examples = orig_predictor_matrix_denorm.shape[0]
random_index = random.randint(0, num_examples - 1)
orig_predictor_matrix_denorm = orig_predictor_matrix_denorm[random_index, ...]

orig_predictor_matrix_norm, _ = image_normalization.normalize_data(
    predictor_matrix=copy.deepcopy(orig_predictor_matrix_denorm),
    predictor_names=predictor_names, normalization_dict=normalization_dict
)

new_predictor_matrix_norm = backwards_opt.optimize_example_for_class(
    model_object=pretrained_model_object,
    input_matrix=numpy.expand_dims(orig_predictor_matrix_norm, axis=0),
    target_class=1, num_iterations=1200, learning_rate=2e-4,
    l2_weight=2e-5
)[0][0, ...]

new_predictor_matrix_denorm = image_normalization.denormalize_data(
    predictor_matrix=new_predictor_matrix_norm,
    predictor_names=predictor_names, normalization_dict=normalization_dict
)

temperature_index = predictor_names.index(image_utils.TEMPERATURE_NAME)
combined_temp_matrix_kelvins = numpy.concatenate((
    orig_predictor_matrix_denorm[..., temperature_index],
    new_predictor_matrix_denorm[..., temperature_index]
), axis=0)

min_temp_kelvins = numpy.percentile(combined_temp_matrix_kelvins, 1)
max_temp_kelvins = numpy.percentile(combined_temp_matrix_kelvins, 99)

figure_object, axes_object_matrix = (
    image_plotting.plot_many_predictors_with_barbs(
        predictor_matrix=orig_predictor_matrix_denorm,
        predictor_names=predictor_names,
        min_colour_temp_kelvins=min_temp_kelvins,
        max_colour_temp_kelvins=max_temp_kelvins
    )
)

for i in range(axes_object_matrix.shape[0]):
    for j in range(axes_object_matrix.shape[1]):
        axes_object_matrix[i, j].set_title(
            'Real example\n(before optimization)'
        )

pyplot.show()
print('\n\n')

figure_object, axes_object_matrix = (
    image_plotting.plot_many_predictors_with_barbs(
        predictor_matrix=new_predictor_matrix_denorm,
        predictor_names=predictor_names,
        min_colour_temp_kelvins=min_temp_kelvins,
        max_colour_temp_kelvins=max_temp_kelvins
    )
)

for i in range(axes_object_matrix.shape[0]):
    for j in range(axes_object_matrix.shape[1]):
        axes_object_matrix[i, j].set_title(
            'Synthetic example\n(after optimization)'
        )

pyplot.show()

## BWO for strong example

 - The next cell runs and plots BWO for the strongest example in the testing data (that with the greatest max future vorticity).
 - **The goal is to decrease strong-rotation probability (create a non-supercell).**

In [None]:
target_matrix_s01 = testing_image_dict[image_utils.TARGET_MATRIX_KEY]
example_index = numpy.unravel_index(
    numpy.argmax(target_matrix_s01), target_matrix_s01.shape
)[0]

orig_predictor_matrix_denorm = (
    testing_image_dict[image_utils.PREDICTOR_MATRIX_KEY][example_index, ...]
)
predictor_names = testing_image_dict[image_utils.PREDICTOR_NAMES_KEY]

orig_predictor_matrix_norm, _ = image_normalization.normalize_data(
    predictor_matrix=copy.deepcopy(orig_predictor_matrix_denorm),
    predictor_names=predictor_names, normalization_dict=normalization_dict
)

new_predictor_matrix_norm = backwards_opt.optimize_example_for_class(
    model_object=pretrained_model_object,
    input_matrix=numpy.expand_dims(orig_predictor_matrix_norm, axis=0),
    target_class=0, num_iterations=1200, learning_rate=1e-3,
    l2_weight=1e-5
)[0][0, ...]

new_predictor_matrix_denorm = image_normalization.denormalize_data(
    predictor_matrix=new_predictor_matrix_norm,
    predictor_names=predictor_names, normalization_dict=normalization_dict
)

temperature_index = predictor_names.index(image_utils.TEMPERATURE_NAME)
combined_temp_matrix_kelvins = numpy.concatenate((
    orig_predictor_matrix_denorm[..., temperature_index],
    new_predictor_matrix_denorm[..., temperature_index]
), axis=0)

min_temp_kelvins = numpy.percentile(combined_temp_matrix_kelvins, 1)
max_temp_kelvins = numpy.percentile(combined_temp_matrix_kelvins, 99)

figure_object, axes_object_matrix = (
    image_plotting.plot_many_predictors_with_barbs(
        predictor_matrix=orig_predictor_matrix_denorm,
        predictor_names=predictor_names,
        min_colour_temp_kelvins=min_temp_kelvins,
        max_colour_temp_kelvins=max_temp_kelvins
    )
)

for i in range(axes_object_matrix.shape[0]):
    for j in range(axes_object_matrix.shape[1]):
        axes_object_matrix[i, j].set_title(
            'Real example\n(before optimization)'
        )

pyplot.show()
print('\n\n')

figure_object, axes_object_matrix = (
    image_plotting.plot_many_predictors_with_barbs(
        predictor_matrix=new_predictor_matrix_denorm,
        predictor_names=predictor_names,
        min_colour_temp_kelvins=min_temp_kelvins,
        max_colour_temp_kelvins=max_temp_kelvins
    )
)

for i in range(axes_object_matrix.shape[0]):
    for j in range(axes_object_matrix.shape[1]):
        axes_object_matrix[i, j].set_title(
            'Synthetic example\n(after optimization)'
        )

pyplot.show()

# References

This notebook refers to a few publications, listed below.  Schwartz *et al.* (2015) documents the dataset used.

Ebert, E., 2001: "Ability of a poor man’s ensemble to predict the probability and distribution of precipitation." *Monthly Weather Review*, **129 (10)**, 2461-2480, https://journals.ametsoc.org/doi/full/10.1175/1520-0493%282001%29129%3C2461%3AAOAPMS%3E2.0.CO%3B2.

Erhan, D., Y. Bengio, A. Courville, and P. Vincent, 2009: "Visualizing higher-layer features of a deep network." Technical report, University of Montr&eacute;al, https://www.researchgate.net/profile/Aaron_Courville/publication/265022827_Visualizing_Higher-Layer_Features_of_a_Deep_Network/links/53ff82b00cf24c81027da530.pdf.

Schwartz, C., G. Romine, M. Weisman, R. Sobash, K. Fossell, K. Manning, and S. Trier, 2015: "A real-time convection-allowing ensemble prediction system initialized by mesoscale ensemble Kalman filter analyses." *Weather and Forecasting*, **30 (5)**, 1158-1181, https://doi.org/10.1175/WAF-D-15-0013.1.

Selvaraju, R., M. Cogswell, A. Das, R. Vedantam, D. Parikh, and D. Batra, 2017: "Grad-CAM: Visual explanations from deep networks via gradient-based localization." *International Conference on Computer Vision*, Venice, Italy, IEEE, http://openaccess.thecvf.com/content_ICCV_2017/papers/Selvaraju_Grad-CAM_Visual_Explanations_ICCV_2017_paper.pdf.

Wagstaff, K., and J. Lee: "Interpretable discovery in large image data sets." *arXiv e-prints*, **1806**, https://arxiv.org/abs/1806.08340.

Zhou, B., A. Khosla, A. Lapedriza, A. Oliva, and A. Torralba, 2016: "Learning deep features for discriminative localization." *Conference on Computer Vision and Pattern Recognition*, Las Vegas, Nevada, IEEE, https://www.cv-foundation.org/openaccess/content_cvpr_2016/papers/Zhou_Learning_Deep_Features_CVPR_2016_paper.pdf.