In [1]:
import sys
sys.path.append('/Users/omarschall/vanilla-rtrl/')
import numpy as np
from network import *
from simulation import *
from gen_data import *
try:
    import matplotlib.pyplot as plt
except ModuleNotFoundError:
    pass
from optimizers import *
from analysis_funcs import *
from learning_algorithms import *
from functions import *
from itertools import product
import os
import pickle
from copy import deepcopy
from scipy.ndimage.filters import uniform_filter1d
from sklearn import linear_model
from state_space import *
from dynamics import *
import multiprocessing as mp
from functools import partial
from sklearn.cluster import DBSCAN

In [2]:
from pyemd import emd
from pdb import set_trace

In [3]:
%matplotlib notebook

In [4]:
np.random.seed(0)
task = Flip_Flop_Task(3, 0.05, tau_task=1)
N_train = 100000
N_test = 10000
data = task.gen_data(N_train, N_test)

In [5]:
def wasserstein_distance(checkpoint_1, checkpoint_2):
    
    cluster_means_1 = checkpoint_1['cluster_means']
    fixed_points_1 = checkpoint_1['fixed_points']
    cluster_labels_1 = checkpoint_1['cluster_labels']
    cluster_weights_1 = []
    for j in range(cluster_means_1.shape[0]):
        cluster_weights_1.append(len(np.where(cluster_labels_1 == j)[0]))
    cluster_weights_1 = np.array(cluster_weights_1)
    
    cluster_means_2 = checkpoint_2['cluster_means']
    fixed_points_2 = checkpoint_2['fixed_points']
    cluster_labels_2 = checkpoint_2['cluster_labels']
    cluster_weights_2 = []
    for j in range(cluster_means_2.shape[0]):
        cluster_weights_2.append(len(np.where(cluster_labels_2 == j)[0]))
    cluster_weights_2 = np.array(cluster_weights_2)
    
    hist1 = np.concatenate([cluster_weights_1, np.zeros_like(cluster_weights_2)], axis=0).astype(np.float64)
    hist2 = np.concatenate([np.zeros_like(cluster_weights_1), cluster_weights_2], axis=0).astype(np.float64)
    N = len(cluster_weights_1) + len(cluster_weights_2)
    
    combined_means = np.concatenate([cluster_means_1, cluster_means_2], axis=0)
    
    distances = np.zeros((N, N))
    for i in range(N):
        for j in range(N):
            distances[i, j] = norm(combined_means[i] - combined_means[j])
            
    #set_trace()
    
    return emd(hist1, hist2, distances)

In [6]:
data_path = '/Users/omarschall/cluster_results/vanilla-rtrl/bptt_rflo_graph/'

In [7]:
def process_results(data_path, N_jobs, rflo_checkpoints={}, bptt_checkpoints={}):
    
    for i in range(N_jobs):

        file_path = os.path.join(data_path, 'result_{}'.format(i))
    
        try:
            with open(file_path, 'rb') as f:
                result = pickle.load(f)
        except FileNotFoundError:
            continue
            
        if result['config']['algorithm'] == 'RFLO':
            rflo_checkpoints.update(result)
                
        if result['config']['algorithm'] == 'E-BPTT':
            bptt_checkpoints.update(result)
                
    return rflo_checkpoints, bptt_checkpoints

In [8]:
rflo_checkpoints, bptt_checkpoints = process_results(data_path, 202, rflo_checkpoints={}, bptt_checkpoints={})

In [None]:
checkpoint = rflo_checkpoints['checkpoint_13500']
print(checkpoint['cluster_means'].shape)
print(checkpoint['adjacency_matrix'])

In [None]:
rflo_checkpoints.keys()

In [None]:
i_rflo = sorted([int(k.split('_')[-1]) for k in rflo_checkpoints.keys() if 'checkpoint' in k])
i_bptt = sorted([int(k.split('_')[-1]) for k in bptt_checkpoints.keys() if 'checkpoint' in k])

In [None]:
rflo_distances = []
for i in i_rflo[:-1]:
    
    try:
        checkpoint_1 = rflo_checkpoints['checkpoint_{}'.format(i)]
        checkpoint_2 = rflo_checkpoints['checkpoint_{}'.format(i + 100)]
    except KeyError:
        continue
        
    rflo_distances.append(wasserstein_distance(checkpoint_1, checkpoint_2))
    

In [None]:
bptt_distances = []
for i in i_bptt[:-1]:
    
    try:
        checkpoint_1 = bptt_checkpoints['checkpoint_{}'.format(i)]
        checkpoint_2 = bptt_checkpoints['checkpoint_{}'.format(i + 100)]
    except KeyError:
        continue
        
    bptt_distances.append(wasserstein_distance(checkpoint_1, checkpoint_2))

In [None]:
rflo_norms = [sum([norm(v) for v in rflo_checkpoints['checkpoint_{}'.format(i)]['optimizer'].vel][:3]) for i in i_rflo]
bptt_norms = [sum([norm(v) for v in bptt_checkpoints['checkpoint_{}'.format(i)]['optimizer'].vel][:3]) for i in i_bptt]

In [None]:
rflo_norms = [norm(rflo_checkpoints['checkpoint_{}'.format(i)]['learn_alg'].rec_grads) for i in i_rflo]
bptt_norms = [norm(bptt_checkpoints['checkpoint_{}'.format(i)]['learn_alg'].rec_grads) for i in i_bptt]

In [None]:
rflo_dW = [norm(rflo_checkpoints['checkpoint_{}'.format(i)]['rnn'].W_rec - 
                rflo_checkpoints['checkpoint_{}'.format(i - 100)]['rnn'].W_rec) for i in i_rflo[1:]]
bptt_dW = [norm(bptt_checkpoints['checkpoint_{}'.format(i)]['rnn'].W_rec - 
                bptt_checkpoints['checkpoint_{}'.format(i - 100)]['rnn'].W_rec) for i in i_bptt[1:]]

In [None]:
rflo_test_loss = [rflo_checkpoints['checkpoint_{}'.format(i)]['test_loss'] for i in i_rflo]
bptt_test_loss = [bptt_checkpoints['checkpoint_{}'.format(i)]['test_loss'] for i in i_bptt]

In [None]:
plt.figure()
#plt.plot([2, 3])
plt.hist(rflo_norms, bins=20, alpha=0.5, density=False)
plt.hist(bptt_norms, bins=20, alpha=0.5, density=False)

In [None]:
plt.figure()
plt.plot(i_rflo[1:], uniform_filter1d(rflo_distances, 200))
plt.plot(i_bptt[1:], uniform_filter1d(bptt_distances, 200))
y_max = np.amax(uniform_filter1d(bptt_distances, 200))
rflo_dW_scaled = rflo_dW / np.amax(rflo_dW) * y_max
bptt_dW_scaled = bptt_dW / np.amax(rflo_dW) * y_max
plt.plot(i_rflo[1:], uniform_filter1d(rflo_dW_scaled, 200), 'C0', linestyle='--')
plt.plot(i_bptt[1:], uniform_filter1d(bptt_dW_scaled, 200), 'C1', linestyle='--')
#rflo_test_loss_scaled = rflo_test_loss / np.amax(rflo_test_loss) * y_max
#bptt_test_loss_scaled = bptt_test_loss / np.amax(rflo_test_loss) * y_max
#plt.plot(i_rflo, uniform_filter1d(rflo_test_loss_scaled, 200), 'C0', linestyle='--')
#plt.plot(i_bptt, uniform_filter1d(bptt_test_loss_scaled, 200), 'C1', linestyle='--')
plt.legend(['RFLO EMD', 'BPTT EMD', 'RFLO dW', 'BPTT dW'])
plt.ylabel('arbitrary units')

In [12]:
### RFLO ###

#Checkpoint to base axes on
i_checkpoint = 40000
checkpoint = rflo_checkpoints['checkpoint_{}'.format(i_checkpoint)]
transform = partial(np.dot, b=checkpoint['V'])
ssa = State_Space_Analysis(checkpoint, data, transform=transform)
ssa_2 = State_Space_Analysis(checkpoint, data, transform=transform)

#Plotting checkpoing
i_checkpoint = 40000
checkpoint = rflo_checkpoints['checkpoint_{}'.format(i_checkpoint)]
plot_checkpoint_results(checkpoint, data, ssa=ssa, plot_cluster_means=True, eig_norm_color=False,
                        plot_test_points=False,
                        plot_fixed_points=True,
                        plot_graph_structure=True)

i_checkpoint = i_checkpoint + 100
checkpoint = rflo_checkpoints['checkpoint_{}'.format(i_checkpoint)]
plot_checkpoint_results(checkpoint, data, ssa=ssa_2, plot_cluster_means=True, eig_norm_color=False,
                        plot_test_points=False,
                        plot_fixed_points=True,
                        plot_graph_structure=True)

#rnn = checkpoint['rnn']
#test_sim = Simulation(rnn)
#test_sim.run(data,
#              mode='test',
#              monitors=['rnn.loss_', 'rnn.y_hat'],
#              verbose=False)

#plt.figure()
#plt.plot(test_sim.mons['rnn.y_hat'][:, 0])
#plt.plot(data['test']['Y'][:, 0])
#plt.xlim([0, 1000])

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<state_space.State_Space_Analysis at 0x1a8734dd90>

In [None]:
plt.figure()
plt.imshow(checkpoint['adjacency_matrix'])

In [None]:
### BPTT ###

#Checkpoint to base axes on
i_checkpoint = 20000
checkpoint = bptt_checkpoints['checkpoint_{}'.format(i_checkpoint)]
transform = partial(np.dot, b=checkpoint['V'])
ssa = State_Space_Analysis(checkpoint, data, transform=transform)
ssa_2 = State_Space_Analysis(checkpoint, data, transform=transform)

#Plotting checkpoing
i_checkpoint = 20000
checkpoint = bptt_checkpoints['checkpoint_{}'.format(i_checkpoint)]
plot_checkpoint_results(checkpoint, data, ssa=ssa, plot_cluster_means=True, eig_norm_color=False,
                        plot_test_points=True,
                        plot_fixed_points=True)

i_checkpoint = i_checkpoint + 100
checkpoint = bptt_checkpoints['checkpoint_{}'.format(i_checkpoint)]
plot_checkpoint_results(checkpoint, data, ssa=ssa_2, plot_cluster_means=True, eig_norm_color=False,
                        plot_test_points=True,
                        plot_fixed_points=True)

#rnn = checkpoint['rnn']
#test_sim = Simulation(rnn)
#test_sim.run(data,
#              mode='test',
#              monitors=['rnn.loss_', 'rnn.y_hat'],
#              verbose=False)

#plt.figure()
#plt.plot(test_sim.mons['rnn.y_hat'][:, 0])
#plt.plot(data['test']['Y'][:, 0])
#plt.xlim([0, 1000])

In [None]:
checkpoints = []
all_cluster_means = []
all_cluster_weights = []
for i in range(200):
    
    file_path = os.path.join(data_path, 'result_{}'.format(i))
    
    try:
        with open(file_path, 'rb') as f:
            result = pickle.load(f)
    except FileNotFoundError:
        continue
    
    
    segment = result['config']['i_start']
    for i_checkpoint in range(segment, segment + 100, 10):
        cluster_means = result['checkpoint_{}'.format(i_checkpoint)]['cluster_means']
        fixed_points = result['checkpoint_{}'.format(i_checkpoint)]['fixed_points']
        cluster_labels = result['checkpoint_{}'.format(i_checkpoint)]['cluster_labels']
        cluster_weights = []
        for j in range(cluster_means.shape[0]):
            cluster_weights = len(np.where(cluster_labels == j))
        cluster_weights = np.array(cluster_weights)
        checkpoints.append(result['checkpoint_{}'.format(i_checkpoint)])
        all_cluster_means.append(cluster_means)
        all_cluster_weights.append(all_cluster_weights)

In [None]:
all_cluster_means = np.concatenate(all_cluster_means, axis=0)

In [None]:
clusters = np.concatenate([c['cluster_means'] for c in checkpoints], axis=0)

In [None]:
cluster_shapes = [c['cluster_means'].shape for c in checkpoints]
cluster_shapes

In [None]:
len(checkpoints)

In [None]:
#n_points = np.array([c[0] for c in cluster_shapes])
grad_norms = [norm(c['learn_alg'].rec_grads) for c in checkpoints]
vel_norms_bptt = [sum([norm(v) for v in c['optimizer'].vel]) for c in checkpoints]

In [None]:
### FIX PCA AXES BASED ON CHECKPOINT 10000 ###
i_checkpoint = 10100
i_file = (i_checkpoint - 10000) // 100
file_path = os.path.join(data_path, 'result_{}'.format(i_file))

try:
    with open(file_path, 'rb') as f:
        result = pickle.load(f)
except FileNotFoundError:
    pass

checkpoint = result['checkpoint_{}'.format(i_checkpoint)]
transform = partial(np.dot, b=checkpoint['V'])
ssa = State_Space_Analysis(checkpoint, data, transform=transform)

In [None]:
i_checkpoint = 22180
i_file = (i_checkpoint - 10000) // 100
file_path = os.path.join(data_path, 'result_{}'.format(i_file))

try:
    with open(file_path, 'rb') as f:
        result = pickle.load(f)
except FileNotFoundError:
    pass

checkpoint = result['checkpoint_{}'.format(i_checkpoint)]
plot_checkpoint_results(checkpoint, data, ssa=ssa, plot_cluster_means=True, eig_norm_color=True,
                        plot_test_points=True,
                        plot_fixed_points=True)
#ssa.fig.suptitle('Test Loss {}'.format(checkpoint['test_loss']))
#rnn = checkpoint['rnn']
#test_sim = Simulation(rnn)
#test_sim.run(data,
#              mode='test',
#              monitors=['rnn.loss_', 'rnn.y_hat', 'rnn.a'],
#              verbose=False)

#plt.figure()
#plt.plot(test_sim.mons['rnn.y_hat'][:, 0])
#plt.plot(data['test']['Y'][:, 0])
#plt.xlim([0, 1000])
#plt.title('Test loss: {}'.format(test_sim.mons['rnn.loss_'].mean()))

In [None]:
distances_2 = [wasserstein_distance(checkpoints[i], checkpoints[i+1]) for i in range(len(checkpoints) - 1)]
plt.figure()
plt.plot(distances_, '.', alpha=0.4)
plt.plot(distances_2, '.', alpha=0.4)
plt.legend(['RFLO', 'BPTT'])
#n_points_derive = n_points[1:] - n_points[:-1]
#plt.plot(np.abs(n_points_derive), grad_norms[:-1], '.', alpha=0.4)

In [None]:
plt.figure()
plt.plot(vel_norms_rflo, '.', alpha=0.4)
plt.plot(vel_norms_bptt, '.', alpha=0.4)
plt.legend(['RFLO', 'BPTT'])

In [None]:
sum([norm(v) for v in checkpoints[1007]['optimizer'].vel])

In [None]:
plt.figure()
plt.plot(grad_norms)

In [None]:
def cluster_distance_metric(X, Y, penalty=0.1):
    
    n_x = X.shape[0]
    n_y = Y.shape[0]
    
    n_match_max = np.minimum(n_x, n_y)
    n_diff_min = np.abs(np.maximum(n_x, n_y) - np.minimum(n_x, n_y))
    
    similarity_matrix = X.dot(Y.T)
    best_match_distance = np.flip(np.argsort(similarity_matrix.flatten()))[:n_match_max].mean()
    
    ret = best_match_distance + penalty * n_diff_min
    
    return ret
    
def wasserstein_distance(checkpoint_1, checkpoint_2):
    
    cluster_means_1 = checkpoint_1['cluster_means']
    fixed_points_1 = checkpoint_1['fixed_points']
    cluster_labels_1 = checkpoint_1['cluster_labels']
    cluster_weights_1 = []
    for j in range(cluster_means_1.shape[0]):
        cluster_weights_1.append(len(np.where(cluster_labels_1 == j)[0]))
    cluster_weights_1 = np.array(cluster_weights_1)
    
    cluster_means_2 = checkpoint_2['cluster_means']
    fixed_points_2 = checkpoint_2['fixed_points']
    cluster_labels_2 = checkpoint_2['cluster_labels']
    cluster_weights_2 = []
    for j in range(cluster_means_2.shape[0]):
        cluster_weights_2.append(len(np.where(cluster_labels_2 == j)[0]))
    cluster_weights_2 = np.array(cluster_weights_2)
    
    hist1 = np.concatenate([cluster_weights_1, np.zeros_like(cluster_weights_2)], axis=0).astype(np.float64)
    hist2 = np.concatenate([np.zeros_like(cluster_weights_1), cluster_weights_2], axis=0).astype(np.float64)
    N = len(cluster_weights_1) + len(cluster_weights_2)
    
    combined_means = np.concatenate([cluster_means_1, cluster_means_2], axis=0)
    
    distances = np.zeros((N, N))
    for i in range(N):
        for j in range(N):
            distances[i, j] = norm(combined_means[i] - combined_means[j])
            
    #set_trace()
    
    return emd(hist1, hist2, distances)

In [None]:
distances = np.zeros((len(checkpoints), len(checkpoints)))
for i in range(len(checkpoints)):
    for j in range(-10, 10):
    
        try:
            checkpoint_1 = checkpoints[i]
            checkpoint_2 = checkpoints[i + j]
        except IndexError:
            continue

    #    rnn = checkpoint_1['rnn']
    #    test_sim = Simulation(rnn)
    #    test_sim.run(data,
    #                  mode='test',
    #                  monitors=['rnn.loss_'],
    #                  verbose=False)
    #    
    #    losses.append(test_sim.mons['rnn.loss_'].mean())
        distances[i, i + j] = wasserstein_distance(checkpoint_1, checkpoint_2)
plt.figure()
plt.imshow(distances)