In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

import helper_mate as h
import helper as hd
from colorscheme import *

from gt_learner import GT_learner
import gr_em_learner as gr_em

In [None]:
# Data parameters
BLOCK_SIZE = 10
N_BATCHES = 1
ALPHA_LIST = [0,90]
N_RUNS = 10

# Agent parameters
SIGMA_R = 1.0
PP_THRESHOLD = 0.4
D = 10
#N_PARTICLES = 256

In [None]:
# Generate N_RUNS datasets
datasets = [h.generate_batch_data(ALPHA_LIST, BLOCK_SIZE, N_BATCHES) for i in range(N_RUNS)]

# Define models to be tested
model_set = ['x', 'y', '1x2D', '2x1D_bg']

# Compute mllhs
results = {'gt': [], 'gr': []}
for data in datasets:
    result = {}
    result['gt'] = GT_learner(data, SIGMA_R, model_set)
    results['gt'].append(result['gt'])
    result['gr'] = gr_em.GR_EM_learner(data, SIGMA_R, model_set, verbose = True,
                        EM_size_limit = 0, pp_thr = PP_THRESHOLD, D = D, task_angles_in_data = ALPHA_LIST)
    results['gr'].append(result['gr'])

In [None]:
# Compute switching times
switch_times = {}
for agent in ['gt', 'gr']:
    switch_times[agent] = [hd.model_change_time(results[agent][i], model_set[-1])
                                 for i in range(N_RUNS)]

In [None]:
# scatter plot of switching times
T = BLOCK_SIZE * (N_BATCHES + 1)
plt.figure(figsize=(6,6))
plt.scatter(switch_times['gt'], switch_times['gr'])
# plot the diagonal
plt.plot([0, T], [0, T], '--')
plt.xlabel('GT switching time')
plt.ylabel('GR switching time')
plt.xlim([0, T])
plt.ylim([0, T])
plt.show()

In [None]:
# Plot results
i = 1
hd.plot_mmllh_curves(results['gt'][i], model_set, T=len(data['c']), color_dict=model_colors_gergo, figsize=(15,4))
hd.plot_mmllh_curves(results['gr'][i], model_set, T=len(data['c']), color_dict=model_colors_gergo, figsize=(15,4))