# Visualize warping over training

In [1]:
import pickle
import os
import imageio.v2 as imageio
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from matplotlib.colors import Normalize

In [2]:
def load_results(results_fn, rep_name, averaged=True):
    """
        Function to load the results
            - result_fn: the path to the result folder
            - rep_name: what type of representation to use (e.g, the averaged representations)
            - averaged: whether using averaged over runs
    """
    
    # Open file
    results_dir = '../results/'
    results_path = os.path.join(results_dir,results_fn)
    with open(results_path, 'rb') as f:
        data = pickle.load(f)
    analysis = data['analysis']
    
    # List with all results
    params = [[s['get_diag_vis_params'][rep_name] for s in run] for run in analysis]
    
    # Get useful variables (fixed across checkpoints/runs)
    n_states = params[0][0]['n_states']
    locs = params[0][0]['locs']
    idx2loc = params[0][0]['idx2loc']
    G_idxs = params[0][0]['G_idxs']
    H_idxs = params[0][0]['H_idxs']
    
    # Mappings from indices to groups
    idx2g = {}
    for idx in range(n_states):
        for g, group in enumerate(G_idxs):
            if idx in group:
                idx2g[idx] = g

    idx2h = {}
    for idx in range(n_states):
        for h, group in enumerate(H_idxs):
            if idx in group:
                idx2h[idx] = h

    # Get visualization parameters
    alpha = [[p['alpha'] for p in run] for run in params]
    beta = [[p['beta'] for p in run] for run in params]
    alpha = np.array(alpha) # [n_runs, n_checkpoints, n_params]
    beta = np.array(beta) # [n_runs, n_checkpoints, n_params]
    
    # Get congruent vs. incongruent accuracy results
    train_results = data['results']
    cong_accs = []
    incong_accs = []
    for run in train_results:
        cong_accs.append([s['cong_acc'] for s in run['train_accs']])
        incong_accs.append([s['incong_acc'] for s in run['train_accs']])
    
    # Get distance ratio results
    ratios = []
    for run in analysis:
        ratios.append([s['distance_ratio'][rep_name]['ratio'] for s in run])
    dist_ratios = np.array(ratios) # [n_runs, n_checkpoints]
    
    
    # Average over runs
    if averaged:
        alpha = np.mean(alpha, axis=0) # [n_checkpoints, n_params]
        beta = np.mean(beta, axis=0)   # [n_checkpoints, n_params]
        cong_accs = np.mean(cong_accs, axis=0)
        incong_accs = np.mean(incong_accs, axis=0)
        acc_ratios = cong_accs/incong_accs # [n_checkpoints]
        dist_ratios = np.mean(dist_ratios, axis=0)
    else:
        alpha = alpha[0] # [n_checkpoints, n_params]
        beta = beta[0]   # [n_checkpoints, n_params]
        cong_accs = np.array(cong_accs[0])
        incong_accs = np.array(incong_accs[0])
        acc_ratios = cong_accs/incong_accs # [n_checkpoints]
        dist_ratios = dist_ratios[0] # [n_checkpoints]
    
    # Return results
    results = {'n_states': n_states,
               'locs': locs,
               'idx2g': idx2g,
               'idx2h': idx2h,
               'alpha': alpha,
               'beta': beta,
               'cong_accs': cong_accs,
               'incong_accs': incong_accs,
               'acc_ratios': acc_ratios,
               'dist_ratios': dist_ratios}
    
    return results

In [3]:
def reconstruct_grid(alpha, beta, n_states, idx2g, idx2h):
    """
        Function for reconstructing grid from params:
         - alpha: parameters for the reconstruction 
         - beta: parameters for the reconstruction 
         - n_states: number of states (i.e, 16 in the 4x4 grid)
         - idx2g: convert index to location 
         - idx2h: convert index to location 

    """

    n_params = len(alpha)
    
    # Cumulative sum 
    cum_alpha = np.zeros(n_params+1)
    cum_beta = np.zeros(n_params+1)
    cum_alpha[1:] = np.cumsum(alpha)
    cum_beta[1:] = np.cumsum(beta)
    
    # Get x and y coordinates in rotated basis
    X = np.zeros([n_states,2])
    for idx in range(n_states):
        g = idx2g[idx] # G group
        h = idx2h[idx] # H group
        X[idx,0] = cum_alpha[g] # x coordinate
        X[idx,1] = cum_beta[h]  # y coordinate
        
    # Unrotate
    unrotate = np.array([[np.cos(-np.pi/4), -np.sin(-np.pi/4)],
                         [np.sin(-np.pi/4), np.cos(-np.pi/4)]])
    X = X @ unrotate
    
    # Mean-center
    X = X - np.mean(X, axis=0, keepdims=True)
    
    return X

In [4]:
def build_gif(results, model_name):
    """
        Funcion for building the .gif
            - resutls: results that inlcudes all the information
            - model_name: name of the model
    """

    # Unpack results
    n_states = results['n_states']
    locs = results['locs']
    idx2g = results['idx2g']
    idx2h = results['idx2h']
    alpha = results['alpha']
    beta = results['beta']
    cong_accs = results['cong_accs']
    incong_accs = results['incong_accs']
    acc_ratios = results['acc_ratios']
    dist_ratios = results['dist_ratios']
    
    # Reconstruct grid for each time point
    n_steps = len(alpha)
    reconstruction = np.zeros([n_steps, n_states, 2])
    for t, (alpha_i, beta_i) in enumerate(zip(alpha,beta)):
        X = reconstruct_grid(alpha_i, beta_i, n_states, idx2g, idx2h)
        reconstruction[t,:,:] = X
    
    # Prepare to plot reconstruction
    xmin = np.min(reconstruction[:,:,0])
    xmax = np.max(reconstruction[:,:,0])
    ymin = np.min(reconstruction[:,:,1])
    ymax = np.max(reconstruction[:,:,1])
    eps = 0.1*(np.max([xmax-xmin, ymax-ymin]))

    dist_ratios_max = np.max(dist_ratios)
    dist_ratios_min = np.min(dist_ratios)
    acc_ratio_max = np.max(acc_ratios)
    acc_ratio_min = np.min(acc_ratios)
    ratio_max = np.max([dist_ratios_max, acc_ratio_max])+0.1
    ratio_min = np.min([dist_ratios_min, acc_ratio_min])-0.1

    cmap = plt.get_cmap('hot')
    normalized_dist_ratios = [w/(dist_ratios_max+1) for w in dist_ratios]
    colors = [cmap(nw) for nw in normalized_dist_ratios]
    norm = Normalize(vmin=np.min(dist_ratios), 
                     vmax=np.max(dist_ratios), 
                     clip=True)
    
    filenames = []
    for t,M in enumerate(reconstruction):
        fig, ax = plt.subplots(3, 1, 
                               figsize=[8,12], 
                               gridspec_kw={'height_ratios': [1,1,3]})

        # Congruent vs. incongruent accuracies over time
        ax[0].plot(cong_accs[:t])
        ax[0].plot(incong_accs[:t])
        ax[0].plot(t-1, cong_accs[t-1], marker='o', c='tab:blue')
        ax[0].plot(t-1, incong_accs[t-1], marker='o', c='tab:orange')
        ax[0].set_title("Congruent vs. incongruent accuracy")
        ax[0].set_xlim([0,n_steps])
        ax[0].set_ylim([-0.05,1.05])
        ax[0].set_xlabel("Steps")
        ax[0].set_ylabel("Accuracy")
        ax[0].legend(['Congruent', 'Incongruent'], loc='lower right')

        # dist_ratios vs. accuracy ratio
        ax[1].plot(dist_ratios[:t], c='tab:green')
        ax[1].plot(acc_ratios[:t], c='tab:purple')
        ax[1].plot(t-1, dist_ratios[t-1], marker='o', c='tab:green')
        ax[1].plot(t-1, acc_ratios[t-1], marker='o', c='tab:purple')
        ax[1].set_title("Warping")
        ax[1].set_xlim([0,n_steps])
        ax[1].set_ylim([ratio_min,ratio_max])
        ax[1].set_xlabel("Steps")
        ax[1].set_ylabel("Ratio")
        ax[1].legend(['Distance', 'Accuracy'], loc='upper right')

        # Reconstructed grid
        scatter = ax[2].scatter(M[:,0], M[:,1], color=colors[t-1])
        for loc,m in zip(locs,M):
            ax[2].annotate(loc,m)
        main_title = "{} Representations (reconstructed)".format(model_name.upper())
        ax[2].set_title(main_title)
        ax[2].set_xlim([xmin-eps, xmax+eps])
        ax[2].set_ylim([ymin-eps, ymax+eps])
        ax[2].set_xticks([])
        ax[2].set_yticks([])
        colorbar = fig.colorbar(cm.ScalarMappable(norm=norm, cmap=cmap), 
                                ax=ax[2], aspect=40, fraction=0.03, pad=0.02)
        colorbar.ax.set_ylabel('Warping', rotation=270, labelpad=15)

        plt.tight_layout()
        filename = '../figures/visualize_dist_ratios_{}{}.png'.format(model_name, t)
        filenames.append(filename)

        # More time on first and last frames
        if t in [0, n_steps-1]:
            for extra_time in range(20):
                filenames.append(filename)
        # More time at maximum dist_ratios
        elif dist_ratios[t-1] == dist_ratios_max:
            for extra_time in range(20):
                filenames.append(filename)
        plt.savefig(filename, dpi=100)
        plt.close()
        
    # Write .gif
    gif_name = '../figures/visualize_reconstructed_warping_{}.gif'.format(model_name)
    with imageio.get_writer(gif_name, mode='I') as writer:
        for filename in filenames:
            image = imageio.imread(filename)
            writer.append_data(image)
    
    # remove files
    for filename in set(filenames):
        if os.path.isfile(filename):
            os.remove(filename)

Loading the results

In [5]:
results_fn = 'rnn.P'
rep_name = 'average'
model_name = 'RNN'
averaged = True

In [6]:
results = load_results(results_fn, rep_name, averaged)

Making the .gif and save the results with 'visualize_reconstructed_warping_RNN.gif' filename under the 'figures' folder

In [7]:
build_gif(results, model_name)

Loading the results

In [8]:
results_fn = 'rnn.P'
rep_name = 'average'
model_name = 'RNN_run0'
averaged = False

In [9]:
results = load_results(results_fn, rep_name, averaged)

Making the .gif and save the results with 'visualize_reconstructed_warping_RNN_run0.gif' filename under the 'figures' folder

In [10]:
build_gif(results, model_name)