In [None]:
import numpy as np
import math
import matplotlib.pyplot as plt
from scipy.optimize import curve_fit
import h5py
import os
from tqdm import trange
import scipy.ndimage as ndimage

In [None]:
from numba import njit


@njit
def normpdf(x, mean, sd):
    var = float(sd)**2
    denom = (2*math.pi*var)**.5
    num = math.exp(-(float(x)-float(mean))**2/(2*var))
    return num/denom


@njit
def normpdf_mean0(x, sd):
    var = float(sd)**2
    denom = (2*math.pi*var)**.5
    num = math.exp(-(float(x))**2/(2*var))
    return num/denom


@njit
def distance_from(points, target):
    nr_points = points.shape[0];
    distance = np.zeros(nr_points);
    
    for index in range(nr_points):
        distance[index] = norm_2(points[index,0]-target[0], points[index,1]-target[1]);

    return distance


@njit
def angle_from(points, target):
    nr_points =  points.shape[0];
    angle_all = np.zeros(nr_points);
    
    for index in range(nr_points):
        vector_1 = points[index,:]
        vector_2 = target
        vector_diff = vector_1 - vector_2
        complex = vector_diff[0] + vector_diff[1]*1j
        angle_all[index] = np.angle(complex)

    return angle_all;


@njit
def BVC_single(r, θ, d, Φ, σ_rad, σ_ang):
    g1 = normpdf(r, d, σ_rad)
    diff_angle = θ - Φ
    if diff_angle > np.pi:
        diff_angle = diff_angle -2*np.pi
    if diff_angle < - np.pi:
        diff_angle = diff_angle +2*np.pi
        
    g2 = normpdf_mean0(diff_angle, σ_ang)
    g = g1*g2
    return g

@njit
def norm_2(x, y):
    return np.sqrt(x**2+y**2)



@njit
def BVC(which_pixels, d, Φ, σ_rad, σ_ang, A, c, boundary_array):
    nr_pixel = which_pixels.shape[0]
    f_all = np.zeros(nr_pixel)
    for i_pixel in range(nr_pixel):
        xy = which_pixels[i_pixel,:]
        r_all = distance_from(boundary_array, xy);
        θ_all = angle_from(boundary_array, xy);
        f = A*sum([BVC_single(r_all[i], θ_all[i], d, Φ, σ_rad, σ_ang) for i in range(len(r_all))])
        f_all[i_pixel] = f
    return f_all + c 

In [None]:
r_all = np.linspace(0,5,1000)
θ_all = np.linspace(-np.pi/2, np.pi/2,1000)

rv, θv = np.meshgrid(r_all, θ_all)

rv = np.float32(rv)
θv = np.float32(θv)

In [None]:
response = np.zeros((len(r_all), len(θ_all)))

In [None]:
d = 3
Φ = 0
σ_rad = 0.5
σ_ang = 0.1

for (i,r) in enumerate(r_all):
    for (j,θ) in enumerate(θ_all):  
        response[i,j] = BVC_single(r, θ, d, Φ, σ_rad, σ_ang)

In [None]:
response[response<0.5]=np.NaN

In [None]:
plt.imshow(response, cmap="jet")

In [None]:
fig = plt.figure(figsize=(3,3), dpi=250)
ax = fig.add_axes([0.1,0.1,0.8,0.8],polar=True)

ax.pcolormesh(θ_all,r_all,response, cmap="jet")
ax.axis("off")

ax.set_xlim(-np.pi/4, np.pi/4)
ax.set_ylim(0,5)


In [None]:
def arrowed_spines(fig, ax):

    xmin, xmax = ax.get_xlim() 
    ymin, ymax = ax.get_ylim()

    # removing the default axis on all sides:
    for side in ['bottom','right','top','left']:
        ax.spines[side].set_visible(False)

    # removing the axis ticks
    plt.xticks([]) # labels 
    plt.yticks([])
    ax.xaxis.set_ticks_position('none') # tick markers
    ax.yaxis.set_ticks_position('none')

    # get width and height of axes object to compute 
    # matching arrowhead length and width
    dps = fig.dpi_scale_trans.inverted()
    bbox = ax.get_window_extent().transformed(dps)
    width, height = bbox.width, bbox.height

    # manual arrowhead width and length
    hw = 1./20.*(ymax-ymin) 
    hl = 1./20.*(xmax-xmin)
    lw = 1. # axis line width
    ohg = 0.3 # arrow overhang

    # compute matching arrowhead length and width
    yhw = hw/(ymax-ymin)*(xmax-xmin)* height/width 
    yhl = hl/(xmax-xmin)*(ymax-ymin)* width/height

    # draw x and y axis
    ax.arrow(xmin, 0, xmax-xmin, 0., fc='k', ec='k', lw = lw, 
             head_width=hw, head_length=hl, overhang = ohg, 
             length_includes_head= True, clip_on = False) 

    ax.arrow(0, ymin, 0., ymax-ymin, fc='k', ec='k', lw = lw, 
             head_width=yhw, head_length=yhl, overhang = ohg, 
             length_includes_head= True, clip_on = False)





In [None]:
fig_path = "/home/chuyu/Notebooks/project_place_cell/figures/output/sfigure2/bvc"

In [None]:

fig = plt.figure(figsize=(1,1))
fig.set_facecolor('white') 
ax = plt.gca()
plt.bar(1,0.5, width=1, color="r")
plt.xlim(0,2)
plt.xticks([])
plt.yticks([])
plt.ylim(0,3)
arrowed_spines(fig, ax)
plt.tight_layout()
fig.savefig(os.path.join(fig_path, "bvc_response_1.pdf"),transparent = True);

In [None]:

fig = plt.figure(figsize=(1,1))
fig.set_facecolor('white') 
ax = plt.gca()
plt.bar(1,1.5, width=1, color="r")
plt.xlim(0,2)
plt.xticks([])
plt.yticks([])
plt.ylim(0,3)
arrowed_spines(fig, ax)
plt.tight_layout()
fig.savefig(os.path.join(fig_path, "bvc_response_2.pdf"),transparent = True);

In [None]:

fig = plt.figure(figsize=(1,1))
fig.set_facecolor('white') 
ax = plt.gca()
plt.bar(1,3, width=1, color="r")
plt.xlim(0,2)
plt.xticks([])
plt.yticks([])
plt.ylim(0,3)
arrowed_spines(fig, ax)
plt.tight_layout()
fig.savefig(os.path.join(fig_path, "bvc_response_3.pdf"),transparent = True);