In [2]:
import matplotlib.pyplot as plt
import numpy as np
import os 

path = "saved_model/character_trajectories/char_l3_p3_interpretable"
prototypes = np.load(os.path.join(path, 'prototypes.npy'), allow_pickle=True)

In [3]:
def compare_all_plots(path, prototypes):
    for i in range(prototypes.shape[0]):
        compare_plots(path, prototypes, i)

def trim_to_length(series, l):
    if series.shape[0] == l:
        return series
    elif series.shape[0] > l:
        return series[:l]
    else:
        diff = series.shape[0] - l
        to_fill = np.zeros((diff, series.shape[1]))
        series = np.vstack([series, to_fill])  
        return series
        
def compare_plots(path, prototypes, base):
    compare = np.delete(np.arange(prototypes.shape[0]), base)
    n_rows = compare.shape[0]
    g, b = plt.subplots(n_rows, 4, figsize=(10, 2*n_rows))
    norm_base = (prototypes[base] - prototypes[base].min()) / (prototypes[base].max() - prototypes[base].min())
    for i, comp in enumerate(compare):
        norm_compare = (prototypes[comp] - prototypes[comp].min()) / (prototypes[comp].max() - prototypes[comp].min())
        
        len_bc = np.min([norm_base.shape[0], norm_compare.shape[0]])
        diff = trim_to_length(norm_base, len_bc) - trim_to_length(norm_compare, len_bc)
        
        if i == 0:
            norm_avg = norm_compare
        else:
            len_ac = np.min([norm_avg.shape[0], norm_compare.shape[0]])
            norm_avg = trim_to_length(norm_avg, len_ac) + trim_to_length(norm_compare, len_ac)
        
        b[i][0].plot(norm_base)
        b[i][1].plot(norm_compare)
        b[i][2].plot(diff)
        
        b[i][0].set_title('Prototype: %s' % (base))
        b[i][1].set_title('Prototype: %s' % (comp))
        b[i][2].set_title('Diff to sample')
        
    len_ab = np.min([norm_base.shape[0], norm_avg.shape[0]])
    diff_mean = trim_to_length(norm_base, len_ab) - (trim_to_length(norm_avg, len_ab) / len(compare))
    for i in range(len(compare)):
        b[i][3].set_title('Diff to all')
        b[i][3].plot(diff_mean)
    
    plt.subplots_adjust(wspace=0.4, hspace=0.5)
    if not path is None:
        img_path = os.path.join(path, 'prototypes_compare')
        if not os.path.exists(img_path):
            os.makedirs(img_path)
        plt.savefig(os.path.join(img_path, 'prototype_' + str(base) + '.png'), 
                    dpi=90, bbox_inches='tight', pad_inches=0.1)
        plt.close()
    else:
        plt.show()

Compare the prototypes to understand the differences.

In [35]:
compare_all_plots(path, prototypes)