Code used to produce random walk trajectories in the new gridworld map, fit to the mouse data.

In [None]:
import random
import numpy as np
import matplotlib.pyplot as plt

## Define the space

#### Define the actions that can be taken and the movement they result in

In [None]:
random.seed(0)

action_space = np.array([0, 1, 2, 3, 4, 5, 6, 7])
# 0: LEFT
# 1: UP
# 2: RIGHT
# 3: DOWN
# 4: LEFT-UP
# 5: RIGHT-UP
# 6: RIGHT-DOWN
# 7: LEFT-DOWN
delta = np.array([[-1, 0],[0, 1],[1, 0],[0, -1],[-1, 1],[1, 1],[1, -1],[-1, -1]])

#### Create array of invalid points within a 13x13 grid

In [None]:
#Array of obstacle coordinates
obstacle = np.array([[5,7],[6,7],[7,7],[8,7],[9,7]])

#Array of borders
top_border = np.zeros([11,2])
bottom_border = np.zeros([11,2])
left_border = np.zeros([9,2])
right_border = np.zeros([9,2])
for i in range(11):
    top_border[i] = [i+2,14] 
    bottom_border[i] = [i+2,0]
for i in range(9):
    left_border[i] = [0,i+3]
    right_border[i] = [14,i+3]
    

#Array of the bottom left invalid points
invalid_bl = np.array([[1,1],[1,2],[1,3],[2,1]])
#Flip along the middle axes to get all the invalid points
#Top left
invalid_tl = np.copy(invalid_bl)
invalid_tl[:,1] = -invalid_tl[:,1]+14
#Top right
invalid_tr = np.copy(invalid_tl)
invalid_tr[:,0] = -invalid_tr[:,0]+14
#Bottom right
invalid_br = np.copy(invalid_bl)
invalid_br[:,0] = -invalid_br[:,0]+14


#Full array of invalid points
invalid_all = np.concatenate((obstacle, top_border, bottom_border, left_border, right_border, invalid_bl, invalid_tl, invalid_tr, invalid_br), axis=0)
invalid = np.concatenate((top_border, bottom_border, left_border, right_border, invalid_bl, invalid_tl, invalid_tr, invalid_br), axis=0)

#### Create array of invalid transitions for conditions 3 & 4

In [None]:
# 0: LEFT
# 1: UP
# 2: RIGHT
# 3: DOWN
# 4: LEFT-UP
# 5: RIGHT-UP
# 6: RIGHT-DOWN
# 7: LEFT-DOWN

F_actions = np.array([2,3,6,7])
A_actions = np.array([0,3,6,7])
Z_actions = np.array([7])
H_actions = np.array([0,3,7])
K_actions = np.array([2,3,6])
Y_actions = np.array([6])
V_actions = np.array([7,0])
U_actions = np.array([2,6])
X_actions = np.array([0,3,4,7])
W_actions = np.array([2,3,5,6])
B_actions = np.array([3,6,7])
B_exit_actions = np.array([1,4,5])

## Generate the data

### Conditions 1 & 2

#### V1: does not record attempts to make invalid moves 
-> more similar to mouse data

In [None]:
num_trials = 8 
num_steps = 1420
data = np.zeros([num_steps,2,num_trials])

trial_count = 0
while trial_count < num_trials:
    current_position = np.array([7,12])
    data[0,:,trial_count] = current_position
    step_count = 1
    while step_count < num_steps:
        old_position = current_position
        action = random.choice(action_space)
        current_position = current_position + delta[action]
        if current_position.tolist() in invalid_all.tolist():
            current_position = old_position
            index = np.argwhere(action_space==action)
            new_action_space = np.delete(action_space, index)
            action_space = new_action_space
        else:
            data[step_count,:,trial_count] = current_position
            step_count += 1
            action_space = np.array([0, 1, 2, 3, 4, 5, 6, 7])
    trial_count += 1

In [None]:
# With practice runs
num_trials = 8 
num_steps = 1420
data = np.zeros([num_steps,2,num_trials])

trial_count = 0
while trial_count < num_trials:
    current_position = np.array([7,12])
    data[0,:,trial_count] = current_position
    step_count = 1
    while step_count < num_steps:
        old_position = current_position
        action = random.choice(action_space)
        current_position = current_position + delta[action]
        if current_position.tolist() == [8,12]:
            data[step_count,:,trial_count] = current_position
            step_count += 1
            action_space = np.array([0, 1, 2, 3, 4, 5, 6, 7])
            if random.random()<0.2:
                for a in [6,3,3,6,3]:
                    if step_count < num_steps:
                        old_position = current_position
                        current_position = current_position + delta[a]
                        data[step_count,:,trial_count] = current_position
                        step_count += 1
        elif current_position.tolist() == [6,12]:
            data[step_count,:,trial_count] = current_position
            step_count += 1
            action_space = np.array([0, 1, 2, 3, 4, 5, 6, 7])
            if random.random()<0.2:
                for a in [7,3,3,7,3]:
                    if step_count < num_steps:
                        old_position = current_position
                        current_position = current_position + delta[a]
                        data[step_count,:,trial_count] = current_position
                        step_count += 1
        elif current_position.tolist() == [7,12]:
            data[step_count,:,trial_count] = current_position
            step_count += 1
            action_space = np.array([0, 1, 2, 3, 4, 5, 6, 7])
            if random.random()<0.2:
                for a in [3,7,7,7,3]:
                    if step_count < num_steps:
                        old_position = current_position
                        current_position = current_position + delta[a]
                        data[step_count,:,trial_count] = current_position
                        step_count += 1             
            elif random.random()<0.2:
                for a in [3,6,6,6,3]:
                    if step_count < num_steps:
                        old_position = current_position
                        current_position = current_position + delta[a]
                        data[step_count,:,trial_count] = current_position
                        step_count += 1
        elif current_position.tolist() == [4,7]:
            data[step_count,:,trial_count] = current_position
            step_count += 1
            action_space = np.array([0, 1, 2, 3, 4, 5, 6, 7])
            if random.random()<0.2:
                for a in [3,6,6,6,3,3]:
                    if step_count < num_steps:
                        old_position = current_position
                        current_position = current_position + delta[a]
                        data[step_count,:,trial_count] = current_position
                        step_count += 1
        elif current_position.tolist() == [10,7]:
            data[step_count,:,trial_count] = current_position
            step_count += 1
            action_space = np.array([0, 1, 2, 3, 4, 5, 6, 7])
            if random.random()<0.2:
                for a in [3,7,7,7,3,3]:
                    if step_count < num_steps:
                        old_position = current_position
                        current_position = current_position + delta[a]
                        data[step_count,:,trial_count] = current_position
                        step_count += 1
        elif current_position.tolist() in invalid_all.tolist():
            current_position = old_position
            index = np.argwhere(action_space==action)
            new_action_space = np.delete(action_space, index)
            action_space = new_action_space
        else:
            data[step_count,:,trial_count] = current_position
            step_count += 1
            action_space = np.array([0, 1, 2, 3, 4, 5, 6, 7])    
    trial_count += 1

In [None]:
# Extend the dataset by concatenating repeats of it
num_repeats = 10
new_num_trials = num_trials*num_repeats
temp = np.zeros((num_steps, 2, new_num_trials))

for i in range(num_repeats):
    temp[:,:,i*num_trials:i*num_trials+num_trials] = data

data_extended = temp

In [None]:
# Save the data, or data_extended
np.save('random_walk_data_1_2_extendedx5.npy', data_extended)

## Plot a trial

In [None]:
data = np.load('random_walk_data_1_2_extendedx5.npy')
trial1_x = data[:,0,2]
trial1_y = data[:,1,2]
plt.plot(trial1_x, trial1_y)

# Uncomment to plot dots corresponding to the position of the condition 3 tripwires
# plt.scatter([4,9,5,10,5,6,8,9],[10,10,10,10,9,9,9,9])
# plt.xlim([0,14])
# plt.ylim([0,14])

## Check for invalid points

In [None]:
# Load up the data if you've already saved something
data = np.load('random_walk_data_1_2_extendedx5.npy')

In [None]:
for i in range(num_trials):
    for point in invalid:
        indexes = np.where(np.all(point == data[:,:,i], axis=1))
        if len(indexes[0])>0:
            print(indexes[0])
            
# If nothing prints, all is well!