In [31]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt
import json

#directory = 'ksa'; objective = 'explored'; y_axis = 'Exploration (%)'
directory = 'dogmatic'; objective = 'rewards'; y_axis = 'Average reward'

file = open(directory + '/results.json')
data = json.load(file)
file.close()
colors = ['green','blue','red','orange','black','pink','grey','purple']
agents = {'BayesAgent':r'AI$\xi$',
          'Knowledge-seeking agent':'Kullback-Leibler',
          'KullbackLeiblerKSA':'Kullback-Leibler',
          'ShannonKSA':'Shannon',
          'SquareKSA':'Square',
          'Shannon KSA':'Shannon',
          'Square KSA':'Square',
          'ThompsonAgent':'Thompson',
          'QLearn':'Q-Learning',
          'MC-AIMU': r'$\kappa = 50$',
          'MC-AIMU-1': r'$\kappa = 100$',
          'MC-AIMU-2': r'$\kappa = 400$',
          'MC-AIMU-3': r'$\kappa = 800$',
          'MC-AIMU-4': r'$\kappa = 2000$',
          'MC-AIMU-5': r'$\kappa = 1$',
          'MC-AIMU-6': r'$\kappa = 10$',
          'MC-AIMU-7': r'$\kappa = 20$',
          'Time inconsistency':r'$C = 0.01$',
          'Time inconsistency-1':r'$C = 0.05$',
          'Time inconsistency-2':r'$C = 0.1$',
          'Time inconsistency-3':r'$C = 0.5$',
          'Time inconsistency-4':r'$C = 1$',
          'Time inconsistency-5':r'$C = 5$',
          'Time inconsistency-6':r'$C = 10$',
         }

fig = plt.figure(figsize=(12,8),dpi=200)
for i,k in enumerate(data):
    try:
        d = data[k]
    except KeyError:
        continue
    color = colors[i]
    alpha = 0.2    
    A = np.zeros((d[0]['cycles'],len(d)))
    for j in xrange(len(d)):
        A[:,j] = np.array(d[j][objective])

    mu = np.mean(A,1)
    sigma = np.std(A,1)
    a = np.max(np.vstack((mu-sigma,np.min(A,1))),0)
    b = sigma+mu
    if k in agents:
        k = agents[k]
    #plt.plot(a,color=color,alpha=alpha)
    #plt.plot(b,color=color,alpha=alpha)
    plt.plot(mu,label=k,color=color,lw=3)
    #plt.fill_between(np.arange(len(mu)),a,b,alpha=alpha,color=color)

    
kek = np.zeros(200)
kek[0] = 0
for i in xrange(1,200):
    if i < 7:
        kek[i] = kek[i-1] 
    else:
        kek[i] = kek[i-1] + 166.7
for i in xrange(1,200):
    kek[i] /= (i+1)

plt.plot(kek,lw=3,ls='dashed',color='black',label='Optimal')

def f(st):
    try:
        return float(st.split(' ')[2].replace('$',''))
    except (ValueError,IndexError):
        return 99999

plt.title(y_axis)
plt.xlabel('Cycles')
plt.ylabel(y_axis)
#plt.legend(loc='upper left')
ax = plt.gca()
handles, labels = ax.get_legend_handles_labels()
# sort both labels and handles by labels
labels, handles = zip(*sorted(zip(labels, handles), key=lambda t: f(t[0])))
ax.legend(handles, labels,loc='lower right')
plt.savefig(directory + '/plot.png', bbox_inches='tight')
plt.close()

In [27]:
N = 1e3
x = np.linspace(1/N,1,N)
y = -np.log(x)
z = -x

fig = plt.figure(figsize=(10,5),dpi=200)
plt.title(r'Utility functions for Square and Shannon KSA, as a function of $\xi$')
plt.xlabel(r'$\xi$')
plt.ylabel(r'$u(\xi)$')
plt.plot(x,y,label=r'$-log(\xi)$',lw=3)
plt.plot(x,z,label=r'$-\xi$',lw=3)
plt.legend(loc='upper right')
plt.savefig('../../../thesis/figures/square-shannon-utility')
plt.close()

In [6]:
kek = np.zeros(200)
kek[0] = -1
for i in xrange(1,200):
    if i < 8:
        kek[i] = kek[i-1] - 1
    else:
        kek[i] = kek[i-1] + 75
for i in xrange(1,200):
    kek[i] /= (i+1)

[ -1.          -1.          -1.          -1.          -1.          -1.          -1.
  -1.           7.44444444  14.2         19.72727273  24.33333333
  28.23076923  31.57142857  34.46666667  37.          39.23529412
  41.22222222  43.          44.6         46.04761905  47.36363636
  48.56521739  49.66666667  50.68        51.61538462  52.48148148
  53.28571429  54.03448276  54.73333333  55.38709677  56.          56.57575758
  57.11764706  57.62857143  58.11111111  58.56756757  59.          59.41025641
  59.8         60.17073171  60.52380952  60.86046512  61.18181818
  61.48888889  61.7826087   62.06382979  62.33333333  62.59183673  62.84
  63.07843137  63.30769231  63.52830189  63.74074074  63.94545455
  64.14285714  64.33333333  64.51724138  64.69491525  64.86666667
  65.03278689  65.19354839  65.34920635  65.5         65.64615385
  65.78787879  65.92537313  66.05882353  66.1884058   66.31428571
  66.43661972  66.55555556  66.67123288  66.78378378  66.89333333  67.
  67.1038961   67.20

In [24]:
float('0.03')

0.03