In [None]:
import tensorflow as tf
from tensorflow_probability import distributions as tfd
import tensorflow_probability as tfp
import numpy as np
import matplotlib.pyplot as plt
import helper

import matplotlib as mpl
import pickle
from copy import deepcopy
import time
#from graphviz import Digraph
import itertools
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
import pandas as pd
import seaborn as sns
from scipy.stats import multivariate_normal

In [None]:
SIGMA_R = .001
SIGMA_R_MODEL = .3
SIGMA_R_DREAM = SIGMA_R_MODEL #this can be changed separately from sigma model to decrease variance in dreams

EM_SIZE = 2
Tx = 5
Ty = EM_SIZE

Z_PRIOR = "uniform" #"uniform" / "informative"
DREAM_DATA_Z_PRIOR = Z_PRIOR
sigma_z_prior = 1

USE_SAVED_DATA = False
DATA_FILENAME = "demo_data_sigr_07.npy" #"demo_data.npy"

In [None]:
sigma_r = SIGMA_R_MODEL #these are not implemented correctly
sig_r_model = SIGMA_R_MODEL

In [None]:
def generate_dream_data_from_posterior_samples(samples,z_prior_type, sigma_reward=0.1):
    '''
    Given a list of gamma samples and a specified Z prior, generate a single datapoint from each gamma sample using the 1x2D model.
    '''
    z_samples = []
    r_samples = []

    for gamma_sample in samples:
        # generate a single datapoint for each gamma sample
        sample_point = helper.generate_data_from_gamma(N=1, gamma=gamma_sample, z_prior_type=z_prior_type, sigma_reward=sigma_reward)
        z_samples.append(sample_point['z'][0])
        r_samples.append(sample_point['r'][0])

    z_samples = np.array(z_samples)
    r_samples = np.array(r_samples)

    return {'z': np.array(z_samples), 'r': np.array(r_samples)}

def generate_dream_data_set(posterior, T=10, N=2, z_prior_type='uniform', sigma_reward=0.1):
    '''
    Generates multiple dream data sets given a posterior distribution. T is length of each dataset, N is number of datasets.
    '''
    dream_data_sets = []
    for i in range(N):
        samples = posterior.sample(T)
        dream_data = generate_dream_data_from_posterior_samples(samples, z_prior_type, SIGMA_R_DREAM)
        dream_data_sets.append(dream_data)
    return dream_data_sets

In [None]:
if not USE_SAVED_DATA:
    data1 = helper.generate_data(Tx, alpha=90, context_value=0, z_prior_type=Z_PRIOR, sigma_reward=SIGMA_R)
    data2 = helper.generate_data(Ty, alpha=0, context_value=1, z_prior_type=Z_PRIOR, sigma_reward=SIGMA_R)
    true_data = helper.concatenate_data(data1, data2)
else:
    true_data = np.load(DATA_FILENAME,allow_pickle=True).item()

In [None]:
helper.plot_data(data1,marker='^',colorbar=False)
helper.plot_data(data2,marker='s')
plt.legend(["task 1","task 2"])
plt.show()

plt.figure(figsize=(3,2))
plt.subplot(1,2,1)
data = helper.generate_data(1000, alpha=90, context_value=0, z_prior_type=Z_PRIOR, sigma_reward=SIGMA_R)
helper.plot_data(data, limit=1, axislabels=False, colorbar=False, ticks=False)
plt.title("task 1")

plt.subplot(1,2,2)
data = helper.generate_data(1000, alpha=0, context_value=1, z_prior_type=Z_PRIOR, sigma_reward=SIGMA_R)
helper.plot_data(data, limit=1, axislabels=False, colorbar=False, ticks=False)

plt.title("task 2")
plt.show()

In [None]:
# create blocked and interleaved datasets
blocked_indices = np.arange(Tx+Ty)
interleaved_indices = helper.riffle(np.arange(Tx+Ty)[:Tx],np.arange(Tx+Ty)[Tx:])

data_blocked = helper.reorder_data(true_data,blocked_indices)
data_interleaved = helper.reorder_data(true_data,interleaved_indices)

In [None]:
data_presented = data_blocked

data_presented = {'z':data_presented['z'][:Tx+EM_SIZE], 'r':data_presented['r'][:Tx+EM_SIZE]}
data_past = {'z':data_presented['z'][:-EM_SIZE], 'r':data_presented['r'][:-EM_SIZE]}
data_EM = {'z':data_presented['z'][-EM_SIZE:], 'r':data_presented['r'][-EM_SIZE:]}

# Fit 1x2D task model
posterior_params = helper.gamma_posterior_analytic(data_past['z'], data_past['r'], SIGMA_R_MODEL, Sigma_0=10*np.eye(2))
posterior = tfd.MultivariateNormalFullCovariance(loc=posterior_params[0], covariance_matrix=posterior_params[1])

# generate N dream data sets
dreams = generate_dream_data_set(posterior, T=5, N=10, z_prior_type=DREAM_DATA_Z_PRIOR, sigma_reward=SIGMA_R_DREAM)

# append episodic memories to dreams
dreams_plus_EM = []
for dream in dreams:
    dream_plus_EM = {'z':np.concatenate([dream['z'], data_EM['z']]), 'r':np.concatenate([dream['r'], data_EM['r']])}
    dream_plus_EM = dreams_plus_EM.append(dream_plus_EM)
dreams = dreams_plus_EM
    
# mllh on 2D 1 task model
onetask_mllhs_dream = [helper.model_marginal_llh_analytic(dream['z'], dream['r'], SIGMA_R_MODEL, Sigma_0=np.eye(2)) for dream in dreams]

# mllh on 2D 2 task model
twotask_mllhs_dream = [helper.model_marginal_llh_analytic_2x2D(dream['z'], dream['r'], SIGMA_R_MODEL, Sigma_0_2D = np.array([[1., 0.], [0., 1.]])) for dream in dreams]

# mllhs on ground truth data
onetask_mllh = helper.model_marginal_llh_analytic(data_presented['z'], data_presented['r'], SIGMA_R_MODEL, Sigma_0=np.eye(2))
twotask_mllh = helper.model_marginal_llh_analytic_2x2D(data_presented['z'], data_presented['r'], SIGMA_R_MODEL, Sigma_0_2D = np.array([[1., 0.], [0., 1.]]))

In [None]:
plt.figure(figsize=(10,4))
plt.subplots_adjust(wspace=0.2, hspace=0.4)
plt.suptitle("BLOCKED")

plt.subplot(1,3,1)
#plot_data_xy_labels(true_data, xylabels)
helper.plot_data(data_presented, colorbar=False)
plt.title("true data")

plt.subplot(1,3,2)
#plot_data_xy_labels(true_data, xylabels)
helper.plot_data(data_past, colorbar=False, axislabels=False)
plt.title("past data (without EM)")

plt.subplot(1,3,3)
samples = posterior.sample(300)
sns.kdeplot(x=samples[:, 0], y=samples[:, 1])
plt.scatter(x=samples[:, 0], y=samples[:, 1])
plt.xlim([-3,3])
plt.ylim([-3,3])
plt.gca().set_aspect('equal')
plt.title("gamma posterior")
plt.show()

plt.figure(figsize=(10,10))
plt.subplots_adjust(wspace=0.4, hspace=0.4)

plt.subplot(3,1,1)
plt.hist(onetask_mllhs_dream,100)
plt.axvline(onetask_mllh, color='gray', linestyle='dashed', linewidth=1)
plt.axvline(np.mean(onetask_mllhs_dream), color='lightblue', linestyle='dashed', linewidth=1)
plt.xlabel("mllh value, true (black dashed), avg (blue dashed)")
plt.ylabel("occurences")

plt.hist(twotask_mllhs_dream,100)
plt.axvline(twotask_mllh, color='k', linestyle='dashed', linewidth=1)
plt.axvline(np.mean(twotask_mllhs_dream), color='darkblue', linestyle='dashed', linewidth=1)
plt.title("both tasks, 2 task (dark), 1 task (light)")

plt.subplot(3,2,3)
plt.hist(onetask_mllhs_dream,100)
plt.axvline(onetask_mllh, color='k', linestyle='dashed', linewidth=1)
plt.axvline(np.mean(onetask_mllhs_dream), color='b', linestyle='dashed', linewidth=1)
plt.xlabel("mllh value, true (black dashed), avg (blue dashed)")
plt.ylabel("occurences")
plt.title("1 x 2D task")

plt.subplot(3,2,4)
plt.hist(twotask_mllhs_dream,100)
plt.axvline(twotask_mllh, color='k', linestyle='dashed', linewidth=1)
plt.axvline(np.mean(twotask_mllhs_dream), color='b', linestyle='dashed', linewidth=1)
plt.title("2 x 2D task")

plt.subplot(3,3,7)
plt.plot(["1x2D","2x2D"],[onetask_mllh, twotask_mllh])
plt.plot(["1x2D","2x2D"],[np.mean(onetask_mllhs_dream), np.mean(twotask_mllhs_dream)])
plt.legend(["true","avg dream"])

plt.subplot(3,3,8)
plt.bar(["1x2D","2x2D"],[onetask_mllh, twotask_mllh])
plt.title("mllh on true data")

plt.subplot(3,3,9)
plt.bar(["1x2D","2x2D"],[np.mean(onetask_mllhs_dream), np.mean(twotask_mllhs_dream)])
plt.title("avg mllh on dream data")
plt.show()

n_vertical = 4
n_horizontal = 5

plt.figure(figsize=(6,5))
for j in range(n_vertical):
    for i in range(n_horizontal):
        plt.subplot(n_vertical, n_horizontal,j*n_horizontal+i+1)
        helper.plot_data(
            helper.generate_data_from_gamma(N=300, gamma=samples[j*n_horizontal+i],
                                            z_prior_type='uniform', sigma_z_prior=1.5, r_bias=0, sigma_reward=0.1, sigma_bias=0),
            limit=1.75, axislabels=False, colorbar=False, ticks=False)
plt.show()


In [None]:
helper.plot_data(generate_dream_data_set(posterior, T=1000, N=1, z_prior_type=DREAM_DATA_Z_PRIOR, sigma_reward=SIGMA_R_DREAM)[0])

In [None]:
data_presented = data_interleaved

data_presented = {'z':data_presented['z'][:Tx+EM_SIZE], 'r':data_presented['r'][:Tx+EM_SIZE]}
data_past = {'z':data_presented['z'][:-EM_SIZE], 'r':data_presented['r'][:-EM_SIZE]}
data_EM = {'z':data_presented['z'][-EM_SIZE:], 'r':data_presented['r'][-EM_SIZE:]}

# Fit 2D 1 task model
posterior_params = helper.gamma_posterior_analytic(data_past['z'], data_past['r'], SIGMA_R_MODEL, Sigma_0=10*np.eye(2))
posterior = tfd.MultivariateNormalFullCovariance(loc=posterior_params[0], covariance_matrix=posterior_params[1])

# generate N dream data sets
dreams = generate_dream_data_set(posterior, T=5, N=10, z_prior_type=DREAM_DATA_Z_PRIOR, sigma_reward=SIGMA_R_DREAM)

# append episodic memories to dreams
dreams_plus_EM = []
for dream in dreams:
    dream_plus_EM = {'z':np.concatenate([dream['z'], data_EM['z']]), 'r':np.concatenate([dream['r'], data_EM['r']])}
    dream_plus_EM = dreams_plus_EM.append(dream_plus_EM)
dreams = dreams_plus_EM
    
# mllh on 2D 1 task model
onetask_mllhs_dream = [helper.model_marginal_llh_analytic(dream['z'], dream['r'], SIGMA_R_MODEL, Sigma_0=np.eye(2)) for dream in dreams]

# mllh on 2D 2 task model
twotask_mllhs_dream = [helper.model_marginal_llh_analytic_2x2D(dream['z'], dream['r'], SIGMA_R_MODEL, Sigma_0_2D = np.array([[1., 0.], [0., 1.]])) for dream in dreams]

# mllhs on ground truth data
onetask_mllh = helper.model_marginal_llh_analytic(data_presented['z'], data_presented['r'], SIGMA_R_MODEL, Sigma_0=np.eye(2))
twotask_mllh = helper.model_marginal_llh_analytic_2x2D(data_presented['z'], data_presented['r'], SIGMA_R_MODEL, Sigma_0_2D = np.array([[1., 0.], [0., 1.]]))

In [None]:
plt.figure(figsize=(10,4))
plt.subplots_adjust(wspace=0.2, hspace=0.4)
plt.suptitle("INTERLEAVED")

plt.subplot(1,3,1)
#plot_data_xy_labels(true_data, xylabels)
helper.plot_data(data_presented, colorbar=False)
plt.title("true data")

plt.subplot(1,3,2)
#plot_data_xy_labels(true_data, xylabels)
helper.plot_data(data_past, colorbar=False, axislabels=False)
plt.title("past data (without EM)")

plt.subplot(1,3,3)
samples = posterior.sample(300)
sns.kdeplot(x=samples[:, 0], y=samples[:, 1])
plt.scatter(x=samples[:, 0], y=samples[:, 1])
plt.xlim([-3,3])
plt.ylim([-3,3])
plt.gca().set_aspect('equal')
plt.title("gamma posterior")
plt.show()

plt.figure(figsize=(10,10))
plt.subplots_adjust(wspace=0.4, hspace=0.4)

plt.subplot(3,1,1)
plt.hist(onetask_mllhs_dream,100)
plt.axvline(onetask_mllh, color='gray', linestyle='dashed', linewidth=1)
plt.axvline(np.mean(onetask_mllhs_dream), color='lightblue', linestyle='dashed', linewidth=1)
plt.xlabel("mllh value, true (black dashed), avg (blue dashed)")
plt.ylabel("occurences")

plt.hist(twotask_mllhs_dream,100)
plt.axvline(twotask_mllh, color='k', linestyle='dashed', linewidth=1)
plt.axvline(np.mean(twotask_mllhs_dream), color='darkblue', linestyle='dashed', linewidth=1)
plt.title("both tasks, 2 task (dark), 1 task (light)")

plt.subplot(3,2,3)
plt.hist(onetask_mllhs_dream,100)
plt.axvline(onetask_mllh, color='k', linestyle='dashed', linewidth=1)
plt.axvline(np.mean(onetask_mllhs_dream), color='b', linestyle='dashed', linewidth=1)
plt.xlabel("mllh value, true (black dashed), avg (blue dashed)")
plt.ylabel("occurences")
plt.title("1 x 2D task")

plt.subplot(3,2,4)
plt.hist(twotask_mllhs_dream,100)
plt.axvline(twotask_mllh, color='k', linestyle='dashed', linewidth=1)
plt.axvline(np.mean(twotask_mllhs_dream), color='b', linestyle='dashed', linewidth=1)
plt.title("2 x 2D task")

plt.subplot(3,3,7)
plt.plot(["1x2D","2x2D"],[onetask_mllh, twotask_mllh])
plt.plot(["1x2D","2x2D"],[np.mean(onetask_mllhs_dream), np.mean(twotask_mllhs_dream)])
plt.legend(["true","avg dream"])

plt.subplot(3,3,8)
plt.bar(["1x2D","2x2D"],[onetask_mllh, twotask_mllh])
plt.title("mllh on true data")

plt.subplot(3,3,9)
plt.bar(["1x2D","2x2D"],[np.mean(onetask_mllhs_dream), np.mean(twotask_mllhs_dream)])
plt.title("avg mllh on dream data")
plt.show()

n_vertical = 4
n_horizontal = 5

plt.figure(figsize=(6,5))
for j in range(n_vertical):
    for i in range(n_horizontal):
        plt.subplot(n_vertical, n_horizontal,j*n_horizontal+i+1)
        helper.plot_data(
            helper.generate_data_from_gamma(N=300, gamma=samples[j*n_horizontal+i],
                                            z_prior_type='uniform', sigma_z_prior=1.5, r_bias=0, sigma_reward=0.1, sigma_bias=0),
            limit=1.75, axislabels=False, colorbar=False, ticks=False)
plt.show()

In [None]:
helper.plot_data(generate_dream_data_set(posterior, T=1000, N=1, z_prior_type=DREAM_DATA_Z_PRIOR, sigma_reward=SIGMA_R_DREAM)[0])

In [None]:
n_vertical = 2
n_horizontal = 5

plt.figure(figsize=(10,4))
plt.figure(figsize=(15,6))
for j in range(n_vertical):
    for i in range(n_horizontal):
        plt.subplot(n_vertical, n_horizontal,j*n_horizontal+i+1)
        helper.plot_data(dreams[j*n_horizontal+i], labels=False, limit=1.7, climit=1,axislabels=False, marker='+', colorbar=False)

In [None]:
n_vertical = 2
n_horizontal = 5

plt.figure(figsize=(15,6))
for j in range(n_vertical):
    for i in range(n_horizontal):
        plt.subplot(n_vertical, n_horizontal,j*n_horizontal+i+1)
        #helper.plot_data(data_past,marker='^', axislabels=False, limit=1.75)
        helper.plot_data(data_EM,marker='>', axislabels=False, limit=1.75, colorbar=False)
        helper.plot_data(dreams[j*n_horizontal+i],marker='+', axislabels=False, limit=1.75, colorbar=False)
        #plt.legend(["past data", "EM contents", "dreamed"])

In [None]:
plt.figure(figsize=(10,5))
helper.plot_data(data_past,marker='^', axislabels=False, limit=1.75, colorbar=False)
helper.plot_data(data_EM,marker='>', axislabels=False, limit=1.75, colorbar=False)
helper.plot_data(dreams[0],marker='+', axislabels=False, limit=1.75, colorbar=False)
plt.legend(["past data", "EM contents", "dreamed"])

In [None]:
#np.save("demo_data_sigr_07",true_data)

In [None]:
ax = plt.axes()
[ax.arrow(0, 0, sample[0], sample[1], head_width=0.02, head_length=0.05, fc='k', ec='k') for sample in samples[:100]]
plt.xlim([-1.75,1.75])
plt.ylim([-1.75,1.75])
plt.gca().set_aspect('equal')
plt.show()