# Visualize warping over training

In [1]:
import pickle
import os
import imageio
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import Normalize

#### Load results

MLP filenames

In [24]:
vis_results_fn = '../../results/get_vis_parameters_results_mlp.P'
acc_results_fn = '../../results/analyze_accs_results_mlp.P'
ratio_results_fn = '../../results/calc_ratio_results_mlp.P'
model_name = 'mlp'

RNN filenames

In [None]:
vis_results_fn = '../../results/get_vis_parameters_ctxF_results_rnn.P'
acc_results_fn = '../../results/analyze_accs_ctxF_results_rnn.P'
ratio_results_fn = '../../results/calc_ratio_ctxF_results_rnn.P'
model_name = 'rnn'

Open files, load data

In [25]:
with open(vis_results_fn, 'rb') as f:
    vis_results = pickle.load(f)

with open(acc_results_fn, 'rb') as f:
    acc_results = pickle.load(f)
    
with open(ratio_results_fn, 'rb') as f:
    ratio_results = pickle.load(f)

In [26]:
run_id = 0
params = vis_results['get_vis_parameters'][run_id]
n_states = params[0]['n_states']
locs = params[0]['locs']
idx2loc = params[0]['idx2loc']
G_idxs = params[0]['G_idxs']
H_idxs = params[0]['H_idxs']
n_steps = len(params)

acc_results = acc_results['analyze_accs'][run_id]
accs = np.zeros([n_steps,2]) # congruent, incongruent
for t,r in enumerate(acc_results):
    accs[t,0] = r['cong_train_acc']
    accs[t,1] = r['incong_train_acc']

    
warping = [r['ratio_hidd'] for r in ratio_results['calc_ratio'][run_id]]
acc_ratio = [a[0]/a[1] for a in accs]

#### Use parameters to reconstruct grid

In [27]:
idx2g = {}
for idx in range(n_states):
    for g, group in enumerate(G_idxs):
        if idx in group:
            idx2g[idx] = g

idx2h = {}
for idx in range(n_states):
    for h, group in enumerate(H_idxs):
        if idx in group:
            idx2h[idx] = h

In [28]:
def generate_grid(alpha, beta):
    cum_alpha = np.zeros(7)
    cum_beta = np.zeros(7)
    cum_alpha[1:] = np.cumsum(alpha)
    cum_beta[1:] = np.cumsum(beta)
    
    # Get x and y coordinate in rotated basis
    X = np.zeros([16, 2])
    for idx in range(16):
        g = idx2g[idx] # G group
        h = idx2h[idx] # H group
        X[idx,0] = cum_alpha[g] # x coordinate
        X[idx,1] = cum_beta[h]  # y coordinate
    
    # Unrotate
    unrotate = np.array([[np.cos(-np.pi/4), -np.sin(-np.pi/4)],
                        [np.sin(-np.pi/4), np.cos(-np.pi/4)]])
    X = X @ unrotate
    
    # Mean-center
    X = X - np.mean(X, axis=0, keepdims=True)
    
    return X

Get reconstructed grid for each time step

In [29]:
reconstruction = np.zeros([n_steps, n_states, 2])
for t,p in enumerate(params):
    alpha = p['alpha']
    beta = p['beta']
    X = generate_grid(alpha, beta)
    reconstruction[t,:,:] = X

### Make .gif

In [30]:
#reconstruction = reconstruction[:50]

In [31]:
xmin = np.min(reconstruction[:,:,0])
xmax = np.max(reconstruction[:,:,0])
ymin = np.min(reconstruction[:,:,1])
ymax = np.max(reconstruction[:,:,1])
eps = 0.1*(np.max([xmax-xmin, ymax-ymin]))

warping_max = np.max(warping)
warping_min = np.min(warping)
acc_ratio_max = np.max(acc_ratio)
acc_ratio_min = np.min(acc_ratio)
ratio_max = np.max([warping_max, acc_ratio_max])+0.1
ratio_min = np.min([warping_min, acc_ratio_min])-0.1

cmap = plt.get_cmap('hot')
normalized_warping = [w/(warping_max+1) for w in warping]
colors = [cmap(nw) for nw in normalized_warping]
norm = Normalize(vmin=np.min(warping), 
                 vmax=np.max(warping), 
                 clip=True)

In [32]:
filenames = []
for t,M in enumerate(reconstruction):
    fig, ax = plt.subplots(3, 1, 
                           figsize=[8,12], 
                           gridspec_kw={'height_ratios': [1,1,3]})

    # Congruent vs. incongruent accuracies over time
    ax[0].plot(accs[:t])
    ax[0].plot(t-1, accs[t-1,0], marker='o', c='tab:blue')
    ax[0].plot(t-1, accs[t-1,1], marker='o', c='tab:orange')
    ax[0].set_title("Congruent vs. incongruent accuracy")
    ax[0].set_xlim([0,n_steps])
    ax[0].set_ylim([-0.05,1.05])
    ax[0].set_xlabel("Steps")
    ax[0].set_ylabel("Accuracy")
    ax[0].legend(['Congruent', 'Incongruent'], loc='lower right')

    # Warping vs. accuracy ratio
    ax[1].plot(warping[:t], c='tab:green')
    ax[1].plot(acc_ratio[:t], c='tab:purple')
    ax[1].plot(t-1, warping[t-1], marker='o', c='tab:green')
    ax[1].plot(t-1, acc_ratio[t-1], marker='o', c='tab:purple')
    ax[1].set_title("Warping")
    ax[1].set_xlim([0,n_steps])
    ax[1].set_ylim([ratio_min,ratio_max])
    ax[1].set_xlabel("Steps")
    ax[1].set_ylabel("Ratio")
    ax[1].legend(['Distance', 'Accuracy'], loc='upper right')

    # Reconstructed grid
    scatter = ax[2].scatter(M[:,0], M[:,1], color=colors[t])
    for loc,m in zip(locs,M):
        ax[2].annotate(loc,m)
    main_title = "{} Representations (reconstructed)".format(model_name.upper())
    ax[2].set_title(main_title)
    ax[2].set_xlim([xmin-eps, xmax+eps])
    ax[2].set_ylim([ymin-eps, ymax+eps])
    ax[2].set_xticks([])
    ax[2].set_yticks([])
    colorbar = fig.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), 
                            ax=ax[2], aspect=40, fraction=0.03, pad=0.02)
    colorbar.ax.set_ylabel('Warping', rotation=270, labelpad=15)
    
    plt.tight_layout()
    filename = '../../results/visualize_warping_{}{}.png'.format(model_name, t)
    filenames.append(filename)
    
    # More time on first and last frames
    if t in [0, n_steps-1]:
        for extra_time in range(20):
            filenames.append(filename)
    # More time at maximum warping
    elif warping[t-1] == warping_max:
        for extra_time in range(20):
            filenames.append(filename)
    plt.savefig(filename, dpi=100)
    plt.close()

In [33]:
# write .gif
gif_name = '../../results/visualize_reconstructed_warping_{}.gif'.format(model_name)
with imageio.get_writer(gif_name, mode='I') as writer:
    for filename in filenames:
        image = imageio.imread(filename)
        writer.append_data(image)

In [34]:
# remove files
for filename in set(filenames):
    if os.path.isfile(filename):
        os.remove(filename)

<img src="visualize_reconstructed_warping_rnn.gif" width="750" align="center">

<img src="visualize_reconstructed_warping_mlp.gif" width="750" align="center">