In [None]:
# To be able to make edits to repo without having to restart notebook
%load_ext autoreload
%autoreload 2

In [None]:
import os, sys
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, mannwhitneyu, wilcoxon, ttest_rel, ttest_ind
import seaborn as sns
import tkinter as tk
from scipy import signal
from scipy import ndimage


from tkinter import filedialog
import time
from pathlib import Path
import matplotlib.patches as patches

PROJECT_PATH = os.getcwd()
sys.path.append(PROJECT_PATH)

unit_matcher_path = os.getcwd()
prototype_path = os.path.abspath(os.path.join(unit_matcher_path, os.pardir))
project_path = os.path.abspath(os.path.join(prototype_path, os.pardir))
lab_path = os.path.abspath(os.path.join(project_path, os.pardir))
sys.path.append(project_path)
os.chdir(project_path)
print(project_path)

In [None]:
N = 63  # kernel size
k1d = signal.gaussian(N, std=10).reshape(N, 1)
kernel = np.outer(k1d, k1d)
single_field = np.zeros((64*3, 64*3))
single_field[64+32, 64+32] = 1  
row, col = np.where(single_field == 1)
single_field[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel


fig = plt.figure(figsize=(8,5))

ax = plt.subplot(1,2,1)
ax.imshow(kernel, cmap='jet')

ax = plt.subplot(1,2,2)
ax.imshow(single_field, cmap='jet')
rect = patches.Rectangle((64, 64), 64, 64, linewidth=1, edgecolor='r', facecolor='none')
ax.add_patch(rect)

fig.tight_layout()
fig.suptitle('Place field (left) and (3*64,3*64) map for easy manipulation of place field')
plt.show()


In [None]:
N = 63  # kernel size
k1d = signal.gaussian(N, std=10).reshape(N, 1)
kernel = np.outer(k1d, k1d)
blank_map = np.zeros((64*3, 64*3))

single_field = np.copy(blank_map)
single_field[64+32, 64+32] = 1  
row, col = np.where(single_field == 1)
single_field[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel

top_left = np.copy(blank_map)
top_left[64, 64] = 1  
row, col = np.where(top_left == 1)
top_left[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel

top_middle = np.copy(blank_map)
top_middle[64, 64+32] = 1  
row, col = np.where(top_middle == 1)
top_middle[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel

top_right = np.copy(blank_map)
top_right[64, 64+64] = 1  
row, col = np.where(top_right == 1)
top_right[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel

right_middle = np.copy(blank_map)
right_middle[64+32, 64+64] = 1  
row, col = np.where(right_middle == 1)
right_middle[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel

bottom_right = np.copy(blank_map)
bottom_right[64+64, 64+64] = 1  
row, col = np.where(bottom_right == 1)
bottom_right[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel

bottom_middle = np.copy(blank_map)
bottom_middle[64+64, 64+32] = 1  
row, col = np.where(bottom_middle == 1)
bottom_middle[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel

bottom_left = np.copy(blank_map)
bottom_left[64+64, 64] = 1  
row, col = np.where(bottom_left == 1)
bottom_left[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel

left_middle = np.copy(blank_map)
left_middle[64+32, 64] = 1  
row, col = np.where(left_middle == 1)
left_middle[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel



fig = plt.figure(figsize=(9,9))
axs = []

ax = plt.subplot(3,3,5)
ax.imshow(single_field[64:64+64,64:64+64], cmap='jet')
axs.append(ax)

ax = plt.subplot(3,3,1)
ax.imshow(top_left[64:64+64,64:64+64], cmap='jet')
axs.append(ax)

ax = plt.subplot(3,3,2)
ax.imshow(top_middle[64:64+64,64:64+64], cmap='jet')
axs.append(ax)

ax = plt.subplot(3,3,3)
ax.imshow(top_right[64:64+64,64:64+64], cmap='jet')
axs.append(ax)

ax = plt.subplot(3,3,6)
ax.imshow(right_middle[64:64+64,64:64+64], cmap='jet')
axs.append(ax)

ax = plt.subplot(3,3,9)
ax.imshow(bottom_right[64:64+64,64:64+64], cmap='jet')
axs.append(ax)

ax = plt.subplot(3,3,8)
ax.imshow(bottom_middle[64:64+64,64:64+64], cmap='jet')
axs.append(ax)

ax = plt.subplot(3,3,7)
ax.imshow(bottom_left[64:64+64,64:64+64], cmap='jet')
axs.append(ax)

ax = plt.subplot(3,3,4)
ax.imshow(left_middle[64:64+64,64:64+64], cmap='jet')
axs.append(ax)

for ax in axs:
    rect = patches.Rectangle((64, 64), 64, 64, linewidth=1, edgecolor='k', facecolor='none')
    ax.add_patch(rect)

fig.suptitle('Object and corner place fields')
fig.tight_layout()
plt.show()

object_corner_fields = {
    'top_left': top_left[64:64+64,64:64+64],
    'top_middle': top_middle[64:64+64,64:64+64],
    'top_right': top_right[64:64+64,64:64+64],
    'right_middle': right_middle[64:64+64,64:64+64],
    'bottom_right': bottom_right[64:64+64,64:64+64],
    'bottom_middle': bottom_middle[64:64+64,64:64+64],
    'bottom_left': bottom_left[64:64+64,64:64+64],
    'left_middle': left_middle[64:64+64,64:64+64],
    'center': single_field[64:64+64,64:64+64],
}


In [None]:
N = 63  # kernel size
blank_map = np.zeros((64*3, 64*3))

fig = plt.figure(figsize=(9,9))
axs = []
scaling_fields = {}

for i in range(1,17,1):
    k1d = signal.gaussian(N, std=i).reshape(N, 1)
    kernel = np.outer(k1d, k1d)
    single_field = np.copy(blank_map)
    single_field[64+32, 64+32] = 1  
    row, col = np.where(single_field == 1)
    single_field[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel

    ax = plt.subplot(4,4,i)
    ax.imshow(single_field[64:64+64,64:64+64], cmap='jet')
    ax.set_title('STD: ' + str(i))
    axs.append(ax)
    
    scaling_fields[i] = single_field[64:64+64,64:64+64]

fig.suptitle('Place field scaling')
fig.tight_layout()
plt.show()


In [None]:
N = 63  # kernel size
k1d = signal.gaussian(N, std=10).reshape(N, 1)
kernel = np.outer(k1d, k1d)
blank_map = np.zeros((64*3, 64*3))

single_field = np.copy(blank_map)
single_field[64+32, 64+32] = 1  
row, col = np.where(single_field == 1)
single_field[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel

top_left = np.copy(blank_map)
top_left[64+int(32/2), 64+int(32/2)] = 1  
row, col = np.where(top_left == 1)
top_left[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel

top_middle = np.copy(blank_map)
top_middle[64+int(32/2), 64+32] = 1  
row, col = np.where(top_middle == 1)
top_middle[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel

top_right = np.copy(blank_map)
top_right[64+int(32/2), 64+32+int(32/2)] = 1  
row, col = np.where(top_right == 1)
top_right[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel

right_middle = np.copy(blank_map)
right_middle[64+32, 64+32+int(32/2)] = 1  
row, col = np.where(right_middle == 1)
right_middle[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel

bottom_right = np.copy(blank_map)
bottom_right[64+32+int(32/2), 64+32+int(32/2)] = 1  
row, col = np.where(bottom_right == 1)
bottom_right[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel

bottom_middle = np.copy(blank_map)
bottom_middle[64+32+int(32/2), 64+32] = 1  
row, col = np.where(bottom_middle == 1)
bottom_middle[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel

bottom_left = np.copy(blank_map)
bottom_left[64+32+int(32/2), 64+int(32/2)] = 1  
row, col = np.where(bottom_left == 1)
bottom_left[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel

left_middle = np.copy(blank_map)
left_middle[64+32, 64+int(32/2)] = 1  
row, col = np.where(left_middle == 1)
left_middle[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel



fig = plt.figure(figsize=(9,9))
axs = []

ax = plt.subplot(3,3,5)
ax.imshow(single_field[64:64+64,64:64+64], cmap='jet')
axs.append(ax)

ax = plt.subplot(3,3,1)
ax.imshow(top_left[64:64+64,64:64+64], cmap='jet')
axs.append(ax)

ax = plt.subplot(3,3,2)
ax.imshow(top_middle[64:64+64,64:64+64], cmap='jet')
axs.append(ax)

ax = plt.subplot(3,3,3)
ax.imshow(top_right[64:64+64,64:64+64], cmap='jet')
axs.append(ax)

ax = plt.subplot(3,3,6)
ax.imshow(right_middle[64:64+64,64:64+64], cmap='jet')
axs.append(ax)

ax = plt.subplot(3,3,9)
ax.imshow(bottom_right[64:64+64,64:64+64], cmap='jet')
axs.append(ax)

ax = plt.subplot(3,3,8)
ax.imshow(bottom_middle[64:64+64,64:64+64], cmap='jet')
axs.append(ax)

ax = plt.subplot(3,3,7)
ax.imshow(bottom_left[64:64+64,64:64+64], cmap='jet')
axs.append(ax)

ax = plt.subplot(3,3,4)
ax.imshow(left_middle[64:64+64,64:64+64], cmap='jet')
axs.append(ax)

for ax in axs:
    rect = patches.Rectangle((64, 64), 64, 64, linewidth=1, edgecolor='k', facecolor='none')
    ax.add_patch(rect)

fig.suptitle('Object and corner place fields 2')
fig.tight_layout()
plt.show()

object_corner_fields2 = {
    'top_left': top_left[64:64+64,64:64+64],
    'top_middle': top_middle[64:64+64,64:64+64],
    'top_right': top_right[64:64+64,64:64+64],
    'right_middle': right_middle[64:64+64,64:64+64],
    'bottom_right': bottom_right[64:64+64,64:64+64],
    'bottom_middle': bottom_middle[64:64+64,64:64+64],
    'bottom_left': bottom_left[64:64+64,64:64+64],
    'left_middle': left_middle[64:64+64,64:64+64],
    'center': single_field[64:64+64,64:64+64],
}


In [None]:
from scipy.integrate import simps 

def integrate_simps (mesh, func):
    nx, ny = func.shape
    px, py = mesh[0][int(nx/2), :], mesh[1][:, int(ny/2)]
    val = simps( simps(func, px), py )
    return val

def normalize_integrate (mesh, func):
    return func / integrate_simps (mesh, func)

def moment (mesh, func, index):
    ix, iy = index[0], index[1]
    g_func = normalize_integrate (mesh, func)
    fxy = g_func * mesh[0]**ix * mesh[1]**iy
    val = integrate_simps (mesh, fxy)
    return val

def make_gauss (mesh, sxy, rxy, rot):
    x, y = mesh[0] - sxy[0], mesh[1] - sxy[1]
    px = x * np.cos(rot) - y * np.sin(rot)
    py = y * np.cos(rot) + x * np.sin(rot)
    fx = np.exp (-0.5 * (px/rxy[0])**2)
    fy = np.exp (-0.5 * (py/rxy[1])**2)
    return fx * fy

def get_centroid (mesh, func):
    dx = moment (mesh, func, (1, 0))
    dy = moment (mesh, func, (0, 1))
    return dx, dy

def get_covariance (mesh, func, dxy):
    g_mesh = [mesh[0]-dxy[0], mesh[1]-dxy[1]]
    Mxx = moment (g_mesh, func, (2, 0))
    Myy = moment (g_mesh, func, (0, 2))
    Mxy = moment (g_mesh, func, (1, 1))
    return np.array([[Mxx, Mxy], [Mxy, Myy]])

In [None]:
from scipy.stats import multivariate_normal

# nx, ny = 64, 64
# lx, ly = 1, 1

px = np.arange(0,64,1)
py = np.arange(0,64,1)
mesh = np.meshgrid(px, py)

# rx, ry = 3, 7
# sx, sy = 22, 32
# rot = 45
# elliptic_center_ex1 = make_gauss(mesh, [sx, sy], [rx, ry], np.deg2rad(-rot)) 

aggregate_elliptic_right_ex1 = []
sxs = [32, 38, 42, 48, 52, 58]
for i in range(len(sxs)):
    rx, ry = 3, 7
    sx, sy = sxs[i], 32
    rot = 45
    elliptic_right_ex1 = make_gauss(mesh, [sx, sy], [rx, ry], np.deg2rad(-rot)) 
    aggregate_elliptic_right_ex1.append(elliptic_right_ex1)

rx, ry = 3, 7
sx, sy = 32, 32
rot = 45
elliptic_left_ex1 = make_gauss(mesh, [sx, sy], [rx, ry], np.deg2rad(-rot)) 

# rx, ry = 5, 5
# sx, sy = 52, 32
# rot = 0
# circle_right_ex1 = make_gauss(mesh, [sx, sy], [rx, ry], np.deg2rad(-rot)) 

aggregate_circle_right_ex1 = []
sxs = [32, 38, 42, 48, 52, 58]
for i in range(len(sxs)):
    rx, ry = 5, 5
    sx, sy = sxs[i], 32
    rot = 0
    elliptic_right_ex1 = make_gauss(mesh, [sx, sy], [rx, ry], np.deg2rad(-rot)) 
    aggregate_circle_right_ex1.append(elliptic_right_ex1)

# s0xy = get_centroid(mesh, fxy0)
# w0xy = get_covariance(mesh, fxy0, s0xy)
# s0xy = get_centroid(mesh, fxy0)
# w0xy = get_covariance(mesh, fxy0, s0xy)
# fxy1 = multivariate_normal.pdf(np.stack(mesh, -1), mean=s0xy, cov=w0xy)


In [None]:
from _prototypes.cell_remapping.src.remapping import pot_sliced_wasserstein
from _prototypes.cell_remapping.src.wasserstein_distance import _get_ratemap_bucket_midpoints

import itertools

"""""""""""""""""""""" EXAMPLE 1 """""""""""""""""""""
""" No overlap same shape = elliptic_left_ex1 and elliptic_right_ex1 """

fig = plt.figure(figsize=(15,3))

c = 0
for elliptic_right_ex1 in aggregate_elliptic_right_ex1:

    # ax = plt.subplot(1,5,1)
    # ax.imshow(elliptic_left_ex1, cmap='jet')
    ax = plt.subplot(1,6,c+1)
    ax.imshow(elliptic_right_ex1, cmap='jet')
    ax.imshow(elliptic_left_ex1, cmap='gist_stern', alpha=0.3)

    y, x = elliptic_left_ex1.shape
    height_bucket_midpoints, width_bucket_midpoints = _get_ratemap_bucket_midpoints(([1],[1]), y, x)
    buckets = np.array(list(itertools.product(np.arange(0,y,1),np.arange(0,x,1))))
    source_weights = np.array(list(map(lambda x: elliptic_left_ex1[x[0],x[1]], buckets)))
    target_weights = np.array(list(map(lambda x: elliptic_right_ex1[x[0],x[1]], buckets)))
    source_weights = source_weights / np.sum(source_weights)
    target_weights = target_weights / np.sum(target_weights)
    coord_buckets = np.array(list(itertools.product(height_bucket_midpoints,width_bucket_midpoints)))

    emd = pot_sliced_wasserstein(coord_buckets, coord_buckets, source_weights, target_weights)

    r, p = pearsonr(source_weights.flatten(), target_weights.flatten())

    # ax.set_title('EMD: ' + str(round(emd, 2)) + '; Pearson: ' + str(round(r, 2)))
    # + '; Pearson p: ' + str(round(p, 2)))

    ax.set_title(str(round(emd, 2)) + ' vs ' + str(round(r, 2)))

    c += 1

fig.suptitle('Ellipse translation, step-wise comparison with first ellipse - {EMD} vs {Pearson}')
fig.tight_layout()
plt.show()

"""""""""""""""""""""" EXAMPLE 1 - POPULATION FIGURES """""""""""""""""""""
""" No overlap same shape = elliptic_left_ex1 and elliptic_right_ex1 """

y, x = elliptic_left_ex1.shape
height_bucket_midpoints, width_bucket_midpoints = _get_ratemap_bucket_midpoints(([1],[1]), y, x)
buckets = np.array(list(itertools.product(np.arange(0,y,1),np.arange(0,x,1))))
source_weights = np.array(list(map(lambda x: elliptic_left_ex1[x[0],x[1]], buckets)))
source_weights = source_weights / np.sum(source_weights)
coord_buckets = np.array(list(itertools.product(height_bucket_midpoints,width_bucket_midpoints)))
# pdct = list(itertools.product(np.arange(0,64), np.arange(0,64)))
pop_shifted_emd = np.zeros(elliptic_left_ex1.shape)
pop_shifted_pearson = np.zeros(elliptic_left_ex1.shape)

def _sub1(i, j):
    # print(i,j)
    rx, ry = 3, 7
    sx, sy = i, j
    rot = 45
    elliptic_right_ex1 = make_gauss(mesh, [sx, sy], [rx, ry], np.deg2rad(-rot)) 

    target_weights = np.array(list(map(lambda y: elliptic_right_ex1[y[0],y[1]], buckets)))
    target_weights = target_weights / np.sum(target_weights)
            
    emd = pot_sliced_wasserstein(coord_buckets, coord_buckets, source_weights, target_weights)

    r, p = pearsonr(source_weights.flatten(), target_weights.flatten())
    return emd, r

def _sub2(i):
    emd_and_r = np.array(list(map(lambda x: _sub1(i, x), np.arange(0,len(pop_shifted_emd)))))
    # pop_shifted_emd[i] = emd_and_r[0]
    # pop_shifted_pearson[i] = emd_and_r[1]
    return emd_and_r[:,0], emd_and_r[:,1]


    # pop_shifted_emd = np.zeros(elliptic_left_ex1.shape)
    # pop_shifted_pearson = np.zeros(elliptic_left_ex1.shape)
    # for i in range(len(pop_shifted_emd)):
    #     emd_and_r = np.array(list(map(lambda x: _sub1(i, x), np.arange(0,len(pop_shifted_emd))))).T
    #     pop_shifted_emd[i] = emd_and_r[0]
    #     pop_shifted_pearson[i] = emd_and_r[1]
    # return pop_shifted_emd, pop_shifted_pearson

# pop_shifted_emd, pop_shifted_pearson 
combined = np.array(list(map(lambda x: _sub2(x), np.arange(0,len(pop_shifted_emd))))).T
pop_shifted_emd, pop_shifted_pearson = combined[:,0,:], combined[:,1,:]

# stop()

    # stop()
    # for j in range(len(pop_shifted_pearson)):
    #     def _sub1(i, j):
    #         rx, ry = 3, 7
    #         sx, sy = i, j
    #         rot = 45
    #         elliptic_right_ex1 = make_gauss(mesh, [sx, sy], [rx, ry], np.deg2rad(-rot)) 

    #         target_weights = np.array(list(map(lambda x: elliptic_right_ex1[x[0],x[1]], buckets)))
    #         target_weights = target_weights / np.sum(target_weights)
            
    #         emd = pot_sliced_wasserstein(coord_buckets, coord_buckets, source_weights, target_weights)

    #         r, p = pearsonr(source_weights.flatten(), target_weights.flatten())
    #         return emd, r
    #         # pop_shifted_emd[i,j] = emd
    #         # pop_shifted_pearson[i,j] = r


In [None]:
pop_shifted_emd.diagonal

In [None]:
fig, (ax1, ax2, ax4) = plt.subplots(1, 3, figsize=(26, 8))
# fig = plt.figure(figsize=(15,6))
ax1 = plt.subplot(1,5,1)
# plt.axis('square')
for i in range(64):
    p1 = ax1.plot(pop_shifted_emd[i],color='r',label='EMD')
ax1.set_xlabel('Field centre bin')
ax1.set_ylabel('EMD (cm)')
axtwin = ax1.twinx()
axtwin.set_ylabel('Pearson r')
for i in range(64):
    ptwin = axtwin.plot(pop_shifted_pearson[i],color='k',label='Pearson')
lns = p1+ptwin
labs = [l.get_label() for l in lns]
ax1.legend(lns, labs, loc='upper left')
# ax1.set_aspect('equal')
ax1.set_title('Horizontal translation for diff. field positions')

def diag_matrix(matrix):
    diags = [matrix[::-1,:].diagonal(i) for i in range(-3,4)]
    diags.extend(matrix.diagonal(i) for i in range(3,-4,-1))
    return [n.tolist() for n in diags]

ax2 = plt.subplot(1,5,2)

diag_emd = diag_matrix(pop_shifted_emd)
for i in range(len(diag_emd)):
    p2 = ax2.plot(diag_emd[i],color='r',label='EMD')
ax2.set_xlabel('Field centre bin')
ax2.set_ylabel('EMD (cm)')
axtwin2 = ax2.twinx()
axtwin2.set_ylabel('Pearson r')
diag_pearson = diag_matrix(pop_shifted_pearson)
for i in range(len(diag_pearson)):
    ptwin2 = axtwin2.plot(diag_pearson[i],color='k',label='Pearson')
lns = p2+ptwin2
labs = [l.get_label() for l in lns]
ax2.legend(lns, labs, loc='upper left')
ax2.set_title('Diagonal translation for diff. field positions')

ax3 = plt.subplot(1,5,3)
rotated_emd = np.array(list(map(lambda x: ndimage.rotate(pop_shifted_emd, x, reshape=False, mode='constant', cval=np.nan), np.arange(90,91,1))))
pdct_emd = list(itertools.product(np.arange(len(rotated_emd)), np.arange(len(rotated_emd[0]))))
list(map(lambda x: ax3.plot(rotated_emd[x[0],x[1]],color='r',label='EMD'), pdct_emd))
p3 = ax3.plot(rotated_emd[0,0],color='r',label='EMD')
# for i in range(len(rotated_emd)):
#     for j in range(len(rotated_emd[i])):
axtwin3 = ax3.twinx()
axtwin3.set_ylabel('Pearson r')
rotated_pearson = np.array(list(map(lambda x: ndimage.rotate(pop_shifted_pearson, x, reshape=False, mode='constant', cval=np.nan), np.arange(90,91,1))))
# for i in range(len(rotated_pearson)):
#     for j in range(len(rotated_pearson[i])):
pdct_pearson = list(itertools.product(np.arange(len(rotated_pearson)), np.arange(len(rotated_pearson[0]))))
list(map(lambda x: axtwin3.plot(rotated_pearson[x[0],x[1]],color='k',label='EMD'), pdct_pearson))
ptwin3 = axtwin3.plot(rotated_pearson[0,0],color='k',label='Pearson')
lns = p3+ptwin3
labs = [l.get_label() for l in lns]
ax3.legend(lns, labs, loc='upper left')
ax3.set_title('Angular translation for diff. field positions')
ax3.set_xlabel('Field centre bin')


ax4 = plt.subplot(1,5,4)
# ax4.imshow(pop_shifted_emd/np.sum(pop_shifted_emd), cmap='jet')
im = ax4.imshow(pop_shifted_emd, cmap='jet', aspect='auto')
ax4.set_title('Normalized EMD score for diff. field positions')
ax4.set_xlabel('Field centre bin')
# ax4.set_aspect('equal')
fig.colorbar(im, ax=ax4,fraction=0.046, pad=0.04)

ax5 = plt.subplot(1,5,5)
ax5.set_xlabel('Field centre bin')
# im = ax5.imshow(pop_shifted_pearson/np.sum(pop_shifted_pearson), cmap='jet')
im = ax5.imshow(pop_shifted_pearson, cmap='jet', aspect='auto')
ax5.set_title('Normalized pearson-r score for diff. field centre positions')
# ax5.set_aspect('equal')
fig.colorbar(im, ax=ax5,fraction=0.046, pad=0.04)

# asp = np.diff(ax1.get_xlim())[0] / np.diff(ax1.get_ylim())[0]
# asp /= np.abs(np.diff(ax2.get_xlim())[0] / np.diff(ax2.get_ylim())[0])
# ax1.set_aspect(asp)
# asp = np.diff(axtwin.get_xlim())[0] / np.diff(axtwin.get_ylim())[0]
# asp /= np.abs(np.diff(ax2.get_xlim())[0] / np.diff(ax2.get_ylim())[0])
# axtwin.set_aspect(asp)

# fig.suptitle('')
fig.tight_layout()
plt.show()

In [None]:
from _prototypes.cell_remapping.src.remapping import pot_sliced_wasserstein
from _prototypes.cell_remapping.src.wasserstein_distance import _get_ratemap_bucket_midpoints

import itertools

"""""""""""""""""""""" EXAMPLE 2 """""""""""""""""""""
""" No overlap different shape = elliptic_center_ex1 and circle_right_ex1 """

fig = plt.figure(figsize=(15,3))

c = 0
for circle_right_ex1 in aggregate_circle_right_ex1:

    # ax = plt.subplot(1,5,1)
    ax = plt.subplot(1,6,c+1)
    ax.imshow(circle_right_ex1, cmap='jet')
    ax.imshow(elliptic_left_ex1, cmap='gist_stern', alpha=0.3)

    y, x = elliptic_left_ex1.shape
    height_bucket_midpoints, width_bucket_midpoints = _get_ratemap_bucket_midpoints(([1],[1]), y, x)
    buckets = np.array(list(itertools.product(np.arange(0,y,1),np.arange(0,x,1))))
    source_weights = np.array(list(map(lambda x: elliptic_left_ex1[x[0],x[1]], buckets)))
    target_weights = np.array(list(map(lambda x: circle_right_ex1[x[0],x[1]], buckets)))
    source_weights = source_weights / np.sum(source_weights)
    target_weights = target_weights / np.sum(target_weights)
    coord_buckets = np.array(list(itertools.product(height_bucket_midpoints,width_bucket_midpoints)))

    emd = pot_sliced_wasserstein(coord_buckets, coord_buckets, source_weights, target_weights)

    r, p = pearsonr(source_weights.flatten(), target_weights.flatten())

    # ax.set_title('EMD: ' + str(round(emd, 2)) + '; Pearson: ' + str(round(r, 2)))
    # + '; Pearson p: ' + str(round(p, 2)))

    ax.set_title(str(round(emd, 2)) + ' vs ' + str(round(r, 2)))

    c += 1

fig.suptitle('Ellipse translation and transformation, step-wise comparison with first ellipse - {EMD} vs {Pearson}')
fig.tight_layout()
plt.show()

"""""""""""""""""""""" EXAMPLE 2 - POPULATION FIGURES """""""""""""""""""""
""" No overlap different shape = elliptic_center_ex1 and circle_right_ex1 """

y, x = elliptic_left_ex1.shape
height_bucket_midpoints, width_bucket_midpoints = _get_ratemap_bucket_midpoints(([1],[1]), y, x)
buckets = np.array(list(itertools.product(np.arange(0,y,1),np.arange(0,x,1))))
source_weights = np.array(list(map(lambda x: elliptic_left_ex1[x[0],x[1]], buckets)))
source_weights = source_weights / np.sum(source_weights)
coord_buckets = np.array(list(itertools.product(height_bucket_midpoints,width_bucket_midpoints)))
pop_shifted_emd = np.zeros(elliptic_left_ex1.shape)
pop_shifted_pearson = np.zeros(elliptic_left_ex1.shape)

def _sub1(i, j):
    rx, ry = 3, 7
    sx, sy = i, j
    rot = 45
    circle_right_ex1 = make_gauss(mesh, [sx, sy], [rx, ry], np.deg2rad(-rot)) 

    target_weights = np.array(list(map(lambda y: circle_right_ex1[y[0],y[1]], buckets)))
    target_weights = target_weights / np.sum(target_weights)
            
    emd = pot_sliced_wasserstein(coord_buckets, coord_buckets, source_weights, target_weights)

    r, p = pearsonr(source_weights.flatten(), target_weights.flatten())
    return emd, r

def _sub2(i):
    emd_and_r = np.array(list(map(lambda x: _sub1(i, x), np.arange(0,len(pop_shifted_emd)))))
    return emd_and_r[:,0], emd_and_r[:,1]

combined = np.array(list(map(lambda x: _sub2(x), np.arange(0,len(pop_shifted_emd))))).T
pop_shifted_emd, pop_shifted_pearson = combined[:,0,:], combined[:,1,:]



In [None]:
fig, (ax1, ax2, ax4) = plt.subplots(1, 3, figsize=(26, 8))
# fig = plt.figure(figsize=(15,6))
ax1 = plt.subplot(1,5,1)
# plt.axis('square')
for i in range(64):
    p1 = ax1.plot(pop_shifted_emd[i],color='r',label='EMD')
ax1.set_xlabel('Field centre bin')
ax1.set_ylabel('EMD (cm)')
axtwin = ax1.twinx()
axtwin.set_ylabel('Pearson r')
for i in range(64):
    ptwin = axtwin.plot(pop_shifted_pearson[i],color='k',label='Pearson')
lns = p1+ptwin
labs = [l.get_label() for l in lns]
ax1.legend(lns, labs, loc='upper left')
# ax1.set_aspect('equal')
ax1.set_title('Horizontal translation for diff. field positions')

def diag_matrix(matrix):
    diags = [matrix[::-1,:].diagonal(i) for i in range(-3,4)]
    diags.extend(matrix.diagonal(i) for i in range(3,-4,-1))
    return [n.tolist() for n in diags]

ax2 = plt.subplot(1,5,2)

diag_emd = diag_matrix(pop_shifted_emd)
for i in range(len(diag_emd)):
    p2 = ax2.plot(diag_emd[i],color='r',label='EMD')
ax2.set_xlabel('Field centre bin')
ax2.set_ylabel('EMD (cm)')
axtwin2 = ax2.twinx()
axtwin2.set_ylabel('Pearson r')
diag_pearson = diag_matrix(pop_shifted_pearson)
for i in range(len(diag_pearson)):
    ptwin2 = axtwin2.plot(diag_pearson[i],color='k',label='Pearson')
lns = p2+ptwin2
labs = [l.get_label() for l in lns]
ax2.legend(lns, labs, loc='upper left')
ax2.set_title('Diagonal translation for diff. field positions')

ax3 = plt.subplot(1,5,3)
rotated_emd = np.array(list(map(lambda x: ndimage.rotate(pop_shifted_emd, x, reshape=False, mode='constant', cval=np.nan), np.arange(90,91,1))))
pdct_emd = list(itertools.product(np.arange(len(rotated_emd)), np.arange(len(rotated_emd[0]))))
list(map(lambda x: ax3.plot(rotated_emd[x[0],x[1]],color='r',label='EMD'), pdct_emd))
p3 = ax3.plot(rotated_emd[0,0],color='r',label='EMD')
# for i in range(len(rotated_emd)):
#     for j in range(len(rotated_emd[i])):
axtwin3 = ax3.twinx()
axtwin3.set_ylabel('Pearson r')
rotated_pearson = np.array(list(map(lambda x: ndimage.rotate(pop_shifted_pearson, x, reshape=False, mode='constant', cval=np.nan), np.arange(90,91,1))))
# for i in range(len(rotated_pearson)):
#     for j in range(len(rotated_pearson[i])):
pdct_pearson = list(itertools.product(np.arange(len(rotated_pearson)), np.arange(len(rotated_pearson[0]))))
list(map(lambda x: axtwin3.plot(rotated_pearson[x[0],x[1]],color='k',label='EMD'), pdct_pearson))
ptwin3 = axtwin3.plot(rotated_pearson[0,0],color='k',label='Pearson')
lns = p3+ptwin3
labs = [l.get_label() for l in lns]
ax3.legend(lns, labs, loc='upper left')
ax3.set_title('Angular translation for diff. field positions')
ax3.set_xlabel('Field centre bin')


ax4 = plt.subplot(1,5,4)
# ax4.imshow(pop_shifted_emd/np.sum(pop_shifted_emd), cmap='jet')
im = ax4.imshow(pop_shifted_emd, cmap='jet', aspect='auto')
ax4.set_title('Normalized EMD score for diff. field positions')
ax4.set_xlabel('Field centre bin')
# ax4.set_aspect('equal')
fig.colorbar(im, ax=ax4,fraction=0.046, pad=0.04)

ax5 = plt.subplot(1,5,5)
ax5.set_xlabel('Field centre bin')
# im = ax5.imshow(pop_shifted_pearson/np.sum(pop_shifted_pearson), cmap='jet')
im = ax5.imshow(pop_shifted_pearson, cmap='jet', aspect='auto')
ax5.set_title('Normalized pearson-r score for diff. field centre positions')
# ax5.set_aspect('equal')
fig.colorbar(im, ax=ax5,fraction=0.046, pad=0.04)

# asp = np.diff(ax1.get_xlim())[0] / np.diff(ax1.get_ylim())[0]
# asp /= np.abs(np.diff(ax2.get_xlim())[0] / np.diff(ax2.get_ylim())[0])
# ax1.set_aspect(asp)
# asp = np.diff(axtwin.get_xlim())[0] / np.diff(axtwin.get_ylim())[0]
# asp /= np.abs(np.diff(ax2.get_xlim())[0] / np.diff(ax2.get_ylim())[0])
# axtwin.set_aspect(asp)

# fig.suptitle('')
fig.tight_layout()
plt.show()

In [None]:
N = 63  # kernel size
k1d = signal.gaussian(N, std=3.5).reshape(N, 1)
kernel = np.outer(k1d, k1d)
blank_map = np.zeros((64*3, 64*3))

single_field = np.copy(blank_map)
single_field[64+32, 64+32] = 1  
row, col = np.where(single_field == 1)
single_field[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel
single_field = single_field[64:64+64, 64:64+64]

circle_left_ex1 = single_field

aggregate_circle_right_ex2 = []
sxs = [32, 38, 42, 48, 52, 58]
for i in range(len(sxs)):
    single_field = np.copy(blank_map)
    single_field[64+32, 64+sxs[i]] = 1  
    row, col = np.where(single_field == 1)
    single_field[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel
    single_field = single_field[64:64+64, 64:64+64]
        
    aggregate_circle_right_ex2.append(single_field)


In [None]:
from _prototypes.cell_remapping.src.remapping import pot_sliced_wasserstein
from _prototypes.cell_remapping.src.wasserstein_distance import _get_ratemap_bucket_midpoints

import itertools

"""""""""""""""""""""" EXAMPLE 2 """""""""""""""""""""
""" No overlap different shape = elliptic_center_ex1 and circle_right_ex1 """

fig = plt.figure(figsize=(15,3))

c = 0
for circle_right_ex1 in aggregate_circle_right_ex2:

    # ax = plt.subplot(1,5,1)
    ax = plt.subplot(1,6,c+1)
    ax.imshow(circle_right_ex1, cmap='jet')
    ax.imshow(circle_left_ex1, cmap='gist_stern', alpha=0.3)

    y, x = circle_left_ex1.shape
    height_bucket_midpoints, width_bucket_midpoints = _get_ratemap_bucket_midpoints(([1],[1]), y, x)
    buckets = np.array(list(itertools.product(np.arange(0,y,1),np.arange(0,x,1))))
    source_weights = np.array(list(map(lambda x: circle_left_ex1[x[0],x[1]], buckets)))
    target_weights = np.array(list(map(lambda x: circle_right_ex1[x[0],x[1]], buckets)))
    source_weights = source_weights / np.sum(source_weights)
    target_weights = target_weights / np.sum(target_weights)
    coord_buckets = np.array(list(itertools.product(height_bucket_midpoints,width_bucket_midpoints)))

    emd = pot_sliced_wasserstein(coord_buckets, coord_buckets, source_weights, target_weights)

    r, p = pearsonr(source_weights.flatten(), target_weights.flatten())

    # ax.set_title('EMD: ' + str(round(emd, 2)) + '; Pearson: ' + str(round(r, 2)))
    # + '; Pearson p: ' + str(round(p, 2)))

    ax.set_title(str(round(emd, 2)) + ' vs ' + str(round(r, 2)))

    c += 1

fig.suptitle('Circle translation, step-wise comparison with first ellipse - {EMD} vs {Pearson}')
fig.tight_layout()
plt.show()

"""""""""""""""""""""" EXAMPLE 2 - POPULATION FIGURES """""""""""""""""""""
""" No overlap different shape = elliptic_center_ex1 and circle_right_ex1 """

y, x = circle_left_ex1.shape
height_bucket_midpoints, width_bucket_midpoints = _get_ratemap_bucket_midpoints(([1],[1]), y, x)
buckets = np.array(list(itertools.product(np.arange(0,y,1),np.arange(0,x,1))))
source_weights = np.array(list(map(lambda x: circle_left_ex1[x[0],x[1]], buckets)))
source_weights = source_weights / np.sum(source_weights)
coord_buckets = np.array(list(itertools.product(height_bucket_midpoints,width_bucket_midpoints)))
pop_shifted_emd = np.zeros(circle_left_ex1.shape)
pop_shifted_pearson = np.zeros(circle_left_ex1.shape)


# pdct = list(itertools.product(np.arange(0,64), np.arange(0,64)))

def _sub1(i, j):
    single_field = np.copy(blank_map)
    single_field[64+i, 64+j] = 1  
    row, col = np.where(single_field == 1)
    single_field[row[0]-(N//2):row[0]+(N//2)+1, col[0]-(N//2):col[0]+(N//2)+1] = kernel
    single_field = single_field[64:64+64, 64:64+64]
    target_weights = np.array(list(map(lambda y: single_field[y[0],y[1]], buckets)))
    target_weights = target_weights / np.sum(target_weights)
    emd = pot_sliced_wasserstein(coord_buckets, coord_buckets, source_weights, target_weights)
    r, p = pearsonr(source_weights.flatten(), target_weights.flatten())
    return emd, r

# aggregate_circle_right_ex1_2 = np.array(list(map(lambda x: _sub1(x[0], x[1]), pdct)))

def _sub2(i):
    emd_and_r = np.array(list(map(lambda x: _sub1(i, x), np.arange(0,len(pop_shifted_emd)))))
    return emd_and_r[:,0], emd_and_r[:,1]

combined = np.array(list(map(lambda x: _sub2(x), np.arange(0,len(pop_shifted_emd))))).T
pop_shifted_emd, pop_shifted_pearson = combined[:,0,:], combined[:,1,:]



In [None]:
fig, (ax1, ax2, ax4) = plt.subplots(1, 3, figsize=(26, 8))
# fig = plt.figure(figsize=(15,6))
ax1 = plt.subplot(1,5,1)
# plt.axis('square')
for i in range(64):
    p1 = ax1.plot(pop_shifted_emd[i],color='r',label='EMD')
ax1.set_xlabel('Field centre bin')
ax1.set_ylabel('EMD (cm)')
axtwin = ax1.twinx()
axtwin.set_ylabel('Pearson r')
for i in range(64):
    ptwin = axtwin.plot(pop_shifted_pearson[i],color='k',label='Pearson')
lns = p1+ptwin
labs = [l.get_label() for l in lns]
ax1.legend(lns, labs, loc='upper left')
# ax1.set_aspect('equal')
ax1.set_title('Horizontal translation for diff. field positions')

def diag_matrix(matrix):
    diags = [matrix[::-1,:].diagonal(i) for i in range(-3,4)]
    diags.extend(matrix.diagonal(i) for i in range(3,-4,-1))
    return [n.tolist() for n in diags]

ax2 = plt.subplot(1,5,2)

diag_emd = diag_matrix(pop_shifted_emd)
for i in range(len(diag_emd)):
    p2 = ax2.plot(diag_emd[i],color='r',label='EMD')
ax2.set_xlabel('Field centre bin')
ax2.set_ylabel('EMD (cm)')
axtwin2 = ax2.twinx()
axtwin2.set_ylabel('Pearson r')
diag_pearson = diag_matrix(pop_shifted_pearson)
for i in range(len(diag_pearson)):
    ptwin2 = axtwin2.plot(diag_pearson[i],color='k',label='Pearson')
lns = p2+ptwin2
labs = [l.get_label() for l in lns]
ax2.legend(lns, labs, loc='upper left')
ax2.set_title('Diagonal translation for diff. field positions')

ax3 = plt.subplot(1,5,3)
rotated_emd = np.array(list(map(lambda x: ndimage.rotate(pop_shifted_emd, x, reshape=False, mode='constant', cval=np.nan), np.arange(90,91,1))))
pdct_emd = list(itertools.product(np.arange(len(rotated_emd)), np.arange(len(rotated_emd[0]))))
list(map(lambda x: ax3.plot(rotated_emd[x[0],x[1]],color='r',label='EMD'), pdct_emd))
p3 = ax3.plot(rotated_emd[0,0],color='r',label='EMD')
# for i in range(len(rotated_emd)):
#     for j in range(len(rotated_emd[i])):
axtwin3 = ax3.twinx()
axtwin3.set_ylabel('Pearson r')
rotated_pearson = np.array(list(map(lambda x: ndimage.rotate(pop_shifted_pearson, x, reshape=False, mode='constant', cval=np.nan), np.arange(90,91,1))))
# for i in range(len(rotated_pearson)):
#     for j in range(len(rotated_pearson[i])):
pdct_pearson = list(itertools.product(np.arange(len(rotated_pearson)), np.arange(len(rotated_pearson[0]))))
list(map(lambda x: axtwin3.plot(rotated_pearson[x[0],x[1]],color='k',label='EMD'), pdct_pearson))
ptwin3 = axtwin3.plot(rotated_pearson[0,0],color='k',label='Pearson')
lns = p3+ptwin3
labs = [l.get_label() for l in lns]
ax3.legend(lns, labs, loc='upper left')
ax3.set_title('Angular translation for diff. field positions')
ax3.set_xlabel('Field centre bin')


ax4 = plt.subplot(1,5,4)
# ax4.imshow(pop_shifted_emd/np.sum(pop_shifted_emd), cmap='jet')
im = ax4.imshow(pop_shifted_emd, cmap='jet', aspect='auto')
ax4.set_title('Normalized EMD score for diff. field positions')
ax4.set_xlabel('Field centre bin')
# ax4.set_aspect('equal')
fig.colorbar(im, ax=ax4,fraction=0.046, pad=0.04)

ax5 = plt.subplot(1,5,5)
ax5.set_xlabel('Field centre bin')
# im = ax5.imshow(pop_shifted_pearson/np.sum(pop_shifted_pearson), cmap='jet')
im = ax5.imshow(pop_shifted_pearson, cmap='jet', aspect='auto')
ax5.set_title('Normalized pearson-r score for diff. field centre positions')
# ax5.set_aspect('equal')
fig.colorbar(im, ax=ax5,fraction=0.046, pad=0.04)

# asp = np.diff(ax1.get_xlim())[0] / np.diff(ax1.get_ylim())[0]
# asp /= np.abs(np.diff(ax2.get_xlim())[0] / np.diff(ax2.get_ylim())[0])
# ax1.set_aspect(asp)
# asp = np.diff(axtwin.get_xlim())[0] / np.diff(axtwin.get_ylim())[0]
# asp /= np.abs(np.diff(ax2.get_xlim())[0] / np.diff(ax2.get_ylim())[0])
# axtwin.set_aspect(asp)

# fig.suptitle('')
fig.tight_layout()
plt.show()

In [None]:
N = 3  # kernel size
k1d = signal.gaussian(N, std=3).reshape(N, 1)
kernel = np.outer(k1d, k1d)
blank_map = np.zeros((64*8, 64*8))

single_field = np.copy(blank_map)
for i in range(25):
    for j in range(25):
        if i % 2 == 0:
            single_field[N+i*16, N+j*16] = 1 
        else:
            single_field[N+i*16, N+j*16+8] = 1 
row, col = np.where(single_field == 1)
assert len(row) == len(col)
for i in range(len(row)):
    single_field[row[i]-(N//2):row[i]+(N//2)+1, col[i]-(N//2):col[i]+(N//2)+1] = kernel
single_field = ndimage.gaussian_filter(single_field, 2)
# single_field = single_field[64:64+64, 64:64+64]
single_field = single_field[N+64:64+N+64, N+64:64+N+64]

circle_left_ex1 = single_field

aggregate_circle_right_ex2 = []
sxs = [0,1,2,3,4,5]
# sxs = [64]
for i in range(len(sxs)):
    single_field = np.copy(blank_map)
    # single_field[64+32, 64+sxs[i]] = 1 
    for i2 in range(25):
        for j in range(25):
            if i2 % 2 == 0:
                single_field[N+i2*16, N+j*16 + sxs[i]] = 1 
            else:
                single_field[N+i2*16, N+j*16+8 + sxs[i]] = 1  
    row, col = np.where(single_field == 1)
    # print(row, col)
    for k in range(len(row)):
        # print(row[k], k, len(row),col[k])
        # print(row[k]-(N//2),row[k]+(N//2)+1, col[k]-(N//2),col[k]+(N//2)+1)
        single_field[row[k]-(N//2):row[k]+(N//2)+1, col[k]-(N//2):col[k]+(N//2)+1] = kernel
    single_field = ndimage.gaussian_filter(single_field, 2)
    single_field = single_field[N+64:64+N+64, N+64:64+N+64]
    # single_field = single_field[63:63+63, 63:63+63]

        
    aggregate_circle_right_ex2.append(single_field)


In [None]:
# for i in range(len(aggregate_circle_right_ex2)):
#     fig = plt.figure(figsize=(5,5))
#     ax = plt.subplot(1,1,1)
#     ax.imshow(aggregate_circle_right_ex2[i], cmap='jet')
#     rect = patches.Rectangle((64, 64), 64, 64, linewidth=1, edgecolor='r', facecolor='none')
#     ax.add_patch(rect)
#     ax.set_title('Grid Cell')
#     plt.show()

In [None]:
from _prototypes.cell_remapping.src.remapping import pot_sliced_wasserstein
from _prototypes.cell_remapping.src.wasserstein_distance import _get_ratemap_bucket_midpoints

import itertools

fig = plt.figure(figsize=(15,3))

c = 0
for circle_right_ex1 in aggregate_circle_right_ex2:

    # ax = plt.subplot(1,5,1)
    ax = plt.subplot(1,6,c+1)
    ax.imshow(circle_right_ex1, cmap='jet')
    ax.imshow(circle_left_ex1, cmap='gist_stern', alpha=0.3)

    y, x = circle_left_ex1.shape
    height_bucket_midpoints, width_bucket_midpoints = _get_ratemap_bucket_midpoints(([1],[1]), y, x)
    buckets = np.array(list(itertools.product(np.arange(0,y,1),np.arange(0,x,1))))
    source_weights = np.array(list(map(lambda x: circle_left_ex1[x[0],x[1]], buckets)))
    target_weights = np.array(list(map(lambda x: circle_right_ex1[x[0],x[1]], buckets)))
    source_weights = source_weights / np.sum(source_weights)
    target_weights = target_weights / np.sum(target_weights)
    coord_buckets = np.array(list(itertools.product(height_bucket_midpoints,width_bucket_midpoints)))

    emd = pot_sliced_wasserstein(coord_buckets, coord_buckets, source_weights, target_weights)

    r, p = pearsonr(source_weights.flatten(), target_weights.flatten())

    # ax.set_title('EMD: ' + str(round(emd, 2)) + '; Pearson: ' + str(round(r, 2)))
    # + '; Pearson p: ' + str(round(p, 2)))

    ax.set_title(str(round(emd, 2)) + ' vs ' + str(round(r, 2)))

    c += 1

fig.suptitle('Grid translation, step-wise comparison with first ellipse - {EMD} vs {Pearson}')
fig.tight_layout()
plt.show()

"""""""""""""""""""""" EXAMPLE 2 - POPULATION FIGURES """""""""""""""""""""
""" No overlap different shape = elliptic_center_ex1 and circle_right_ex1 """

y, x = circle_left_ex1.shape
height_bucket_midpoints, width_bucket_midpoints = _get_ratemap_bucket_midpoints(([1],[1]), y, x)
buckets = np.array(list(itertools.product(np.arange(0,y,1),np.arange(0,x,1))))
source_weights = np.array(list(map(lambda x: circle_left_ex1[x[0],x[1]], buckets)))
source_weights = source_weights / np.sum(source_weights)
coord_buckets = np.array(list(itertools.product(height_bucket_midpoints,width_bucket_midpoints)))
pop_shifted_emd = np.zeros(circle_left_ex1.shape)
pop_shifted_pearson = np.zeros(circle_left_ex1.shape)


# pdct = list(itertools.product(np.arange(0,64), np.arange(0,64)))

def _sub1(a, b):
    single_field = np.copy(blank_map)
    for i in range(25):
        for j in range(25):
            if i % 2 == 0:
                single_field[N+i*16+a, N+j*16+b] = 1 
            else:
                single_field[N+i*16+a, N+j*16+8+b] = 1 
    row, col = np.where(single_field == 1)
    assert len(row) == len(col)
    for i in range(len(row)):
        single_field[row[i]-(N//2):row[i]+(N//2)+1, col[i]-(N//2):col[i]+(N//2)+1] = kernel
    # single_field = single_field[64:64+64, 64:64+64]
    single_field = single_field[N+64:64+N+64, N+64:64+N+64]
    single_field = ndimage.gaussian_filter(single_field, 2)
    target_weights = np.array(list(map(lambda y: single_field[y[0],y[1]], buckets)))
    target_weights = target_weights / np.sum(target_weights)
    emd = pot_sliced_wasserstein(coord_buckets, coord_buckets, source_weights, target_weights)
    r, p = pearsonr(source_weights.flatten(), target_weights.flatten())
    return emd, r

# aggregate_circle_right_ex1_2 = np.array(list(map(lambda x: _sub1(x[0], x[1]), pdct)))

def _sub2(i):
    emd_and_r = np.array(list(map(lambda x: _sub1(i, x), np.arange(0,len(pop_shifted_emd)))))
    return emd_and_r[:,0], emd_and_r[:,1]

combined = np.array(list(map(lambda x: _sub2(x), np.arange(0,len(pop_shifted_emd))))).T
pop_shifted_emd, pop_shifted_pearson = combined[:,0,:], combined[:,1,:]



In [None]:
fig, _ = plt.subplots(1, 3, figsize=(20, 8))
# fig = plt.figure(figsize=(15,6))
ax1 = plt.subplot(1,3,1)
# plt.axis('square')
for i in range(64):
    p1 = ax1.plot(pop_shifted_emd[i],color='r',label='EMD')
ax1.set_xlabel('Field centre bin')
ax1.set_ylabel('EMD (cm)')
axtwin = ax1.twinx()
axtwin.set_ylabel('Pearson r')
for i in range(64):
    ptwin = axtwin.plot(pop_shifted_pearson[i],color='k',label='Pearson')
lns = p1+ptwin
labs = [l.get_label() for l in lns]
ax1.legend(lns, labs, loc='upper left')
# ax1.set_aspect('equal')
ax1.set_title('Horizontal translation for diff. field positions')

def diag_matrix(matrix):
    diags = [matrix[::-1,:].diagonal(i) for i in range(-3,4)]
    diags.extend(matrix.diagonal(i) for i in range(3,-4,-1))
    return [n.tolist() for n in diags]

ax2 = plt.subplot(1,3,2)

diag_emd = diag_matrix(pop_shifted_emd)
for i in range(len(diag_emd)):
    p2 = ax2.plot(diag_emd[i],color='r',label='EMD')
ax2.set_xlabel('Field centre bin')
ax2.set_ylabel('EMD (cm)')
axtwin2 = ax2.twinx()
axtwin2.set_ylabel('Pearson r')
diag_pearson = diag_matrix(pop_shifted_pearson)
for i in range(len(diag_pearson)):
    ptwin2 = axtwin2.plot(diag_pearson[i],color='k',label='Pearson')
lns = p2+ptwin2
labs = [l.get_label() for l in lns]
ax2.legend(lns, labs, loc='upper left')
ax2.set_title('Diagonal translation for diff. field positions')

ax3 = plt.subplot(1,3,3)
rotated_emd = np.array(list(map(lambda x: ndimage.rotate(pop_shifted_emd, x, reshape=False, mode='constant', cval=np.nan), np.arange(90,91,1))))
pdct_emd = list(itertools.product(np.arange(len(rotated_emd)), np.arange(len(rotated_emd[0]))))
list(map(lambda x: ax3.plot(rotated_emd[x[0],x[1]],color='r',label='EMD'), pdct_emd))
p3 = ax3.plot(rotated_emd[0,0],color='r',label='EMD')
# for i in range(len(rotated_emd)):
#     for j in range(len(rotated_emd[i])):
axtwin3 = ax3.twinx()
axtwin3.set_ylabel('Pearson r')
rotated_pearson = np.array(list(map(lambda x: ndimage.rotate(pop_shifted_pearson, x, reshape=False, mode='constant', cval=np.nan), np.arange(90,91,1))))
# for i in range(len(rotated_pearson)):
#     for j in range(len(rotated_pearson[i])):
pdct_pearson = list(itertools.product(np.arange(len(rotated_pearson)), np.arange(len(rotated_pearson[0]))))
list(map(lambda x: axtwin3.plot(rotated_pearson[x[0],x[1]],color='k',label='EMD'), pdct_pearson))
ptwin3 = axtwin3.plot(rotated_pearson[0,0],color='k',label='Pearson')
lns = p3+ptwin3
labs = [l.get_label() for l in lns]
ax3.legend(lns, labs, loc='upper left')
ax3.set_title('Angular translation for diff. field positions')
ax3.set_xlabel('Field centre bin')

fig, _ = plt.subplots(1, 2, figsize=(20, 8))

ax4 = plt.subplot(1,2,1)
# ax4.imshow(pop_shifted_emd/np.sum(pop_shifted_emd), cmap='jet')
im = ax4.imshow(pop_shifted_emd, cmap='jet', aspect='auto')
ax4.set_title('Normalized EMD score for diff. field positions')
ax4.set_xlabel('Field centre bin')
# ax4.set_aspect('equal')
fig.colorbar(im, ax=ax4,fraction=0.046, pad=0.04)

ax5 = plt.subplot(1,2,2)
ax5.set_xlabel('Field centre bin')
# im = ax5.imshow(pop_shifted_pearson/np.sum(pop_shifted_pearson), cmap='jet')
im = ax5.imshow(pop_shifted_pearson, cmap='jet', aspect='auto')
ax5.set_title('Normalized pearson-r score for diff. field centre positions')
# ax5.set_aspect('equal')
fig.colorbar(im, ax=ax5,fraction=0.046, pad=0.04)

# asp = np.diff(ax1.get_xlim())[0] / np.diff(ax1.get_ylim())[0]
# asp /= np.abs(np.diff(ax2.get_xlim())[0] / np.diff(ax2.get_ylim())[0])
# ax1.set_aspect(asp)
# asp = np.diff(axtwin.get_xlim())[0] / np.diff(axtwin.get_ylim())[0]
# asp /= np.abs(np.diff(ax2.get_xlim())[0] / np.diff(ax2.get_ylim())[0])
# axtwin.set_aspect(asp)

# fig.suptitle('')
fig.tight_layout()
plt.show()