In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

import helper_mate as h
import helper as hd
from colorscheme import *

from gt_learner import GT_learner
import gr_em_learner as gr_em

In [None]:
# Data parameters
BLOCK_SIZE = 4
N_BATCHES = 1
ALPHA_LIST = [0,90]
N_RUNS = 20

# Agent parameters
SIGMA_R = 3.0
PP_THRESHOLD = 1.0
D = 5
EM_SIZE = 8
#N_PARTICLES = 256

In [None]:
# Generate N_RUNS datasets
datasets = [h.generate_batch_data(ALPHA_LIST, BLOCK_SIZE, N_BATCHES) for i in range(N_RUNS)]

# Define models to be tested
model_set = ['x', 'y', '1x2D', '2x1D_bg']

# Compute mllhs
results = {'gt': [], 'gr': [], 'em': []}
pbar = tf.keras.utils.Progbar(N_RUNS)
for data in datasets:
    result = {}
    # Ground truth learner
    result['gt'] = GT_learner(data, SIGMA_R, model_set)
    results['gt'].append(result['gt'])

    # Generative replay learner
    result['gr'] = gr_em.GR_EM_learner(data, SIGMA_R, model_set, verbose = False,
                        EM_size_limit = 0, pp_thr = PP_THRESHOLD*100, D = D, task_angles_in_data = ALPHA_LIST)
    results['gr'].append(result['gr'])

    # Episodic learner
    result['em'] = gr_em.GR_EM_learner(data, SIGMA_R, model_set, verbose = False,
                    EM_size_limit = EM_SIZE, pp_thr = PP_THRESHOLD, D = D, task_angles_in_data = ALPHA_LIST)
    results['em'].append(result['em'])

    pbar.add(1)

In [None]:
# Compute switching times
switch_times_wnan = {}
switch_times = {}
for agent in ['gt', 'gr', 'em']:
    switch_times_wnan[agent] = [hd.model_change_time(results[agent][i], model_set[-1])
                                 for i in range(N_RUNS)]
for agent in ['gt', 'gr', 'em']:
    switch_times[agent] = [x if not np.isnan(x) else N_BATCHES*BLOCK_SIZE*2 for x in switch_times_wnan[agent]]

In [None]:
#count np.nan-s in switch_times
for agent in ['gt', 'gr', 'em']:
    print(agent, np.sum(np.isnan(switch_times_wnan[agent])))



In [None]:
fig, axs = plt.subplots(1, 3, sharey=True, tight_layout=True)
fig.set_size_inches(12, 4)
for i, agent in enumerate(['gt', 'gr', 'em']):    
    axs[i].hist(switch_times[agent], label = agent)
    axs[i].set_title(agent)
    axs[i].set_xlabel('Switching time')
    axs[i].set_ylabel('Frequency')
    axs[i].set_xlim(0, N_BATCHES*BLOCK_SIZE*2)
    axs[i].set_ylim(0, 20)
    axs[i].legend()
# add SIGMA_R, EM_SIZE, PP_THRESHOLD in plot title
fig.suptitle('Switching times for different agents, $\\sigma_r$ = {}, EM size = {}, PP threshold = {}'.format(SIGMA_R, EM_SIZE, PP_THRESHOLD))
plt.show()


In [None]:
# Scatter plot of switching times with some noise added to points to avoid overlapping
T = BLOCK_SIZE * N_BATCHES * 2
epsilon = 0.3
plt.figure(figsize=(6,6))
plt.scatter(switch_times['gt'] + epsilon*np.random.randn(N_RUNS), switch_times['gr'] + epsilon*np.random.randn(N_RUNS))
plt.scatter(switch_times['gt'] + epsilon*np.random.randn(N_RUNS), switch_times['em'] + epsilon*np.random.randn(N_RUNS), color = 'r')
plt.xlabel('GT switching time')
plt.ylabel('GR/EM switching time')
plt.xlim([0, T])
plt.ylim([0, T])
#add legend outside of plot
plt.legend(['GR vs GT', 'EM vs GT'], loc='upper left', bbox_to_anchor=(1, 1))

plt.plot([0, T], [0, T], '--')
# add SIGMA_R, EM_SIZE, PP_THRESHOLD outside the plot
plt.text(T*1.05, T*0.1, 'SIGMA_R = ' + str(SIGMA_R))
plt.text(T*1.05, T*0.2, 'EM_SIZE = ' + str(EM_SIZE))
plt.text(T*1.05, T*0.3, 'PP_THRESHOLD = ' + str(PP_THRESHOLD))

plt.show()

In [None]:
# Scatter plot of switching times
T = BLOCK_SIZE * (N_BATCHES + 1)
plt.figure(figsize=(6,6))
plt.scatter(switch_times['gt'], switch_times['gr'])
plt.scatter(switch_times['gt'], switch_times['em'], color = 'r')
plt.plot([0, T], [0, T], '--')
plt.xlabel('GT switching time')
plt.ylabel('GR/EM switching time')
plt.xlim([0, T])
plt.ylim([0, T])
# Join corresponding points wit lines
for i in range(N_RUNS):
    plt.plot([switch_times['gt'][i], switch_times['gt'][i]],
             [switch_times['gr'][i], switch_times['em'][i]], color = 'gray', linewidth = 0.5)
             
plt.show()

In [None]:
# GR vs EM switching time
T = BLOCK_SIZE * (N_BATCHES + 1)
plt.figure(figsize=(6,6))
plt.scatter(switch_times['gr'] + epsilon*np.random.randn(N_RUNS), switch_times['em'] + epsilon*np.random.randn(N_RUNS))
plt.plot([0, T], [0, T], '--')
plt.xlabel('GR switching time')
plt.ylabel('EM switching time')
plt.xlim([0, T])
plt.ylim([0, T])
plt.show()

In [None]:
i = 5
learning_dicts = [results['gt'][i], results['gr'][i], results['em'][i]]
hd.plot_mllh_curves_subpanels(learning_dicts, model_set, T=len(data['c']), color_dict=model_colors_gergo, figsize=(15,4), data=datasets[0], markersize=5)

In [None]:
# Plot mllhs
hd.plot_mmllh_curves(results['gt'][i], model_set, T=len(data['c']), color_dict=model_colors_gergo, figsize=(15,4), data=datasets[0])
hd.plot_mmllh_curves(results['gr'][i], model_set, T=len(data['c']), color_dict=model_colors_gergo, figsize=(15,4), data=datasets[0])
hd.plot_mmllh_curves(results['em'][i], model_set, T=len(data['c']), color_dict=model_colors_gergo, figsize=(15,4), data=datasets[0])