In [34]:
import numpy as np
import matplotlib.pyplot as plt
import math
from simulation_functions import create_watermaze, create_actions, bounce_off, goal_not_reached
from create_figure3 import figure3, plot_Cp
from actorcritic_functions import place_cells, critic, actor, delta
np.random.seed(40)


In [None]:
# create a water maze and plot:
x_coords, y_coords, goal_x, goal_y = create_watermaze()
plt.figure(0)
plt.plot(x_coords, y_coords, c='black', ls='-')
plt.plot(goal_x, goal_y, c='black', fillstyle='none', marker='o')
actionsx, actionsy = create_actions()
N = 493

# initialize parameters:
dt = 0.1
T = 120.0
max_trials = 1 # 5 # 25
discount = 0.9

# initialize weights to train and vectors of activities:
w = np.zeros((max_trials + 1, N))
z = np.zeros((max_trials + 1, 8, N))
f, places = place_cells([0.9, 0.001])
C = np.zeros(max_trials) # N?
places_x, places_y = np.transpose(places)
x_path = np.zeros((max_trials, int(T / dt)))
y_path = np.zeros((max_trials, int(T / dt)))

for trial in range(max_trials):
    # initial location and direction:
    x = [x_coords[int(209 / np.random.randint(1, 4))]]
    y = [y_coords[int(209 / np.random.randint(1, 4))]]
    direction = np.random.randint(0, 8)
    print trial
    for t in range(int(T / dt)):
        if goal_not_reached(x[-1], y[-1], goal_x, goal_y):
            # compute current C:
            f_p, _ = place_cells([x[-1], y[-1]])
            C_current = critic(w[trial, :], f_p)
            
            # compute action probabilities and perform best action:
            a = actor(z[trial, :, :], f_p) # 8x1
            print "a: ", a
            P = np.exp(2.0 * a) / np.sum(np.exp(2.0 * a))
            
            new_direction = np.sum(np.random.random() > np.cumsum(P)) # TODO: 1:3 ratio of changing directions
            if np.random.randint(0, high=4) > 0: # keep current direction
                new_direction = direction
            direction = new_direction
            dx = actionsx[direction] * 0.03
            dy = actionsy[direction] * 0.03

            # check for boundaries:
            if np.linalg.norm([x[-1] + dx, y[-1] + dy], ord=2) > 1.0:
                # bounce off:
                xnew, ynew = bounce_off(x[-1], y[-1], dx, dy)
                x.append(xnew)
                y.append(ynew)
            else:
                x.append(x[-1] + dx)
                y.append(y[-1] + dy)
            x_path[trial, t] = x[-1]
            y_path[trial, t] = y[-1]

            # the critic (update w):
            f_p_new, _ = place_cells([x[-1], y[-1]])
            C_new = critic(w[trial, :], f_p_new)
            dw = np.repeat(delta(C_current, C_new, discount, goal_not_reached(x[-2], y[-2], goal_x, goal_y)), N) * f_p
            w[trial, :] = w[trial, :] + dw

            # the actor (update z):
            dz = np.repeat(delta(C_current, C_new, discount, goal_not_reached(x[-2], y[-2], goal_x, goal_y)), N) * f_p 
            z[trial, direction, :] = z[trial, direction, :] + dz
            
        else: # goal reached
            print "goal reached!"
            break
    w[trial + 1, :] = w[trial, :]
    z[trial + 1, :, :] = z[trial, :, :]
#
plt.plot(x_path[0, :], y_path[0, :])
plt.show()

0
a:  [[0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]
 [0.]]
a:  [[0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [1.63414597e-15]]
a:  [[2.89810804e-15]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [1.92374301e-15]]
a:  [[1.91020758e-15]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [3.51161479e-15]]
a:  [[2.49431712e-15]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [6.51823145e-15]]
a:  [[2.48083299e-15]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [9.06487163e-15]]
a:  [[2.46835691e-15]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [0.00000000e+00]
 [1.16780200e-14]]
a:  [[2.44765566e-15]
 [0.00000000e+00]
 [0.0000000

In [47]:
print z[0, 7, :]

[4.16332595e-01 4.00193507e-01 4.25502555e-01 4.01343082e-01
 3.21479607e-01 3.10267333e-01 3.93098217e-01 3.94871118e-01
 2.65158195e-01 2.32343882e-01 3.26525192e-01 4.06077393e-01
 2.35914608e-01 1.75822335e-01 2.45644697e-01 4.12091001e-01
 2.27048235e-01 1.38456650e-01 1.75288002e-01 3.72459933e-01
 2.35822830e-01 1.15443142e-01 1.24154177e-01 2.81513683e-01
 2.61730388e-01 1.02925648e-01 9.01017813e-02 1.91293048e-01
 2.99521783e-01 9.86296190e-02 6.83762512e-02 1.25581856e-01
 3.11024837e-01 1.01684591e-01 5.49341417e-02 8.28009499e-02
 2.44873384e-01 1.12296797e-01 4.70749595e-02 5.63264450e-02
 1.68578396e-01 1.31139582e-01 4.32456320e-02 4.02546661e-02
 1.07323773e-01 1.53586757e-01 4.27246431e-02 3.05309205e-02
 6.64727190e-02 1.51085340e-01 4.53611579e-02 2.46745604e-02
 4.21621517e-02 1.20801383e-01 5.12487365e-02 2.12950678e-02
 2.80751207e-02 9.29384339e-02 5.97390769e-02 1.97048387e-02
 1.99308409e-02 6.03970301e-02 6.77726677e-02 1.96488512e-02
 1.51440506e-02 3.671632

In [37]:
print w[0, :]

[4.16332595e-01 4.00193507e-01 4.25502555e-01 4.01343082e-01
 3.21479607e-01 3.10267333e-01 3.93098217e-01 3.94871118e-01
 2.65158195e-01 2.32343882e-01 3.26525192e-01 4.06077393e-01
 2.35914608e-01 1.75822335e-01 2.45644697e-01 4.12091001e-01
 2.27048235e-01 1.38456650e-01 1.75288002e-01 3.72459933e-01
 2.35822830e-01 1.15443142e-01 1.24154177e-01 2.81513683e-01
 2.61730388e-01 1.02925648e-01 9.01017813e-02 1.91293048e-01
 2.99521783e-01 9.86296190e-02 6.83762512e-02 1.25581856e-01
 3.11024837e-01 1.01684591e-01 5.49341417e-02 8.28009499e-02
 2.44873384e-01 1.12296797e-01 4.70749595e-02 5.63264450e-02
 1.68578396e-01 1.31139582e-01 4.32456320e-02 4.02546661e-02
 1.07323773e-01 1.53586757e-01 4.27246431e-02 3.05309205e-02
 6.64727190e-02 1.51085340e-01 4.53611579e-02 2.46745604e-02
 4.21621517e-02 1.20801383e-01 5.12487365e-02 2.12950678e-02
 2.80751207e-02 9.29384339e-02 5.97390769e-02 1.97048387e-02
 1.99308409e-02 6.03970301e-02 6.77726677e-02 1.96488512e-02
 1.51440506e-02 3.671632