# Hierarchical Prototypes for Deep Learning
In this notebook, we briefly show the results listed in our paper. 

## Importing and getting everything ready

In [None]:
import os, sys
sys.path.append('src')

In [None]:
import numpy as np
import warnings
import time
import torch 
import argparse
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from src.preprocessing import batch_elastic_transform
from src.train import train_MNIST, load_and_test

In [None]:
# Set device in case of training
# Turn off warnings for loading model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
warnings.filterwarnings('ignore')

In [None]:
# load fully trained models
hierarchy_model_path = './Pre-trainedModels/hierarchy_model.pth'
standard_model_path = './Pre-trainedModels/standard_model.pth'

hierarchy_model = torch.load(hierarchy_model_path, map_location=torch.device(device))
standard_model = torch.load(standard_model_path, map_location=torch.device(device))
hierarchy_model.eval()
standard_model.eval()

### Define function for visualization

In [None]:
def show_prototypes(prototype_set, **kwargs):
    """
    input: numpy set of prototypes
    displays all the prototypes in the input
    """
    fig, ax = plt.subplots(len(prototype_set)//5 ,5)
    col = 0
    for i, img in enumerate(prototype_set):
        if i % 5 == 0 and i != 0:
            col += 1
        ax[col,i%5].imshow(img[0], cmap='gray')

# Nonhierarchical model
This is to reproduce the author's original model.

### Prototype results

In [None]:
# Fetch and convert prototypes
standard_prototypes = standard_model.prototype.get_prototypes()
standard_prototype_img = standard_model.decoder(standard_prototypes.view(-1, 10, 2, 2))

In [None]:
# Get a numerical representation of our prototypes
_, decoding, (r_1, r_2, c) = standard_model(standard_prototype_img)
pred = torch.argmax(c, dim=1)
# Sort the images in accending order
merged = zip(pred, standard_prototype_img.cpu().detach().numpy())
mergedlist = list(merged)
mergedlist.sort(key=lambda x: x[0])
imgs = np.array([x[1] for x in mergedlist])
# display prototypes
show_prototypes(imgs)

### Weight matrix
Below, the prototype images can be seen, followed by the argmin (maximum activation) of the weight matrix and the actual weights.

In [None]:
show_prototypes(standard_prototype_img.cpu().detach().numpy())

In [None]:
# Weight matrices
learned_weights = standard_model.prototype.linear1.weight.cpu().detach().numpy().T
print("Maximum weights per column correspond to qualitative prototypes!")
print(learned_weights.argmin(axis=1).reshape(3,5))

print("---")
print("Actual weights")
print(np.array_str(learned_weights, precision=3, suppress_small=True))


## Test Model

In [None]:
load_and_test(standard_model_path, hierarchical = False)

# Hierarchical model

### Prototype results

In [None]:
prototypes = hierarchy_model.prototype.get_prototypes()
sub_prototypes = hierarchy_model.prototype.get_sub_prototypes()

In [None]:
# convert the prototypes
prototype_img = hierarchy_model.decoder(prototypes.view(-1, 10, 2, 2))
sub_prototype_img = hierarchy_model.decoder(sub_prototypes.view(-1, 10, 2, 2))

In [None]:
# show prototypes
show_prototypes(prototype_img.cpu().detach().numpy())

In [None]:
# Get numerical representation of our prototypes
_, decoding, (sub_c, sup_c, r1, r2, r3, r4)  = hierarchy_model(sub_prototype_img)
pred = torch.argmax(sub_c, dim=1)
# Sort the prototypes in asccending order
merged = zip(pred, sub_prototype_img.cpu().detach().numpy())
mergedlist = list(merged)
mergedlist.sort(key=lambda x: x[0])
imgs = np.array([x[1] for x in mergedlist])
# Display prototypes
show_prototypes(imgs)

### Weight matrix for subprototypes
This is the weight matrix that is also shown in the appendix of the paper.

In [None]:
print("These prototypes correspond to the weight matrix below!")
show_prototypes(sub_prototype_img.cpu().detach().numpy())

In [None]:
# Weight matrices
learned_weights = hierarchy_model.prototype.linear2.weight.cpu().detach().numpy().T
print("Maximum weights per column (roughly) correspond to qualitative prototypes")
print(learned_weights.argmin(axis=1).reshape(4,5))

print("---")
print("Actual weights")
print(np.array_str(learned_weights, precision=3, suppress_small=True))


### Test model

In [None]:
load_and_test(hierarchy_model_path, hierarchical = True)

# Training example
This code runs the hierarchical prototype network with default parameters as used in our paper. Prototype image results are saved in a separate directory. 

In [None]:
# Global parameters for device and reproducibility
seed = 42
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
hierarchical = True
# Globals
learning_rate = 0.0001
training_epochs = 1500
batch_size = 250

sigma = 4
alpha = 20
n_prototypes = 10
n_sub_prototypes = 20
latent_size = 40
n_classes = 10
save_every = 1
underrepresented_class = -1

lambda_dict = { 
    'lambda_class' : 20, 
    'lambda_class_sup' : 20,
    'lambda_class_sub' : 20,
    'lambda_ae' : 1,
    'lambda_r1' : 1,
    'lambda_r2' : 1,
    'lambda_r3' : 1,
    'lambda_r4' : 1
}

result_directory = './notebook_results'

In [None]:
"""
Args:
    Input:
      Model parameters
        hierarchical : Boolean: Is the model hierarchical?
        n_prototypes : The amount of prototypes. When hierarchical is set to true, this is the amount of superprototypes.
        n_sub_prototypes : The amount of subprototypes. Will be ignored if hierarchical is set to false.
        latent_size : Size of the latent space
        n_classes : Amount of classes 
        lambda_dict : Dictionary containing all necessary lambda's for the weighted loss equation
      Training parameters
        learning_rate : 
        training_epochs : 
        batch_size : 
        save_every : how often to save images and models?
      Miscellaneous
        sigma, alpha : Parameters for elastic deformation. Only used for train data
        directory : Directory to save results, prototype images and final model.
        underrepresented  : The class that is to be downsampled (0.25 to 1 for all other classes)
                    When it is set to -1, no class is downsampled.
"""
train_MNIST(
    hierarchical, 
    n_prototypes, 
    n_sub_prototypes, 
    latent_size, 
    n_classes, 
    lambda_dict, 
    learning_rate, 
    training_epochs, 
    batch_size, 
    save_every, 
    sigma, 
    alpha, 
    seed, 
    result_directory,
    underrepresented_class)