In [None]:
import tensorflow as tf
from tensorflow_probability import distributions as tfd
import tensorflow_probability as tfp
import numpy as np
import matplotlib.pyplot as plt
import helper

import matplotlib as mpl
import pickle
from copy import deepcopy
import time
#from graphviz import Digraph
import itertools
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import pandas as pd
import seaborn as sns
from scipy.stats import multivariate_normal

In [None]:
SIGMA_R = .001

sigma_r = .3 #these are not implemented correctly
sig_r_model = sigma_r

sigma_z_prior = 1

EM_SIZE = 2
Tx = 5
Ty = EM_SIZE


Z_PRIOR = "uniform" #"uniform" / "informative"
DREAM_DATA_Z_PRIOR = Z_PRIOR

USE_SAVED_DATA = False
DATA_FILENAME = "demo_data_sigr_07.npy" #"demo_data.npy"

In [None]:
def sokatmondo_adat(Tx=Tx,Ty=Ty):
  alpha = 90 #x
  #alpha = 0 # flesch 2018 version
  sigma_r = .3
  gamma = helper.gamma_from_alpha(alpha)
  if Z_PRIOR == "informative":
    z_prior = tfd.MultivariateNormalDiag(loc=[-1,1], scale_diag=[.3,.3]);
  elif Z_PRIOR == "uniform":
    z_prior = tfd.Uniform([-sigma_z_prior,-sigma_z_prior],[sigma_z_prior,sigma_z_prior])
  z = np.array(z_prior.sample(Tx))
  r_noise = tfd.Normal(0, .001).sample(Tx)
  r_mean = tf.reduce_sum(tf.multiply(gamma,z),1)
  r = r_mean + r_noise
  datax = {'z':z,'r':r}

  alpha = 0 #y
  #alpha = 270 # flesch 2018 version
  gamma = helper.gamma_from_alpha(alpha)
  if Z_PRIOR == "informative":
    z_prior = tfd.MultivariateNormalDiag(loc=[-1,1], scale_diag=[.3,.3]);
  elif Z_PRIOR == "uniform":
    z_prior = tfd.Uniform([-sigma_z_prior,-sigma_z_prior],[sigma_z_prior,sigma_z_prior])
  z = np.array(z_prior.sample(Ty))
  r_noise = tfd.Normal(0, .001).sample(Ty)
  r_mean = tf.reduce_sum(tf.multiply(gamma,z),1)
  r = r_mean + r_noise
  datay = {'z':z,'r':r}

  data = concatenate_data(datax, datay)
  z = data['z']
  r = np.array(data['r'])

  xylabels = ['x']*Tx + ['y']*Ty
  return z, r, xylabels

def concatenate_data(data1, data2):
  z = np.concatenate((data1['z'], data2['z']), 0)
  r = np.concatenate((np.array(data1['r']), np.array(data2['r'])))
  return {'z': z, 'r': r}

def plot_data_xy_labels(data, labels):
    plt.scatter(*data['z'].T,c=data['r'])
    plt.gca().set_aspect('equal')
    plt.colorbar()
    plt.xlabel('z_1')
    plt.ylabel('z_2')
    plt.axhline(y = 0)
    plt.axvline(x = 0)
    for label, x, y in zip(labels, data['z'][:, 0], data['z'][:, 1]):
        plt.annotate(
            label,
            xy=(x, y), xytext=(-20, 20),
            textcoords='offset points', ha='right', va='bottom',
            bbox=dict(boxstyle='round,pad=0.5', fc='yellow', alpha=0.5),
            arrowprops=dict(arrowstyle = '->', connectionstyle='arc3,rad=0'))

In [None]:
'''#helper.generate_data(Tx, alpha=90, z_prior_type='uniform', sigma_reward=SIGMA_R)
#helper.generate_data(Tx, alpha=0, z_prior_type='uniform', sigma_reward=SIGMA_R)

plt.figure(figsize=(3,2))
plt.subplot(1,2,1)
true_data = helper.generate_data(1000, alpha=90, z_prior_type='uniform', sigma_reward=SIGMA_R)
helper.plot_data(true_data,limit=1, axislabels=False)

plt.tick_params(
    axis='both',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    top=False,         # ticks along the top edge are off
    labelbottom=False,
    left=False,
    labelleft=False) # labels along the bottom edge are off
plt.title("task 1")

plt.subplot(1,2,2)
true_data = helper.generate_data(1000, alpha=0, z_prior_type='uniform', sigma_reward=SIGMA_R)
helper.plot_data(true_data,limit=1, axislabels=False)

plt.tick_params(
    axis='both',          # changes apply to the x-axis
    which='both',      # both major and minor ticks are affected
    bottom=False,      # ticks along the bottom edge are off
    top=False,         # ticks along the top edge are off
    labelbottom=False,
    left=False,
    labelleft=False) # labels along the bottom edge are off
plt.title("task 2")
plt.show()''';

In [None]:
z, r, xylabels = sokatmondo_adat(Tx,Ty)
true_data = {'z':z, 'r':r}

if USE_SAVED_DATA:
    xylabels = ['x', 'x', 'x', 'x', 'x', 'y', 'y', 'y', 'y', 'y']
    true_data = np.load(DATA_FILENAME,allow_pickle=True).item()

In [None]:
#plot_data_xy_labels(true_data, xylabels)

In [None]:
def generate_data_from_posterior_samples(N=100, gamma=0, z_prior_type='uniform', sigma_z_prior=1, r_bias=0, sigma_reward=0.1, sigma_bias=0):
    if z_prior_type == 'normal':
        z_prior = tfd.MultivariateNormalDiag(loc=[0,0], scale_diag=[sigma_z_prior,sigma_z_prior]);
    elif z_prior_type == 'uniform':
        z_prior = tfd.Uniform([-sigma_z_prior,-sigma_z_prior],[sigma_z_prior,sigma_z_prior])
    elif z_prior_type == 'informative':
        z_prior = tfd.MultivariateNormalDiag(loc=[-1,1], scale_diag=[.3,.3]);


    z = np.array(z_prior.sample(N))

    r_noise = tfd.Normal(0, sigma_reward).sample(N)
    r_mean = tf.cast(tf.reduce_sum(tf.multiply(gamma,z),1),dtype=tf.float32) + r_bias
    r = r_mean + r_noise

    return z[0],r[0]
  
def generate_dream_data_from_posterior_samples(samples,z_prior_type):
  z_samples = []
  r_samples = []

  for gamma_sample in samples:
      # generate a single datapoint for each gamma sample
      z,r = generate_data_from_posterior_samples(N=1, gamma=gamma_sample, z_prior_type=z_prior_type, sigma_reward=sig_r_model)
      z_samples.append(z)
      r_samples.append(r)

  z_samples = np.array(z_samples)
  r_samples = np.array(r_samples)

  return {'z': np.array(z_samples), 'r': np.array(r_samples)}

def generate_dream_data_set(posterior, T=10, N=2, z_prior_type='uniform'):
  '''generates multiple dream data sets'''
  dream_data_sets = []
  for i in range(N):
    samples = posterior.sample(T)
    dream_data = generate_dream_data_from_posterior_samples(samples, z_prior_type)
    dream_data_sets.append(dream_data)
  return dream_data_sets

In [None]:
# mllh on 2D 2 task model
# these are from particle_filter.ipynb but is it correct to just mllh?

from itertools import chain, combinations
def powerset(iterable):
    s = list(iterable)
    return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))

def model_marginal_llh_analytic(zs, rs, sigma_r, Sigma_0 = np.array([[1., 0.], [0., 1.]]), model = '2d'):
    '''
    Analytic computation of marginal likelihood of 2D 2 task
    '''
    # it is validated through 'trial_nonorm_posterior_set_transformed'
    # from that function the only step fowrad is to leave the normal in gamma (the gamma posterior) since gamma is marginalized out
    if zs.size != 0:
      T = np.size(zs,0)
      if model == '2d':
        assert not np.isscalar(Sigma_0), 'Sigma_0 must be a 2-dimensional array'
        detSigma_0 = np.linalg.det(Sigma_0)
        Sigma_i_star_invs = []
        Sigma_i_invs = []
        mu_is = []
        y = 1/(2*np.pi)/np.sqrt(np.linalg.det(Sigma_0))
        for t in range(T):
            z = zs[t]
            r = rs[t]
            Sigma_i_star_inv = np.array([[z[0]**2/sigma_r**2, z[0]*z[1]/sigma_r**2],[z[0]*z[1]/sigma_r**2, z[1]**2/sigma_r**2]])
            Sigma_i_star_invs.append(Sigma_i_star_inv)
            if t==0:
                Sigma_i_inv = Sigma_i_star_inv + np.linalg.inv(Sigma_0)
            else:
                Sigma_i_inv = Sigma_i_star_inv + Sigma_i_invs[t-1]
            Sigma_i_invs.append(Sigma_i_inv)
            Sigma_i = np.linalg.inv(Sigma_i_inv)
            if t==0:
                mu_i = Sigma_i.dot(z*r/sigma_r**2)
            else:
                mu_i = Sigma_i.dot(z*r/sigma_r**2 + Sigma_i_invs[t-1].dot(mu_is[t-1]))
            mu_is.append(mu_i)
            y = y * multivariate_normal.pdf(r, mean = 0, cov = sigma_r**2)
        y = y / multivariate_normal.pdf(mu_i, mean = np.array([0,0]), cov = Sigma_i)
      else:
        '''
        Sigma_0 is the standard deviation of the gamma prior
        '''
        assert np.isscalar(Sigma_0), 'Sigma_0 must be scalar'
        if model == 'x':
          integral_dim = 1
        else:
          integral_dim = 0

        Sigma_i_star_invs = []
        Sigma_i_invs = []
        mu_is = []
        y = 1/(np.sqrt(2*np.pi))/Sigma_0
        for t in range(T):
            z = zs[t]
            r = rs[t]
          
            Sigma_i_star_inv = z[integral_dim]**2/sigma_r**2
            Sigma_i_star_invs.append(Sigma_i_star_inv)
            if t==0:
                Sigma_i_inv = Sigma_i_star_inv + 1/Sigma_0**2
            else:
                Sigma_i_inv = Sigma_i_star_inv + Sigma_i_invs[t-1]
            Sigma_i_invs.append(Sigma_i_inv)
            Sigma_i = 1/Sigma_i_inv
            if t==0:
                mu_i = Sigma_i * z[integral_dim]*r/sigma_r**2
            else:
                mu_i = Sigma_i * (z[integral_dim]*r/sigma_r**2 + Sigma_i_invs[t-1]*mu_is[t-1])
            mu_is.append(mu_i)
            y = y * multivariate_normal.pdf(r, mean = 0, cov = sigma_r**2)
        y = y / multivariate_normal.pdf(mu_i, mean = 0.0, cov = Sigma_i)

      return y
    else:
      return 1.

def model_marginal_llh_analytic_2x2D(z, r, sigma_r, Sigma_0_2D = np.array([[1., 0.], [0., 1.]]), verbose = True):
  T = z.shape[0]
  
  indices = np.arange(T)
  index_subsets = list(powerset(indices))

  mmllh_accumulator = 0.
  if verbose:
    pbar = tf.keras.utils.Progbar(len(index_subsets))
  for index_subset in index_subsets:
    z1 = z[list(index_subset)]
    r1 = r[list(index_subset)]
    
    complementer_subset = [item for item in indices if item not in index_subset]
    
    z2 = z[complementer_subset]
    r2 = r[complementer_subset]
    
    mmllh_accumulator += model_marginal_llh_analytic(z1, r1, sigma_r, Sigma_0 = Sigma_0_2D, model = '2d') \
    * model_marginal_llh_analytic(z2, r2, sigma_r, Sigma_0 = Sigma_0_2D, model = '2d')
    
    if verbose:
      pbar.add(1)
      
  mmllh_accumulator /= 2**T
  return mmllh_accumulator

In [None]:
# create blocked and interleaved datasets

data_x = {'z':true_data['z'][:Tx], 'r':true_data['r'][:Tx]}
data_y = {'z':true_data['z'][Tx:], 'r':true_data['r'][Tx:]}

blocked_indices = np.arange(Tx+Ty)
interleaved_indices = helper.riffle(np.arange(Tx+Ty)[:Tx],np.arange(Tx+Ty)[Tx:])

data_blocked = {'z':true_data['z'][blocked_indices], 'r':true_data['r'][blocked_indices]}
data_interleaved = {'z':true_data['z'][interleaved_indices], 'r':true_data['r'][interleaved_indices]}

In [None]:
data_blocked

In [None]:
data_interleaved

In [None]:
data_presented = deepcopy(data_blocked)

data_presented = {'z':data_presented['z'][:Tx+EM_SIZE], 'r':data_presented['r'][:Tx+EM_SIZE]}
data_past = {'z':data_presented['z'][:-EM_SIZE], 'r':data_presented['r'][:-EM_SIZE]}
data_EM = {'z':data_presented['z'][-EM_SIZE:], 'r':data_presented['r'][-EM_SIZE:]}

In [None]:
data_presented

In [None]:
plt.plot(data_presented["r"])

In [None]:
plt.scatter(data_presented["z"].T[0],data_presented["z"].T[1],c=range(len(data_presented["r"])))

In [None]:
data_presented = deepcopy(data_interleaved)

data_presented = {'z':data_presented['z'][:Tx+EM_SIZE], 'r':data_presented['r'][:Tx+EM_SIZE]}
data_past = {'z':data_presented['z'][:-EM_SIZE], 'r':data_presented['r'][:-EM_SIZE]}
data_EM = {'z':data_presented['z'][-EM_SIZE:], 'r':data_presented['r'][-EM_SIZE:]}

In [None]:
data_presented

In [None]:
plt.plot(data_presented["r"])

In [None]:
plt.scatter(data_presented["z"].T[0],data_presented["z"].T[1],c=range(len(data_presented["r"])))

In [None]:
data_presented = data_blocked

data_presented = {'z':data_presented['z'][:Tx+EM_SIZE], 'r':data_presented['r'][:Tx+EM_SIZE]}
data_past = {'z':data_presented['z'][:-EM_SIZE], 'r':data_presented['r'][:-EM_SIZE]}
data_EM = {'z':data_presented['z'][-EM_SIZE:], 'r':data_presented['r'][-EM_SIZE:]}

# Fit 2D 1 task model
posterior_params = helper.gamma_posterior_analytic(data_past['z'], data_past['r'], sigma_r, Sigma_0=10*np.eye(2))
posterior = tfd.MultivariateNormalFullCovariance(loc=posterior_params[0], covariance_matrix=posterior_params[1])

# generate samples from gamma posterior
# shouldn't this be inside the dream data generation loop?
samples = posterior.sample(10)

# generate N dream data sets
dreams = generate_dream_data_set(posterior, T=5, N=10, z_prior_type=DREAM_DATA_Z_PRIOR)

# append episodic memories to dreams
dreams_plus_EM = []
for dream in dreams:
    dream_plus_EM = {'z':np.concatenate([dream['z'], data_EM['z']]), 'r':np.concatenate([dream['r'], data_EM['r']])}
    dream_plus_EM = dreams_plus_EM.append(dream_plus_EM)
dreams = dreams_plus_EM
    
# mllh on 2D 1 task model
onetask_mllhs_dream = [helper.model_marginal_llh_analytic(dream['z'], dream['r'], sigma_r, Sigma_0=np.eye(2)) for dream in dreams]

# mllh on 2D 2 task model
twotask_mllhs_dream = [model_marginal_llh_analytic_2x2D(dream['z'], dream['r'], sigma_r, Sigma_0_2D = np.array([[1., 0.], [0., 1.]])) for dream in dreams]

# mllhs on ground truth data
onetask_mllh = helper.model_marginal_llh_analytic(data_presented['z'], data_presented['r'], sigma_r, Sigma_0=np.eye(2))
twotask_mllh = model_marginal_llh_analytic_2x2D(data_presented['z'], data_presented['r'], sigma_r, Sigma_0_2D = np.array([[1., 0.], [0., 1.]]))

# debug
print(data_presented['r'])
print(helper.model_marginal_llh_analytic(data_presented['z'], data_presented['r'], sigma_r, Sigma_0=np.eye(2)))
print(helper.model_marginal_llh_analytic(data_presented['z'], data_presented['r'], sigma_r, Sigma_0=np.eye(2)))

In [None]:
plt.figure(figsize=(10,4))
plt.subplots_adjust(wspace=0.2, hspace=0.4)
plt.suptitle("BLOCKED")

plt.subplot(1,3,1)
#plot_data_xy_labels(true_data, xylabels)
helper.plot_data(data_presented, axislabels=False)
plt.title("true data")

plt.subplot(1,3,2)
#plot_data_xy_labels(true_data, xylabels)
helper.plot_data(data_past, axislabels=False)
plt.title("past data (without EM)")

plt.subplot(1,3,3)
samples = posterior.sample(300)
sns.kdeplot(x=samples[:, 0], y=samples[:, 1])
plt.scatter(x=samples[:, 0], y=samples[:, 1])
plt.xlim([-3,3])
plt.ylim([-3,3])
plt.gca().set_aspect('equal')
plt.title("gamma posterior")
plt.show()

plt.figure(figsize=(10,10))
plt.subplots_adjust(wspace=0.4, hspace=0.4)

plt.subplot(3,1,1)
plt.hist(onetask_mllhs_dream,100)
plt.axvline(onetask_mllh, color='gray', linestyle='dashed', linewidth=1)
plt.axvline(np.mean(onetask_mllhs_dream), color='lightblue', linestyle='dashed', linewidth=1)
plt.xlabel("mllh value, true (black dashed), avg (blue dashed)")
plt.ylabel("occurences")

plt.hist(twotask_mllhs_dream,100)
plt.axvline(twotask_mllh, color='k', linestyle='dashed', linewidth=1)
plt.axvline(np.mean(twotask_mllhs_dream), color='darkblue', linestyle='dashed', linewidth=1)
plt.title("both tasks, 2 task (dark), 1 task (light)")

plt.subplot(3,2,3)
plt.hist(onetask_mllhs_dream,100)
plt.axvline(onetask_mllh, color='k', linestyle='dashed', linewidth=1)
plt.axvline(np.mean(onetask_mllhs_dream), color='b', linestyle='dashed', linewidth=1)
plt.xlabel("mllh value, true (black dashed), avg (blue dashed)")
plt.ylabel("occurences")
plt.title("1 x 2D task")

plt.subplot(3,2,4)
plt.hist(twotask_mllhs_dream,100)
plt.axvline(twotask_mllh, color='k', linestyle='dashed', linewidth=1)
plt.axvline(np.mean(twotask_mllhs_dream), color='b', linestyle='dashed', linewidth=1)
plt.title("2 x 2D task")

plt.subplot(3,3,7)
plt.plot(["1x2D","2x2D"],[onetask_mllh, twotask_mllh])
plt.plot(["1x2D","2x2D"],[np.mean(onetask_mllhs_dream), np.mean(twotask_mllhs_dream)])
plt.legend(["true","avg dream"])

plt.subplot(3,3,8)
plt.bar(["1x2D","2x2D"],[onetask_mllh, twotask_mllh])
plt.title("mllh on true data")

plt.subplot(3,3,9)
plt.bar(["1x2D","2x2D"],[np.mean(onetask_mllhs_dream), np.mean(twotask_mllhs_dream)])
plt.title("avg mllh on dream data")
plt.show()


In [None]:
n_vertical = 4
n_horizontal = 5

plt.figure(figsize=(6,5))
for j in range(n_vertical):
    for i in range(n_horizontal):
        plt.subplot(n_vertical, n_horizontal,j*n_horizontal+i+1)
        helper.plot_data(
            helper.generate_data_from_gamma(N=600, gamma=samples[j*n_horizontal+i],
                                            z_prior_type='uniform', sigma_z_prior=1.5, r_bias=0, sigma_reward=0.1, sigma_bias=0),
            limit=1.75, axislabels=False)
        plt.tick_params(
            axis='both',          # changes apply to the x-axis
            which='both',      # both major and minor ticks are affected
            bottom=False,      # ticks along the bottom edge are off
            top=False,         # ticks along the top edge are off
            labelbottom=False,
            left=False,
            labelleft=False) # labels along the bottom edge are off



In [None]:
data_presented = data_interleaved

data_presented = {'z':data_presented['z'][:Tx+EM_SIZE], 'r':data_presented['r'][:Tx+EM_SIZE]}
data_past = {'z':data_presented['z'][:-EM_SIZE], 'r':data_presented['r'][:-EM_SIZE]}
data_EM = {'z':data_presented['z'][-EM_SIZE:], 'r':data_presented['r'][-EM_SIZE:]}

# Fit 2D 1 task model
posterior_params = helper.gamma_posterior_analytic(data_past['z'], data_past['r'], sigma_r, Sigma_0=10*np.eye(2))
posterior = tfd.MultivariateNormalFullCovariance(loc=posterior_params[0], covariance_matrix=posterior_params[1])

# generate samples from gamma posterior
# shouldn't this be inside the dream data generation loop?
samples = posterior.sample(10)

# generate N dream data sets
dreams = generate_dream_data_set(posterior, T=5, N=10, z_prior_type=DREAM_DATA_Z_PRIOR)

# append episodic memories to dreams
dreams_plus_EM = []
for dream in dreams:
    dream_plus_EM = {'z':np.concatenate([dream['z'], data_EM['z']]), 'r':np.concatenate([dream['r'], data_EM['r']])}
    dream_plus_EM = dreams_plus_EM.append(dream_plus_EM)
dreams = dreams_plus_EM
    
# mllh on 2D 1 task model
onetask_mllhs_dream = [helper.model_marginal_llh_analytic(dream['z'], dream['r'], sigma_r, Sigma_0=np.eye(2)) for dream in dreams]

# mllh on 2D 2 task model
twotask_mllhs_dream = [model_marginal_llh_analytic_2x2D(dream['z'], dream['r'], sigma_r, Sigma_0_2D = np.array([[1., 0.], [0., 1.]])) for dream in dreams]

# mllhs on ground truth data
onetask_mllh = helper.model_marginal_llh_analytic(data_presented['z'], data_presented['r'], sigma_r, Sigma_0=np.eye(2))
twotask_mllh = model_marginal_llh_analytic_2x2D(data_presented['z'], data_presented['r'], sigma_r, Sigma_0_2D = np.array([[1., 0.], [0., 1.]]))

# debug
print(data_presented['r'])
print(helper.model_marginal_llh_analytic(data_presented['z'], data_presented['r'], sigma_r, Sigma_0=np.eye(2)))
print(helper.model_marginal_llh_analytic(data_presented['z'], data_presented['r'], sigma_r, Sigma_0=np.eye(2)))

In [None]:
plt.figure(figsize=(10,4))
plt.subplots_adjust(wspace=0.2, hspace=0.4)
plt.suptitle("INTERLEAVED")

plt.subplot(1,3,1)
#plot_data_xy_labels(true_data, xylabels)
helper.plot_data(data_presented, axislabels=False)
plt.title("true data")

plt.subplot(1,3,2)
#plot_data_xy_labels(true_data, xylabels)
helper.plot_data(data_past, axislabels=False)
plt.title("past data (without EM)")

plt.subplot(1,3,3)
samples = posterior.sample(300)
sns.kdeplot(x=samples[:, 0], y=samples[:, 1])
plt.scatter(x=samples[:, 0], y=samples[:, 1])
plt.xlim([-3,3])
plt.ylim([-3,3])
plt.gca().set_aspect('equal')
plt.title("gamma posterior")
plt.show()

plt.figure(figsize=(10,10))
plt.subplots_adjust(wspace=0.4, hspace=0.4)

plt.subplot(3,1,1)
plt.hist(onetask_mllhs_dream,100)
plt.axvline(onetask_mllh, color='gray', linestyle='dashed', linewidth=1)
plt.axvline(np.mean(onetask_mllhs_dream), color='lightblue', linestyle='dashed', linewidth=1)
plt.xlabel("mllh value, true (black dashed), avg (blue dashed)")
plt.ylabel("occurences")

plt.hist(twotask_mllhs_dream,100)
plt.axvline(twotask_mllh, color='k', linestyle='dashed', linewidth=1)
plt.axvline(np.mean(twotask_mllhs_dream), color='darkblue', linestyle='dashed', linewidth=1)
plt.title("both tasks, 2 task (dark), 1 task (light)")

plt.subplot(3,2,3)
plt.hist(onetask_mllhs_dream,100)
plt.axvline(onetask_mllh, color='k', linestyle='dashed', linewidth=1)
plt.axvline(np.mean(onetask_mllhs_dream), color='b', linestyle='dashed', linewidth=1)
plt.xlabel("mllh value, true (black dashed), avg (blue dashed)")
plt.ylabel("occurences")
plt.title("1 x 2D task")

plt.subplot(3,2,4)
plt.hist(twotask_mllhs_dream,100)
plt.axvline(twotask_mllh, color='k', linestyle='dashed', linewidth=1)
plt.axvline(np.mean(twotask_mllhs_dream), color='b', linestyle='dashed', linewidth=1)
plt.title("2 x 2D task")

plt.subplot(3,3,7)
plt.plot(["1x2D","2x2D"],[onetask_mllh, twotask_mllh])
plt.plot(["1x2D","2x2D"],[np.mean(onetask_mllhs_dream), np.mean(twotask_mllhs_dream)])
plt.legend(["true","avg dream"])

plt.subplot(3,3,8)
plt.bar(["1x2D","2x2D"],[onetask_mllh, twotask_mllh])
plt.title("mllh on true data")

plt.subplot(3,3,9)
plt.bar(["1x2D","2x2D"],[np.mean(onetask_mllhs_dream), np.mean(twotask_mllhs_dream)])
plt.title("avg mllh on dream data")
plt.show()


In [None]:
n_vertical = 4
n_horizontal = 5

plt.figure(figsize=(6,5))
for j in range(n_vertical):
    for i in range(n_horizontal):
        plt.subplot(n_vertical, n_horizontal,j*n_horizontal+i+1)
        helper.plot_data(
            helper.generate_data_from_gamma(N=600, gamma=samples[j*n_horizontal+i],
                                            z_prior_type='uniform', sigma_z_prior=1.5, r_bias=0, sigma_reward=0.1, sigma_bias=0),
            limit=1.75, axislabels=False)
        plt.tick_params(
            axis='both',          # changes apply to the x-axis
            which='both',      # both major and minor ticks are affected
            bottom=False,      # ticks along the bottom edge are off
            top=False,         # ticks along the top edge are off
            labelbottom=False,
            left=False,
            labelleft=False) # labels along the bottom edge are off



In [None]:
n_vertical = 2
n_horizontal = 5

plt.figure(figsize=(10,4))
plt.figure(figsize=(15,6))
for j in range(n_vertical):
    for i in range(n_horizontal):
        plt.subplot(n_vertical, n_horizontal,j*n_horizontal+i+1)
        helper.plot_data(dreams[j*n_horizontal+i], labels=False, limit=1.7, climit=1,axislabels=False, marker='+')

In [None]:
n_vertical = 2
n_horizontal = 5

plt.figure(figsize=(15,6))
for j in range(n_vertical):
    for i in range(n_horizontal):
        plt.subplot(n_vertical, n_horizontal,j*n_horizontal+i+1)
        #helper.plot_data(data_past,marker='^', axislabels=False, limit=1.75)
        helper.plot_data(data_EM,marker='>', axislabels=False, limit=1.75)
        helper.plot_data(dreams[j*n_horizontal+i],marker='+', axislabels=False, limit=1.75)
        #plt.legend(["past data", "EM contents", "dreamed"])

In [None]:
plt.figure(figsize=(10,5))
helper.plot_data(data_past,marker='^', axislabels=False, limit=1.75)
helper.plot_data(data_EM,marker='>', axislabels=False, limit=1.75)
helper.plot_data(dreams[0],marker='+', axislabels=False, limit=1.75)
plt.legend(["past data", "EM contents", "dreamed"])

In [None]:
#np.save("demo_data_sigr_07",true_data)

In [None]:
ax = plt.axes()
[ax.arrow(0, 0, sample[0], sample[1], head_width=0.02, head_length=0.05, fc='k', ec='k') for sample in samples[:100]]
plt.xlim([-1.75,1.75])
plt.ylim([-1.75,1.75])
plt.gca().set_aspect('equal')
plt.show()