In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
import pandas as pd

## Load data

In [None]:
all_data = np.load('condition_1.pkl', allow_pickle=True)

data = np.array(all_data['position'])
data = np.transpose(data, (1,2,0))
num_timepts, num_coords, num_trials = data.shape
print(num_timepts, num_coords, num_trials)

## Visualise the data

#### Plot a single trial

In [None]:
trial1_x = -data[:,0,1]+994.05279541
trial1_y = -data[:,1,1]+1002.065979
plt.plot(trial1_x, trial1_y)

#### Plot heatmap of all trials

In [None]:
x = data[:,0,:].reshape(-1)
y = data[:,1,:].reshape(-1)

heatmap, xedges, yedges = np.histogram2d(x, y, bins=500)
extent = [xedges[0], xedges[-1], yedges[0], yedges[-1]]
major = np.linspace(0,1000,num=21)

plt.clf()
plt.imshow(heatmap.T, extent=extent, cmap ='gist_heat_r', vmax=50)
ax = plt.gca()
# Removes the frame and ticks for a clean figure
ax.set(frame_on=False)
ax.set_xticks([])
ax.set_yticks([])
# Uncomment the lines below for gridlines
# ax.set_xticks(major)
# ax.set_yticks(major)
# ax.grid(which='major',color='black', linestyle='-', linewidth=1)
plt.colorbar()

#### Plot a histogram of position coordinates

In [None]:
# x coordinates
plt.hist(x)

In [None]:
# y coordinates 
plt.hist(y)

## Replace coordinates with gridworld coordinates

#### Assign each mouse coordinate to a gridworld position

In [None]:
# Code used to check x and y extremes: prints values from the dataset that are smaller/greater than the specified value
# Switch between data[:,0,:] for x coordinates and data[:,1,:] for y coordinates
data[:,1,:][np.asarray(data[:,1,:]<35.1).nonzero()]

In [None]:
square_size = 965/13
x0 = 33.5
y0 = 37.5
for i in range(13):
    data[:,0,:][np.logical_and(data[:,0,:]>=(x0+i*square_size),data[:,0,:]<(x0+(i+1)*square_size))] = i+1 # Add 1 so that grid coordinates are between 1 and 13
    data[:,1,:][np.logical_and(data[:,1,:]>=(y0+i*square_size),data[:,1,:]<(y0+(i+1)*square_size))] = i+1

In [None]:
# Code used to print the continuous x and y coordinates each gridworld square corresponds to
for i in range(13):
    print(i+1, x0+i*square_size, x0+(i+1)*square_size)
for i in range(13):
    print(i+1, y0+i*square_size, y0+(i+1)*square_size)

#### Scatter plot of the new coordinates

In [None]:
x = data[:,0,:].reshape(-1)
y = data[:,1,:].reshape(-1)

fig, ax = plt.subplots(figsize=(5,5))
plt.scatter(x, y, c='black')
ax.set_yticks(np.linspace(-0.5,14.5,16), minor=True)
ax.set_xticks(np.linspace(-0.5,14.5,16), minor=True)
ax.yaxis.grid(True, which='minor')
ax.xaxis.grid(True, which='minor')
plt.show()

#### Flip the data 
(because the origin is actually the upper left corner)

In [None]:
data[:,0,:] = -data[:,0,:]+14
data[:,1,:] = -data[:,1,:]+14

#### Correct invalid moves

In [None]:
# Hard coding to correct discontinuity errors
idx = np.where((data[:,:,5] == [8,4]).all(axis=1) & (np.roll(data[:,:,5],-1,axis=0) == [8,6]).all(axis=1))[0]
data[idx,1,5]=5
data[idx+1,1,5]=6
data[idx+2,1,5]=7

In [None]:
idx = np.where((data[:,:,5] == [7,8]).all(axis=1) & (np.roll(data[:,:,5],-1,axis=0) == [7,6]).all(axis=1))[0]
print(idx)
data[idx+1,1,5]=7
data[idx+2,1,5]=7
data[idx+3,1,5]=7
data[idx+4,1,5]=7
data[idx+5,1,5]=7

In [None]:
idx = np.where((data[:,:,6] == [9,5]).all(axis=1) & (np.roll(data[:,:,6],-1,axis=0) == [8,3]).all(axis=1))[0]
print(idx)
data[idx+1,:,6]=[9,5]
data[idx+2,:,6]=[9,5]
data[idx+3,:,6]=[9,5]
data[idx+4,:,6]=[9,5]
data[idx+5,:,6]=[9,5]

In [None]:
idx = np.where((data[:,:,7] == [9,5]).all(axis=1) & (np.roll(data[:,:,7],-1,axis=0) == [8,4]).all(axis=1))[0]
print(idx)
data[idx+2,:,7]=[8,3]
data[idx+3,:,7]=[8,4]
data[idx+4,:,7]=[9,5]
data[idx+5,:,7]=[9,5]

In [None]:
# Corrects moves DOWN into the obstacle (from y=8 to y=7)
for i in range(num_trials):
    for j in range(5,10):
        idx = np.where((data[:,:,i] == [j,8]).all(axis=1) & (np.roll(data[:,:,i],-1,axis=0) == [j,7]).all(axis=1))[0]
        for k in range(len(idx)):
            n=idx[k]+1
            while data[n,0,i]>4 and data[n,0,i]<10 and data[n,1,i]==7:
                data[n,1,i]=8
                if n<num_timepts-1:
                    n=n+1

In [None]:
# Corrects moves UP into the obstacle (from y=7 to y=8)
for i in range(num_trials):
    for j in range(5,10):
        idx = np.where((data[:,:,i] == [j,7]).all(axis=1) & (np.roll(data[:,:,i],-1,axis=0) == [j,8]).all(axis=1))[0]
        for k in range(len(idx)):
            n=idx[k]+1
            while data[n,0,i]>4 and data[n,0,i]<10 and data[n,1,i]==8:
                data[n,1,i]=7
                if n<num_timepts-1:
                    n=n+1

In [None]:
# Hard coding to correct issues arsing from move corrections
idx = np.where((data[:,:,4] == [6,7]).all(axis=1) & (np.roll(data[:,:,4],-1,axis=0) == [6,9]).all(axis=1))[0]
print(idx)
data[idx+1,:,4]=[5,7]
data[idx+2,:,4]=[4,7]
data[idx+3,:,4]=[4,8]
data[idx+4,:,4]=[5,8]

In [None]:
idx = np.where((data[:,:,5] == [9,8]).all(axis=1) & (np.roll(data[:,:,5],-1,axis=0) == [8,6]).all(axis=1))[0]
print(idx)
data[idx+1,:,5]=[10,8]
data[idx+2,:,5]=[10,7]
data[idx+3,:,5]=[9,7]
data[idx+4,:,5]=[9,7]
data[idx+5,:,5]=[9,7]
data[idx+6,:,5]=[9,7]

In [None]:
idx = np.where((data[:,:,5] == [5,7]).all(axis=1) & (np.roll(data[:,:,5],-1,axis=0) == [4,8]).all(axis=1))[0]
print(idx)
data[idx+1,:,5]=[4,7]

#### Ensure all the gridworld positions are valid

In [None]:
# Array of the bottom left invalid points
invalid_bl = np.array([[1,1],[1,2],[1,3],[2,1]])


# Flip along the middle axes to get all the invalid points
# Top left
invalid_tl = np.copy(invalid_bl)
invalid_tl[:,1] = -invalid_tl[:,1]+14
# Top right
invalid_tr = np.copy(invalid_tl)
invalid_tr[:,0] = -invalid_tr[:,0]+14
# Bottom right
invalid_br = np.copy(invalid_bl)
invalid_br[:,0] = -invalid_br[:,0]+14

# Full array of invalid points
invalid = np.concatenate((invalid_bl, invalid_tl, invalid_tr, invalid_br), axis=0)

# Print list of indices of invalid points, and the invalid point they correspond to
for i in range(num_trials):
    for point in invalid:
        indexes = np.where(np.all(point == data[:,0:2,i], axis=1))
        if len(indexes[0])>0:
            print("Trial number:", i)
            print("Point:", point)
            print(len(indexes[0]))
            
# If nothing prints, all is well!

#### Plot a single trial in the new gridworld coordinate system

In [None]:
plt.figure(figsize=(5,5))
for i in range(num_trials):
    trial1_x = data[:,0,i]
    trial1_y = data[:,1,i]
    plt.plot(trial1_x, trial1_y, 'black')
plt.xlim([0,14])
plt.ylim([0,14])

## Save the transformed data

In [None]:
np.save('mouse_data_1.npy', data)

In [None]:
# Extend the dataset by concatenating repeats of it
num_repeats = 25
new_num_trials = num_trials*num_repeats
temp = np.zeros((num_timepts, num_coords, new_num_trials))

for i in range(num_repeats):
    print(i*num_timepts,i*num_timepts+num_timepts)
    temp[:,:,i*num_trials:i*num_trials+num_trials] = data

data = temp

In [None]:
# Shuffle the extended dataset
rng = np.random.default_rng()
rng.shuffle(data, axis=2)

In [None]:
np.save('mouse_data_1_extendedx25_shuffled.npy', data)

In [None]:
# Remove consecutively identical coordinates (the github code does this normally, this code here is just for checks/testing)
temp = np.zeros((num_timepts, num_coords, num_trials))
lengths = np.zeros((num_trials))
for i in range(num_trials):
    a = data[:,:,i]
    a = a[np.insert(np.invert(np.all(np.diff(a,axis=0)==0,axis=1)), 0, True)]
    temp[0:len(a),:,i] = a
    lengths[i]=len(a)
    
data = temp
print(lengths)
print(np.sum(lengths))

In [None]:
np.save('mouse_data_1_TEST.npy', data)