In [1]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
from ipywidgets import interactive
import ipywidgets as widgets
import plotly as ply

In [2]:
class SinusoidTrace(object):
        
    def __init__(self,theta_0,a,omega,n_points=9000,t_start=0,t_stop=3):
        self.t = np.linspace(t_start,t_stop,n_points)
        self.dt = self.t[1]-self.t[0]
        envelope = a*np.exp(self.t) 
        self.theta_0 = theta_0
        #Want this to still work when there are multiple rows (trials) of traces    
        self.path = (np.pi/16)*np.sin((2*np.pi*omega)*self.t)*envelope+self.theta_0
        #Then store a version of this array where the bottom is 0 and angles are degrees
        self.path_origin_shifted_degrees = (self.path - theta_0)*(180/np.pi)
        self.angle_bin_size = np.pi/18 #10 degrees in radians

    def find_reversal_inds(self):
        reversal_bool = (np.abs(np.diff(self.path))<1e-3)
        reversal_inds = np.where(reversal_bool)[0]
        non_duplicate_inds = np.diff(np.hstack((np.array(0),reversal_inds)))>1
        reversal_inds = reversal_inds[non_duplicate_inds]
        #Want this to still work when there are multiple rows (trials) of traces    
        return reversal_inds
    
    def reversal_loc_hist(self,shifted=False):
        
        reversal_inds = self.find_reversal_inds()
        if shifted:
            bins = np.arange(-180,180,np.degrees(self.angle_bin_size))
            n,bins = np.histogram(self.path_origin_shifted_degrees[reversal_inds],bins=bins)
        else:
            bins = np.arange(0,2*np.pi,self.angle_bin_size)
            n,bins = np.histogram(self.path[reversal_inds],bins=bins)
        return n,bins
    
    def count_double_crosses(self,lower_bound,upper_bound):
        
        #Returns a count of the total number of times the (1D) path 
        #either {passes through upper_bound decreasing  and
        #immediately passes through lower_bound decreasing}
        #or {passes through lower_bound increasing  and
        #immediately passes through upper_bound increasing}

        epsilon = np.abs(np.diff(np.hstack((self.path[0],self.path))))/2
#         print(epsilon)
        upper_bound_crosses = np.abs(self.path-upper_bound)<epsilon
        lower_bound_crosses = np.abs(self.path-lower_bound)<epsilon


        return (upper_bound_crosses|lower_bound_crosses)

        #return double_cross_count        
    
    
    def compute_transit_counts(self):
        #input shape: (trials x timestamps) 
        #Return the binned transit counts
        #output shape: (trials x angle bins)
        return n
    
    def draw_trials(num_trials):
        #For the already inputted trace parameter values, draw a trace
        #num_trials times (drawing noise anew each trial),
        #and compute the transit vector for each trial
        #to return a 2d array of draws x theta (used for heatmap)

        return transits 


In [5]:
def slider(start,stop,step,init):
    return widgets.FloatSlider(
    value=init,
    min=start,
    max=stop,
    step=step,
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.2f'
        
)



In [4]:
def f(theta_0,a,omega,t_stop,phi):
    phi_rad = np.radians(phi)
    sinusoidTrace = SinusoidTrace(theta_0,a,omega,t_stop=t_stop)
    x = sinusoidTrace.t
    y = sinusoidTrace.path
    fig = plt.figure(1,figsize=(15,15),dpi=200)
    gs = GridSpec(7,7)
    ax1 = fig.add_subplot(gs[0:3,0:3])
    plt.plot(x,y%(2*np.pi),'o',markersize=1)
    
#     #Scaffold plot: derivative of trace
#     plt.plot(x[:-1],15*np.diff(y),'o',markersize=1)
#     plt.figure(10)
#     plt.hist(np.diff(y),bins=100)
#     plt.figure(1)
    
    
    #Plot cyan dots for the reversal points
    reversal_inds = sinusoidTrace.find_reversal_inds()
    plt.scatter(x[reversal_inds],(y%(2*np.pi))[reversal_inds],color='cyan')
    
    #Plot the center (origin) angle and the food angles
    plt.plot([np.min(x),np.max(x)],[theta_0,theta_0],'g')  
    plt.plot([np.min(x),np.max(x)],[theta_0+phi_rad,theta_0+phi_rad],'r')
    plt.plot([np.min(x),np.max(x)],[theta_0-phi_rad,theta_0-phi_rad],'r')
    
    
    ax1.set_ylim([0,2*np.pi])
    ax1.set_aspect(1./3)
    ax1.spines['right'].set_visible(False)
    ax1.spines['top'].set_visible(False)
    plt.xlabel('Time post activation period (min)')
    plt.ylabel('Angular position (radians)')
    
        #testing crossing function
    transit_inds = sinusoidTrace.count_double_crosses(theta_0-phi_rad,
                            theta_0+phi_rad)
    print(x[transit_inds])
    plt.scatter(x[transit_inds],y[transit_inds],color='orange')


    ax2 = fig.add_subplot(gs[0:2,4:6],polar=True)
    plt.plot(y,x)
    plt.plot([theta_0,theta_0],[np.min(x),np.max(x)],'g')
    plt.scatter(y[reversal_inds],x[reversal_inds],color='cyan')
    plt.plot([theta_0+phi_rad,theta_0+phi_rad],[np.min(x),t_stop],'r')
    plt.plot([theta_0-phi_rad,theta_0-phi_rad],[np.min(x),t_stop],'r')
    ax2.set_ylim([0,t_stop*1.1])
    plt.grid(False)
    plt.yticks([])
   
   
    


    ax3 = fig.add_subplot(gs[3:7,1:5])
    n,bins = sinusoidTrace.reversal_loc_hist(shifted=True)
    plt.plot(bins[:-1],n/(np.sum(n)),'-o')
    plt.subplots_adjust(left=0.25, bottom=0.25)
    
    #Plot the center (origin) angle and the food angles
    plt.plot([0,0],[0,max(n)/np.sum(n)],'g')  
    plt.plot([np.degrees(phi_rad),np.degrees(phi_rad)],[0,max(n)/np.sum(n)],'r')  
    plt.plot([-np.degrees(phi_rad),-np.degrees(phi_rad)],[0,max(n)/np.sum(n)],'r')  
    
    plt.xlabel('Angle (deg)')
    plt.ylabel('Proportion of Reversals')

interactive(f, 
            theta_0 = slider(0,2*np.pi,0.1,3*np.pi/2),
            a = slider(0,1,0.05,0.5),
            omega = slider(0,8,0.05,3),
            t_stop= slider(1,10,0.5,3.),
            phi=slider(0,90,5,10)
           )
#theta_0: center of oscillation
#a: amplitude of oscillation (scaling factor)
#omega: frequency of oscillation (wrt time)

interactive(children=(FloatSlider(value=4.71238898038469, continuous_update=False, description='theta_0', max=…