In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import copy

import helper_mate as h
import helper as hd
from colorscheme import *

from gt_learner import GT_learner

In [None]:
# Data parameters
BLOCK_SIZE = 5
N_BATCHES = 3
N_RUNS = 100

# Agent parameters
SIGMA_R = 1.0

In [None]:
# Define rotation operator
rotation_operator = [[np.cos(np.pi/4), -np.cos(np.pi/4)],[np.sin(np.pi/4), np.cos(np.pi/4)]]

# Define conditions
conditions = ['cardinal', 'diagonal']

def generate_dataset():
    '''Generate a dataset with cardinal and rotated cardinal data'''
    # Generate initial data
    data = {}
    data = h.generate_batch_data([0,90], BLOCK_SIZE, N_BATCHES);

    # Cardinal data (same as init)
    data['cardinal'] = copy.deepcopy(data)

    # Rotate to get diagonal data
    data['diagonal'] = copy.deepcopy(data)

    for i in range(0,np.size(data["z"],0)):
        data['diagonal']["z"][i] = np.dot(rotation_operator,data['diagonal']["z"][i])

    return data

def generate_datasets(n_datasets):
    '''Generate a list of N_RUNS datasets'''
    datasets = []
    for i in range(n_datasets):
        datasets.append(generate_dataset())
    return datasets

# Generate datasets using the functions above
datasets = generate_datasets(N_RUNS)

In [None]:
# Plot first dataset
i = 0
data = datasets[i]

for condition in conditions:
    hd.plot_data(data[condition], labels=False, limit=2.5, figsize=(4,4))
    plt.show()


In [None]:
# Define model sets
model_set = {'cardinal': ['x', 'y', '1x2D', '2x1D_bg'], 'diagonal': ['x', 'y', '1x2D', '2x2D_bg']}

# For each condition, compute mllhs
results = []
for data in datasets:
    result = {}
    for condition in conditions:
        result[condition] = GT_learner(data[condition], SIGMA_R, model_set[condition])
    results.append(result)

# Compute switching times
switch_times = {}
for condition in conditions:
    switch_times[condition] = [hd.model_change_time(results[i][condition], model_set[condition][-1])
                                 for i in range(N_RUNS)]


In [None]:
# Plot evolution of mllh for each condition
i = 0
data = datasets[i]
result = results[i]
for condition in conditions:
    hd.plot_mmllh_curves(result[condition], model_set[condition],
                            T=len(data[condition]['c']), color_dict=model_colors_gergo, figsize=(15,4))

In [None]:
# Plot histogram of switching times
for condition in conditions:
    plt.hist(switch_times[condition], bins=np.arange(0, 30, 1), alpha=0.5,
             label=condition, color=modelColors[model_set[condition][-1]])
plt.legend()
plt.title("$\sigma_r$=" + str(SIGMA_R) + ", block size =" + str(BLOCK_SIZE))
plt.ylabel('count')
plt.xlabel('time of model discovery')
plt.show()