In [2]:
import time
import scipy
import matplotlib.pyplot as plt
import matplotlib
import sys
import itertools
import json
import numpy as np
from odor_tracking_sim import utility
from pompy import models
# from matplotlib.widgets import Slider,Button
# from matplotlib.transforms import Bbox
from extras import UpdatingVPatch,plot_wedges
from core_functions import f0,f1,f1_wedge,f2,f3,f4,f5
from ipywidgets import interactive
import ipywidgets as widgets
from matplotlib.gridspec import GridSpec
import warnings
warnings.filterwarnings('ignore')

#Constants that don't change with drag bars
num_flies = 20000
fly_speed = 1.6

number_sources = 8
radius_sources = 1000.0
source_locations, _ = utility.create_circle_of_sources(number_sources,
                radius_sources,None)
source_pos = scipy.array([scipy.array(tup) for tup in source_locations])
release_location = np.zeros(2)

intended_heading_angles = np.random.uniform(0,2*np.pi,num_flies)
intended_heading_angles = np.linspace(0,2*np.pi,num_flies)

initial_cone_angle = np.radians(10.)

windmag_slider = widgets.FloatSlider(
    value=1.,
    min=0,
    max=4.0,
    step=0.1,
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
)

cone_angle_slider = widgets.FloatSlider(
    value=np.degrees(initial_cone_angle),
    min=0.,
    max=40.0,
    step=1.,
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
)

K_slider = widgets.FloatSlider(
    value=0.4,
    min=0.,
    max=1.0,
    step=0.1,
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
)

x_0_slider = widgets.FloatSlider(
    value=300.,
    min=0.,
    max=1000.0,
    step=10.,
    disabled=False,
    continuous_update=False,
    orientation='horizontal',
    readout=True,
    readout_format='.1f',
)

toggle = widgets.ToggleButton(
    value=False,
    description='Headings | Intersections',
    disabled=False,
    button_style='', # 'success', 'info', 'warning', 'danger' or ''
    tooltip='Description')


def f(wind_mag,toggle,cone_angle,K,x_0):

    #----OBTAIN NEW VALUES--------
    wind_angle = 7*scipy.pi/8.
    cone_angle = np.radians(cone_angle)

    release_times=0.

    K = -1*K
#     x_0 = 300

    #------RECOMPUTE OUTPUTS------
    track_heading_angles,dispersing_speeds = f0(intended_heading_angles,wind_mag,
        wind_angle)
    intersection_distances,dispersal_distances = f1_wedge(
    track_heading_angles,source_pos,wind_angle,cone_angle)
    success_probabilities = f2(intersection_distances,K,x_0,source_pos,wind_angle)
    plume_assignments = f3(success_probabilities,dispersal_distances)
    dispersal_travel_times,release_to_chosen_plume_distances = f4(
            plume_assignments,dispersal_distances,dispersing_speeds)
    arrival_times,chasing_times,\
    which_flies,which_traps = f5(plume_assignments,dispersal_travel_times,
        intersection_distances,fly_speed,release_times)

  
    #plot scaffolding
    gs = GridSpec(9, 9)
    fig = plt.figure(figsize=(10,10))
    
    #------"FIGURE" 1 
    ax =  fig.add_subplot(gs[0:4,0:4])   
    xlim = (-1500., 1500.)
    ylim = (-1500., 1500.)
    im_extents = xlim[0], xlim[1], ylim[0], ylim[1]

    ax.set_ylim(list(ylim))
    ax.set_xlim(list(xlim))
    
    if toggle:
        x,y = release_to_chosen_plume_distances*np.cos(track_heading_angles), \
            release_to_chosen_plume_distances*np.sin(track_heading_angles)
        
    else: 
        time = 5*60.
        mag = time*dispersing_speeds 
        x,y = mag*np.cos(track_heading_angles), \
            mag*np.sin(track_heading_angles)
            
    plt.scatter(x,y,alpha=0.02,color='r')
    
    wedge_points = plot_wedges(source_pos,wind_angle,cone_angle)

    plume_wedges = [matplotlib.patches.Polygon(
        wedge_points[:,i,:],color='black',alpha=0.2) for i in range(number_sources)]

    for plume_wedge in plume_wedges:
        ax.add_patch(plume_wedge)

    for x,y in source_locations:
        plt.scatter(x,y,marker='x',s=50,c='k')
    
    ax.set_aspect('equal')
    
    plt.xticks([])
    plt.yticks([])
    
    #------------"FIGURE" 2 : ARRIVAL CDFs---------------# 
    num_bins = 50

    trap_counts = scipy.zeros(8)
    rasters = []
    labels = ['N','NE','E','SE','S','SW','W','NW']
    sim_reorder = scipy.array([3,2,1,8,7,6,5,4])
    axes = []
    lines = []
    cdf_patches = []
    cdf_steepnesses = np.zeros(8)
    first_hit_times = np.full(8,np.nan)
    new_maxes = 400*np.ones(8)
    for i in range(8):

        row = sim_reorder[i]-1
        ax =  fig.add_subplot(gs[row,5:])    
        t_sim = arrival_times[which_traps==i]

        if len(t_sim)==0:
            ax.set_xticks([0,10,20,30,40,50])
            trap_total = 0
            pass
        else:
            t_sim = t_sim/60.
            (n, bins) = np.histogram(t_sim,bins=num_bins,
                range=(0,max(t_sim)))
            cum_n = np.cumsum(n)
            line, = plt.step(bins,np.hstack((np.array([0,]),cum_n)))
            lines.append(line)
 
            patch_object = UpdatingVPatch(min(t_sim),max(t_sim)-min(t_sim))
            ax.add_patch(patch_object.rectangle)
            cdf_patches.append(patch_object)
            try:
                trap_counts[i]=max(cum_n)
            except(IndexError):
                trap_counts[i]=0

            cdf_steepnesses[i] = trap_counts[i]/(max(t_sim)-min(t_sim))
            first_hit_times[i] = min(t_sim)
            new_maxes[i] = max(400.,50*np.ceil(max(cum_n)/50.))
    


        if sim_reorder[i]-1==0:
             ax.set_title('Cumulative Trap Arrivals')

        ax.set_xlim([0,50])
        plt.tick_params(
        axis='x',          # changes apply to the x-axis
        which='both',      # both major and minor ticks are affected
        bottom=True,      # ticks along the bottom edge are off
        top=False,         # ticks along the top edge are off
        labelbottom=True)
        ax.text(1.1,0.5,str(labels[sim_reorder[i]-1]),transform=ax.transAxes,fontsize=20,
            horizontalalignment='center',verticalalignment='center')
        if sim_reorder[i]-1==7:
            ax.set_xlabel('Time (min)',x=0.5,horizontalalignment='center',fontsize=20)
            plt.tick_params(axis='x', which='major', labelsize=15)
        else:
            ax.set_xticklabels('')
        axes.append(ax)

    for i,ax in enumerate(axes):
        ax.set_yticks([0,200,400,600,800])
        ax.set_ylim([0,np.max(new_maxes)])
        patch_object = cdf_patches[i]
        patch_object.rectangle.set_height(ax.get_ylim()[1])

  
    #------------"FIGURE" 3 : Trap Histograms---------------# 
    
    ax =  fig.add_subplot(gs[4:8,0:4])   
    ax.set_aspect('equal')
    steepness_max = 300.

    num_traps = np.shape(source_pos)[0]
    trap_locs = (2*np.pi/num_traps)*np.array(range(num_traps))
    #Set 0s to 1 for plotting purposes
    trap_counts[trap_counts==0] = .5
    radius_scale = 0.3
    plot_size = 1.5
    trap_locs_2d = [(scipy.cos(trap_loc),scipy.sin(trap_loc)) for trap_loc in trap_locs]
    trap_patches = [plt.Circle(center, size,
        alpha=min(cdf_steepnesses[i]/steepness_max,1.)) for center, size, i in zip(
            trap_locs_2d, radius_scale*trap_counts/np.max(new_maxes),range(8))]
    for trap_patch in trap_patches:
        ax.add_patch(trap_patch)

    vmin = 5.;vmax = 20.
    trap_cmap_vals = (first_hit_times-vmin)/vmax
    trap_cmap  = matplotlib.cm.get_cmap('plasma_r')
  
    for trap_cmap_val,trap_patch in zip(trap_cmap_vals,trap_patches):
        # trap_patch.set_color(trap_cmap(trap_cmap_val)[:-1])
        color = tuple(np.array((trap_cmap(trap_cmap_val)[:-1])).astype(float).tolist())
        trap_patch.set_color(color)

    ax.set_ylim([-plot_size,plot_size]);ax.set_xlim([-plot_size,plot_size])
    ax.set_xticks([])
    ax.set_xticklabels('')
    ax.set_yticks([])
    ax.set_yticklabels('')

    coll = matplotlib.collections.PatchCollection(trap_patches)#, facecolors=colors,edgecolors=colors)
    coll.set(cmap=trap_cmap,array=[])
    coll.set_clim(vmin=vmin,vmax=vmax)
    fig.colorbar(coll, ax=ax,pad=0.2)
    ax.text(2.1,.1,'First Arrival Time (min)',horizontalalignment='center',
        rotation=-90,verticalalignment='center',fontsize=15)
    #Wind arrow
    plt.arrow(0.5, 0.5, 0.1*scipy.cos(wind_angle), 0.1*scipy.sin(wind_angle),transform=ax.transAxes,color='b',
        width=0.01,head_width=0.05)
    fontsize=15
    ax.text(0,1.5,'N',horizontalalignment='center',verticalalignment='center',fontsize=fontsize)
    ax.text(0,-1.5,'S',horizontalalignment='center',verticalalignment='center',fontsize=fontsize)
    ax.text(1.5,0,'E',horizontalalignment='center',verticalalignment='center',fontsize=fontsize)
    ax.text(-1.5,0,'W',horizontalalignment='center',verticalalignment='center',fontsize=fontsize)
    ax.axis('off')

    #Tidy up plot
    plt.subplots_adjust(bottom=0.1, right=0.8, top=0.9,wspace=0.5,hspace=0.5)
    
    
interactive(f, wind_mag=windmag_slider,cone_angle=cone_angle_slider,K=K_slider,x_0=x_0_slider,toggle=toggle)
    

interactive(children=(FloatSlider(value=1.0, continuous_update=False, description='wind_mag', max=4.0, readout…