# Larsch & Baier 2018
## Figure 1 and supplementary figure 3


In [None]:
%config InteractiveShellApp.pylab_import_all = False
%matplotlib inline
%pylab inline
%reload_ext autoreload
%autoreload 2

import sys
import os
import fnmatch

import numpy as np
import math
import matplotlib.pyplot as plt
import pandas as pd
from pandas import DataFrame, Series
import seaborn as sns
import glob
from datetime import datetime
from scipy import stats

if 'startDirMaster' not in locals():
    startDirMaster=os.getcwd()

propsFn=startDirMaster+'\\props.csv'
props=pd.read_csv(propsFn, header=None, index_col=0, squeeze=True,delim_whitespace=True).to_dict()

base=props['BaseDir']
expFile=props['allExpFn']

RawDataDir = os.path.join(base,props['RawDataDir'])+'\\'
ProcessingDir = os.path.join(base,props['ProcessingDir'])+'\\Fig1SkypeAge\\'
ProcessingDir2 = os.path.join(base,props['ProcessingDir'])+'\\Fig1TruePairVsSkype\\'

outputDir = os.path.join(base,props['outputDir'])+'\\'

if not os.path.isdir(ProcessingDir):
    os.makedirs(ProcessingDir)
if not os.path.isdir(outputDir):
    os.makedirs(outputDir)

os.chdir('..\\')
import functions.matrixUtilities_joh as mu
import matplotlib.pyplot as plt
import models.experiment as xp
import models.experiment_set as es
import functions.paperFigureProps as pfp
pfp.paper()
inToCm=2.54


In [None]:
info=pd.read_csv(expFile, sep=',')#pd.read_csv(expFile,quotechar='"', sep=',', converters={'bdGroup':ast.literal_eval})
info=info[info.stimulusProtocol=='7']
info.head()

Pre-Analyze all experiments only if necessary, this takes a couple of minutes! Experiment summary csv files are saved to disk.

In [None]:
#
# collect meta information and save to new csv file for batch processing

aviPath=[]
posPath=[]
PLPath=[]
expTime = []
    
for index,row in info.iterrows():
    startDir=os.path.join(RawDataDir,row.path)+'\\'
    posPath.append(glob.glob(startDir+'PositionTxt*.txt')[0])
    PLPath.append(glob.glob(startDir+'PL*.txt')[0])
    
    head, tail = os.path.split(posPath[-1])
    currTime=datetime.strptime(tail[-23:-4], '%Y-%m-%dT%H_%M_%S')
    expTime.append(currTime)
    
info['txtPath']=posPath
info['pairList']=PLPath

info['epiDur'] = 5      # duration of individual episodes (default: 5 minutes)
info['episodes'] = -1   # number of episodes to process: -1 to load all episodes (default: -1)
info['inDish'] = 10#np.arange(len(posPath))*120     # time in dish before experiments started (default: 10)
info['arenaDiameter_mm'] = 100 # arena diameter (default: 100 mm)
info['minShift'] = 60 # minimum number of seconds to shift for control IAD
info['episodePLcode'] = 0 # flag if first two characters of episode name encode animal pair matrix (default: 0)
info['recomputeAnimalSize'] = 1 # flag to compute animals size from avi file (takes time, default: 1)
info['SaveNeighborhoodMaps'] = 1 # flag to save neighborhood maps for subsequent analysis (takes time, default: 1)
info['computeLeadership'] = 0 # flag to compute leadership index (takes time, default: 1)
info['ComputeBouts'] = 1 # flag to compute swim bout frequency (takes time, default: 1)
info['set'] = np.arange(len(posPath))   # experiment set: can label groups of experiments (default: 0)
info['ProcessingDir']=ProcessingDir
info['outputDir']=outputDir

info['expTime']=expTime


csvFile=os.path.join(ProcessingDir,'Fig1_skypeAge_process.csv')
info.to_csv(csvFile,encoding='utf-8')
    
info.head()

In [None]:
#recompute all experiments - this takes a long time (15 minutes?)
expSet=es.experiment_set(csvFile=csvFile)

In [None]:
#read pre-analyzed experiment summaries
#inputFiles=pd.read_csv(csvFile,sep=',')

df=[]
anTotal=0
for f in glob.glob(ProcessingDir+'\\*_siSummary_epi*.csv'):
    dfNew=pd.read_csv(f,index_col=0,sep=',')
    
    dfNew['animalIdRun']=dfNew.animalIndex.values + anTotal
    anTotal+=dfNew['animalIndex'].max()+1
    
    #AnSizeFileOut=currTxt[:-4]+'_anSize.csv'
    #anSize= np.loadtxt(AnSizeFileOut)
    #dfNew['anSize']=anSize[dfNew.animalIndex.values]
    df.append(dfNew)
    
df=pd.concat(df, axis=0)
df['episode']=[x.strip().replace('_','') for x in df['episode']]

d=df.time
r=datetime(2017,1,1)
t2=[pd.to_datetime(x).replace(day=1,month=1)for x in df.time]

t3=[(x-r)/pd.Timedelta('1 hour') for x in t2]
df['t2']=t2
df['t3']=t3

print(df.shape)
df.head()

In [None]:
df.animalSet.unique()

In [None]:
anList=df.animalIdRun.unique()
IADsArray=np.zeros((2*anList.shape[0],df.epiNr.unique().shape[0],10))*np.nan
expList=df.animalSet.unique()

ai=0
for e in expList:
    dfTmp=df[df.animalSet==e]
    anList=dfTmp.animalIdRun.unique()
    

    for a in anList:
        idx3=dfTmp.animalIdRun==a
        pl=np.where(idx3)[0]
        print('computing: ',e,a,idx3.sum())
        ep=0
        for p in pl:
            expIAD=np.array(expSet.experiments[e].pair[p].IADs())
            expIADm=np.nanmean(expIAD,axis=1)

            for i in range(9):
                ShiftAttract=(expIADm[i]-expIADm[-1])/expIADm[i]
                IADsArray[ai,ep,i]=ShiftAttract
            ShiftAttract=(expIADm[:-1].mean()-expIADm[-1])/expIADm[:-1].mean()
            IADsArray[ai,ep,9]=ShiftAttract
            ep+=1
        ai+=1
                


In [None]:
IADsMean=IADsArray[:,:,9]
IADsArray=IADsArray[:,:,:9]
IADsArray.shape

In [None]:
np.isfinite(np.nanmean(IADsArray,axis=2)).sum(axis=1)

In [None]:
pfp.paper()
sns.set_palette('viridis',3)

fig, ax = plt.subplots(figsize=(6/inToCm,4.5/inToCm))

rav=IADsArray.ravel()

h=ax.hist(rav[np.isfinite(rav)],bins=100)
ax.set_xlabel('Control attraction in shifted pairs')
ax.set_ylabel('Count')
ax.set_xticks([-.4,0,.4])
ax.set_xlim([-.8,.8])
#ax.text(.35,750,'Mean: '+"{:1.3f}".format(np.mean(IADsAll)),color='r')
#ax.text(.35,650,'Std: '+"{:1.3f}".format(np.std(IADsAll)))
#ax.axvline(0,linestyle=':',color='gray')
ax.axvline(np.nanmean(IADsArray),linestyle=':',color='r')

l,u=stats.t.interval(0.95, len(IADsArray.ravel())-1, loc=np.nanmean(IADsArray), scale=np.nanstd(IADsArray)/np.sqrt(1))
ax.axvline(u,linestyle=':',color='k')
ax.axvline(l,linestyle=':',color='k')
#ax.text(.35,550,'CI95: '+"{:0.3f}".format(l)+'-'+"{:1.3f}".format(u))
print(np.nanmean(IADsArray),np.nanstd(IADsArray),l,u)
sns.despine()

In [None]:
pfp.paper()
sns.set_palette('viridis',3)

fig, ax = plt.subplots(figsize=(6/inToCm,4.5/inToCm))

rav=IADsMean.ravel()

h=ax.hist(rav[np.isfinite(rav)],bins=20)
ax.set_xlabel('Control attraction in shifted pairs \n 5 minutes of data each')
ax.set_ylabel('Count')
ax.set_xticks([-.4,0,.4])
ax.set_xlim([-.5,.5])
#ax.text(.35,750,'Mean: '+"{:1.3f}".format(np.mean(IADsAll)),color='r')
#ax.text(.35,650,'Std: '+"{:1.3f}".format(np.std(IADsAll)))
#ax.axvline(0,linestyle=':',color='gray')
ax.axvline(np.mean(IADsArray),linestyle=':',color='r')

l,u=stats.t.interval(0.95, len(IADsArray.ravel())-1, loc=np.nanmean(IADsMean), scale=np.nanstd(IADsMean))
ax.axvline(u,linestyle=':',color='k')
ax.axvline(l,linestyle=':',color='k')
#ax.text(.35,550,'CI95: '+"{:0.3f}".format(l)+'-'+"{:1.3f}".format(u))
print(np.nanmean(IADsMean),np.nanstd(IADsMean),l,u)
sns.despine()

In [None]:
si=df[df.episode=='lead00'].si
si=si[np.isfinite(si)]

pfp.paper()
sns.set_palette('viridis',3)

fig, ax = plt.subplots(figsize=(6/inToCm,4.5/inToCm))

H,bins=np.histogram(si,bins=20,density=True)
bincentres = [(bins[i]+bins[i+1])/2. for i in range(len(bins)-1)]
ax.step(bincentres,H,
        where='mid',
        color='b',
        linestyle='-',
       label='NullStimulus')

rav=IADsMean.ravel()
rav=rav[np.isfinite(rav)]
H2,bins=np.histogram(rav,bins=bins,density=True)
ax.step(bincentres,H2,
        where='mid',
        color='g',
        linestyle='-',
       label='ShiftPair')


print(df[df.episode=='lead00'].si.std()*2)
ax.set_xlabel('Control attraction in shifted pairs \n 5 minutes data chunks')
ax.set_ylabel('Count')
ax.set_xticks([-.4,0,.4])
ax.set_xlim([-.5,.5])
#ax.text(.35,750,'Mean: '+"{:1.3f}".format(np.mean(IADsAll)),color='r')
#ax.text(.35,650,'Std: '+"{:1.3f}".format(np.std(IADsAll)))
#ax.axvline(0,linestyle=':',color='gray')
ax.axvline(np.mean(si),linestyle=':',color='r')

l,u=stats.t.interval(0.95, len(si)-1, loc=np.nanmean(si), scale=np.nanstd(si))
ax.axvline(u,linestyle=':',color='k')
ax.axvline(l,linestyle=':',color='k')
#ax.text(.35,550,'CI95: '+"{:0.3f}".format(l)+'-'+"{:1.3f}".format(u))
print(np.nanmean(si),np.nanstd(si),l,u)
sns.despine()
ax.legend(loc='right',bbox_to_anchor=(1.3, 0.7))
figPath=outputDir+'\\S1c_SI5minNullStimVsShift.svg'
plt.savefig(figPath)

In [None]:
CI=np.zeros(100)
CIan=np.ones(100)
for i in range(100):
    CI[i]=np.nanstd(np.nanmean(IADsArray[:,:i+1,:],axis=1).mean(axis=1))
    CIan[i]=np.nanstd(np.nanmean(IADsMean[:,:i+1],axis=1))
plt.plot(CI*2,'.')
plt.plot(2*CIan,'.')

In [None]:
pfp.paper()
sns.set_palette('viridis',3)
fig, ax = plt.subplots(figsize=(6/inToCm,4.5/inToCm))

x=np.arange(1,400/5.)
ax.plot(x*5,
        2*np.nanstd(np.nanmean(IADsArray,axis=2))/np.sqrt(x),
        'r',
        label='Normal distribution',
       linewidth=2)

ax.plot(x[:50]*5,
        CI[:50]*2,
        '.k-',
        label='ShiftPair data',
        markerSize=5,
       alpha=1)

ax.legend()
ax.set_xlabel('Recording duration (Minutes)')
ax.set_ylabel('Attraction CI95 +/-')
ax.set_xlim([0,240])
ax.set_ylim([0,.3])
sns.despine()
ax.set_xticks(np.arange(0,250,60));

figPath=outputDir+'\\S1d_SI5minShiftPair.svg'
plt.savefig(figPath)

In [None]:
CI[12]*2

## Habituation or Fatigue within 20 hours?

Plot shoaling index during closed loop skype episodes over time.

In [None]:
sns.tsplot(data=df, time="t3",value="si",unit="animalIdRun",condition="episode",estimator=np.nanmean,interpolate=False,err_style="ci_bars");
plt.xlim([0,24])

# size tuning per age group

In [None]:
pxPmm_project=1280/600.
print(pxPmm_project)

In [None]:
(240-25)/25

In [None]:
idx=(df['inDishTime']<240) & (df['inDishTime']>25) & (df.animalIndex<14) & (df.age!=21)
episodeNames=df['episode'].unique()
dfDR=df[idx]
tmp=dfDR.groupby(['episode','age'],sort=True)['si']
xax=np.array([x[-2:] for x in episodeNames]).astype('int')/pxPmm_project
#xax=xax*25*10
xax.sort()

err=tmp.std().unstack().values.T

co=sns.color_palette("Dark2", dfDR['age'].unique().shape[0])

fig, axes = plt.subplots(nrows=2, ncols=1, sharex='col', sharey=False,
                               gridspec_kw={'height_ratios': [2, 1]},
                               figsize=(10, 7))


xt=xax.copy()
xt[0]=0

dfx=tmp.mean().unstack().reset_index()
dfx['xax']=xax
lab=dfx.columns

axes[0]=dfx.plot(x='xax',kind='line',marker='o',yerr=err,
                                  linestyle=':',ax=axes[0],color=co,legend=False,
                                 xticks=xt)

for s in xax:
    c=plt.Circle((s,0),s/20,color='k')
    axes[1].add_artist(c)
    
#axes[1].set_aspect('equal')
axes[1].set_xlabel('disc diameter [mm]')
axes[0].set_ylabel('attraction index')
plt.xlim([0,xax.max()+xax.max()*0.1])
plt.setp(axes[1].get_yticklabels(), visible=False)
axes[1].set_ylim([-1,1]);
lines, labels = axes[0].get_legend_handles_labels()
axes[0].legend([lines[x] for x in[0,1,2,3]], [labels[x] for x in[0,1,2,3]], labels=lab,loc='center left', bbox_to_anchor=(1, 0.5),title='age')
axes[0].axhline(0,ls=':',color='k')
#axes[0].xaxis.set_major_formatter(FormatStrFormatter('%0.1f'))
axes[0].set_title('Mean Disc Size Tuning per age group');


In [None]:
dfDR.loc[:,'ag']=0
dfDR.loc[dfDR.age>13,'ag']=1
dfDR.loc[dfDR.age>16,'ag']=2
dfDR.loc[dfDR.age>21,'ag']=3

In [None]:
sns.set_palette('viridis',3)
inToCm=2.54

plt.figure(figsize=(4.5/inToCm,4.5/inToCm))
ax = plt.gca()

tmp_g=dfDR.groupby(['episode','ag'],sort=True)['si']

err=tmp_g.std().unstack().values.T

co=sns.color_palette("viridis", dfDR['ag'].unique().shape[0])


dfx=tmp_g.mean().unstack().reset_index()
dfx['xax']=xax

dfx.plot(x='xax',
         kind='line',
         marker='o',
         linestyle=':',
         ax=ax,color=co,
         legend=False)

ax.set_ylabel('Virtual attraction')
ax.set_xlabel('Dot diameter (mm)')
ax.set_yticks([0,.5,1]);
ax.set_ylim([-.1,1]);

plt.xlim([0,xax.max()+xax.max()*0.1])
#lines, labels = ax.get_legend_handles_labels()
#ax.legend([lines[x] for x in[0,1,2,3]], [labels[x] for x in[0,1,2,3]], labels=lab,loc='center left', bbox_to_anchor=(1, 0.5),title='age')
ax.axhline(0,ls=':',color='k')
#ax.set_title('Preferred dot size');
fig.subplots_adjust(top=0.75)
L=plt.legend(bbox_to_anchor=(0.1, 1.1), loc=2, borderaxespad=0.,handletextpad=0,title='Age (dpf)',ncol=2)
#L=plt.legend(ncol=1, loc='upper right',)
#L.get_title().set_position((-100, -20))

L.get_texts()[0].set_text('<14')
L.get_texts()[1].set_text('14-16')
L.get_texts()[2].set_text('17-21')
L.get_texts()[3].set_text('>21')
sns.despine()
figPath=outputDir+'\\1Fa_DotSizePrefCurves.svg'
plt.savefig(figPath)

In [None]:
tt=dfDR.groupby(['age','episode','animalIdRun'],sort=True).mean().reset_index()
tt['ind']=tt.age.astype('str')+tt.animalIdRun.astype('str')
ttP=tt.pivot_table(index='episode',columns='ind',values='si').reset_index()

In [None]:
ageInd=np.array([x[:2] for x in ttP.columns[1::2]]).astype(int)
print(ageInd)
ageChng=np.hstack([0,np.where(ageInd[:-1]!=ageInd[1:])[0],ageInd.shape[0]-1])
print(ageChng)
minTick= ageChng[1:]-np.hstack([(ageChng[1:]-ageChng[:-1])])/2
minTick
print(ageInd[ageChng][1:])

In [None]:
from mpl_toolkits.axes_grid1 import make_axes_locatable

fig, ax = plt.subplots(figsize=(4.5/inToCm,4.5/inToCm))

ax = plt.gca()
im=ax.imshow(ttP.values[:,1::2].astype('float').T,interpolation='nearest',cmap='viridis',aspect='auto')

# create an axes on the right side of ax. The width of cax will be 5%
# of ax and the padding between cax and ax will be fixed at 0.05 inch.


yt=[x[:3] for x in xt.astype('str')]

plt.xticks( np.arange(5), yt )

plt.yticks( ageChng, '' )

ax.set_yticks(minTick,minor=True)
ax.set_yticklabels(ageInd[ageChng][1:].astype('str'), minor=True,fontsize=7 )
ax.tick_params(axis='y', which='minor',length=0)
    
ax.set_ylabel('age dpf')
ax.set_xlabel('dot diameter [mm]');
sns.despine()

divider = make_axes_locatable(ax)
cax = divider.append_axes("right", size="5%", pad=0.05)

cb1=plt.colorbar(im, cax=cax,ticks=[0,0.5])
cb1.set_label('attraction index')


In [None]:
tt=dfDR.groupby(['age','episode','animalIdRun'],sort=True).mean().reset_index()
tt['ind']=tt.age.astype('str')+tt.animalIdRun.astype('str')
ttP=tt.pivot_table(index='episode',columns='ind',values='si').reset_index()
ttP=ttP.drop('episode',axis=1)
ttP

In [None]:
from scipy.interpolate import interp1d

fig, ax = plt.subplots(nrows=10, ncols=14, sharex=True, sharey=True,figsize=(25/inToCm,25/inToCm))
ax=ax.ravel()
maxPosAll=[]
ageAll=[]
x=np.array([0,0.9,1.8,3.7,7.5])
for i in range(ttP.shape[1]):
    y=ttP.values[:,i]
    z=np.polyfit(x,y,2)
    p = np.poly1d(z)
    xp = np.linspace(0, 7.5, 1000)
    if np.max(p(xp))>0:
        maxPosAll.append(np.argmax(p(xp))/(1000/7.5))
        ax[i].plot(x, y, '.', xp, p(xp), '-')
        ax[i].set_ylim([-.20,.8])
        ax[i].axvline(maxPosAll[-1])
        ax[i].axis('off')
    else:
        maxPosAll.append(np.nan)
    ageAll.append(int(ttP.columns[i][:2]))

In [None]:
mpa=pd.DataFrame({'age':ageAll,'mp':maxPosAll})
mpa.head()

In [None]:
maxPosIndMn=mpa.groupby(['age']).mean().mp
maxPosIndMd=mpa.groupby(['age']).median().mp
maxPosIndSTD=mpa.groupby(['age']).std().mp
maxPosIndMn

In [None]:
plt.plot(maxPosIndMd)
plt.plot(maxPosIndMn)

In [None]:
import scipy.stats
pfp.paper()
inToCm=2.54
#plt.figure(figsize=(4.5/inToCm,4.5/inToCm))
#ax = plt.gca()
fig, ax = plt.subplots(figsize=(3.5/inToCm,4.5/inToCm))

smaxCol='gray'

xs=maxPosIndMn.index.values
ys=maxPosIndMn.values
s,i,r,p,std=scipy.stats.linregress(xs,ys)
t=np.linspace(5,30,100)
l=i+s*t
ax.plot(t,l,'--',xs,ys,'.',color='k',markersize=15)
(_, caps, _)=ax.errorbar(xs,ys,maxPosIndSTD.values,ls='',color=smaxCol,zorder=100)

for cap in caps:
    cap.set_markeredgewidth(1)
    
ax.text(15,2,'R: '+str(r)[:4],color=smaxCol)
ax.text(15,1,"p = {:.3f}".format(p),color=smaxCol)
ax.set_ylabel('Best diameter (mm)')
ax.set_xlabel('Age (dpf)')
#plt.yticks(np.arange(0,1.8,.4))
ax.set_xlim([9,30])
ax.set_ylim([0,7])
sns.despine()

fishPNGf=RawDataDir+'artwork\\fish_2mmScaleBar.png'
import imageio
im = imageio.imread(fishPNGf)
newax = fig.add_axes([0.55, 0.10, 0.8, 0.8], anchor='NE', zorder=1)
newax.imshow(im[:,:].T, cmap='gray')
newax.axis('off')


figPath=outputDir+'\\1Fb_dotSizePrefCorr.svg'
plt.savefig(figPath)

# same analysis on age group means

### (because individual animal attraction is somewhat noisy, compare fit to mean of all animals)

In [None]:
tt=dfDR.groupby(['age','episode'],sort=True).mean().reset_index()
ttP=tt.pivot_table(index='episode',columns='age',values='si').reset_index()
dfx=ttP.drop('episode',axis=1)
dfx

In [None]:

maxPosAll=[]
ageAll=[]
x=np.array([0,0.9,1.8,3.7,7.5])
for i in range(dfx.shape[1]):
    y=dfx.values[:,i]
    z=np.polyfit(x,y,4)
    p = np.poly1d(z)
    xp = np.linspace(0, 8, 1000)
    if np.max(p(xp))>.01:
        maxPosAll.append(np.argmax(p(xp))/(1000/8.))
    else:
        maxPosAll.append(np.nan)
    ageAll.append(dfx.columns[i])

In [None]:
dfx

In [None]:
inToCm=2.54


fig, axes = plt.subplots(sharex=True, sharey=True,figsize=(4.5/inToCm,4.5/inToCm))

sns.heatmap(dfx.values.astype('float').T,center=0,cmap='seismic',ax=axes)
plt.yticks(range(11),np.sort(df.age.unique()));


In [None]:
from scipy.interpolate import interp1d
maxPos=[]
fig, ax = plt.subplots(nrows=10, ncols=1, sharex=True, sharey=True,figsize=(4.5/inToCm,25/inToCm))
x=np.array([0,0.9,1.8,3.7,7.5])
for i in range(10):
    y=dfx.values[:,i]
    z=np.polyfit(x,y,4)
    p = interp1d(x, y, kind='cubic')
    xp = np.linspace(0, 7.5, 1000)
    maxPos.append(np.argmax(p(xp))/(1000/7.5))
    ax[i].plot(x, y, '.', xp, p(xp), '-')
    ax[i].set_ylim([-.10,.8])
    ax[i].axvline(maxPos[-1])

In [None]:
import scipy.stats
pfp.paper()
inToCm=2.54
plt.figure(figsize=(4.5/inToCm,4.5/inToCm))
ax = plt.gca()

xs=dfx.columns.values.astype('int')
ys=np.array(maxPos)
s,i,r,p,std=scipy.stats.linregress(xs,ys)
t=np.linspace(10,30,100)
l=i+s*t
ax.plot(t,l,'--',xs,ys,'.',color='gray',markersize=20)

ax.text(20,2.0,'R: '+str(r)[:4])
ax.text(20,1.8,"p = {:.3f}".format(p))
ax.set_ylabel('best diameter (mm)')
ax.set_xlabel('age (dpf)')
#ax.set_title('dot size preference')
sns.despine()



# Inrease of attraction with age
### for paper, show dot size = 4 mm (condition lead08)

In [None]:
tmp.count().unique()

In [None]:
tmp=dfDR[dfDR.episode=='lead08'].groupby(['animalIdRun','age'],sort=True)['si']
LowestConditionCount=tmp.count().values.min()
CI95=CI[LowestConditionCount]*2
print(LowestConditionCount, ' episode repeats per animal')
print('CI95:',CI95)

dfx=tmp.mean().reset_index()[::2]




pfp.paper()
inToCm=2.54

fig, axes = plt.subplots(figsize=(4.5/inToCm,4.5/inToCm))

#axes.fill_between([0,30],[CI95],[-CI95],color='gray',alpha=0.4)


perCon=dfx.groupby('age').si.mean().reset_index()
perCon['std']=dfx.groupby('age').si.std().values


axes.scatter(dfx.age,dfx.si,color='gray',s=10,edgecolor='gray')
(_, caps, _) =axes.errorbar(perCon.age,perCon.si,yerr=perCon['std'],color='k',zorder=1,ls='')

axes.scatter(perCon.age,perCon.si,s=100,color='r',marker='_',linewidths=3)

for cap in caps:
    cap.set_markeredgewidth(1)
    
sns.despine()
#plt.axhline(0,ls=':',color='k')
plt.axhline(0,ls='-',color='k',linewidth=0.5)
plt.axhline(CI95,ls='--',color='gray',linewidth=0.5)
plt.axhline(-CI95,ls='--',color='gray',linewidth=0.5)
plt.xlabel('Age (dpf)')
plt.ylabel('Virtual attraction')
axes.set_yticks([0,.5,1]);
axes.set_ylim([-0.15,1])
axes.set_xlim([9,28])
#plt.title('si of individual pairs at various ages')
figPath=outputDir+'\\1E_physical_SI_overAge.svg'
plt.savefig(figPath)

### compare at other sizes

In [None]:
pfp.paper()
inToCm=2.54

co=sns.color_palette("Dark2", 8)

col=[co[0],co[1],co[2],'k',co[4]]


fig, axes = plt.subplots(figsize=(4.5/inToCm,4.5/inToCm))

dfSiSize=pd.DataFrame()

i=0
#axes.scatter(dfx.age,dfx.si,color='gray',s=15,edgecolor='k',label=None)

for ep in np.sort(dfDR.episode.unique()):
    tmp=dfDR[dfDR.episode==ep].groupby(['animalIdRun','age'],sort=True)['si']
    dfx2=tmp.mean().reset_index()[::2]
    perCon=dfx2.groupby('age').si.mean().reset_index()
    dfSiSize[ep]=perCon.si
    axes.plot(np.sort(dfDR.age.unique()),perCon.si,'o-',ms=3,label=yt[i],color=col[i])
    i+=1

axes.set_ylim([-.1,.7])
axes.set_xlabel('age')
axes.set_ylabel('attraction')

plt.legend(title='dot size [mm]',loc='center left', bbox_to_anchor=(1, 0.5))
sns.despine()

# Figure S3 - interdependence of growth and movement

In [None]:
import scipy.stats as sps

tmp=dfDR[dfDR.episode=='lead08'].groupby(['animalIdRun','age'],sort=True)
sizeSi=tmp.mean().reset_index()
sizeSi.boutDur[sizeSi.boutDur>2]=sizeSi.boutDur.mean()

g=sns.pairplot(sizeSi[['age','anSize','avgSpeed','boutDur','si']],kind="reg",size=1.5,markers=".")

def corrfunc(x, y, **kws):
    s,i,r,p,std=sps.linregress(x,y)
    if r<0.99:
        ax = plt.gca()
        ax.annotate("R = {:.2f}".format(r),
                    xy=(.1, .9), xycoords=ax.transAxes)
        ax.annotate("p = {:.1e}".format(p),
                    xy=(.1, .1), xycoords=ax.transAxes)

g=g.map(corrfunc)
g.fig.text(0,1,'Figure S3', fontsize=12,fontweight='bold')
xlabels=np.array(['Age (dpf)','Body length (mm)','Speed (mm/s)','Bout Interval (Hz)','Attraction'])

for i in range(len(xlabels)):
        g.axes[-1,i].xaxis.set_label_text(xlabels[i])
        g.axes[i,0].yaxis.set_label_text(xlabels[i])
        
        
figPath=outputDir+'S3_pairPlot.svg'
plt.savefig(figPath)

In [None]:
lAll=[]
for f in glob.glob(ProcessingDir+'\\*MapData.npy'):
    print('loading:',f)
    lAll.append(np.load(f))

In [None]:
idx=(df['inDishTime']<240) & (df['inDishTime']>25) & (df.animalIndex<14) & (df.age!=21)
allMaps=np.concatenate(tuple(lAll),axis=0)#[idx]
allMaps.shape

In [None]:
ageGroups=[[0,13],[14,16],[17,21],[22,27]]

levels=df['episode'].unique()
treat=ageGroups
avgT=np.zeros((len(treat),len(levels),allMaps.shape[3],allMaps.shape[4]))
avgTC=np.zeros((len(treat),len(levels),allMaps.shape[3],allMaps.shape[4]))

In [None]:
for iag in range(len(treat)):
    ag=treat[iag]
    a1=ag[0]
    a2=ag[1]
    print(ag,a1,a2)
    for i in range(len(levels)):
        ix=np.where((df['episode']==levels[i]) & (df['age']>a1)&(df['age']<=a2)&idx)[0]
        avgT[iag,i,:,:]=allMaps[ix,0,0,:,:].mean(axis=0)
        avgTC[iag,i,:,:]=allMaps[ix,0,1,:,:].mean(axis=0)

ff=['E:\\b\\2017\\20171106_TruePairVsSkype_11dpf\\a_truePair\\meanMap.npy',
    'E:\\b\\2017\\20171026_TruePairVsSkype_16dpf\\a_truePair\\meanMap.npy',
   'D:\\data\\b\\2017\\20171026_TruePairVsSkype_21dpf\\a_truePair\\meanMap.npy',
   'D:\\data\\b\\2017\\20171013_TruePairVsSkype\\a_truePair\\meanMap.npy']

physAll=np.stack([np.load(x) for x in ff])
physAll.shape

In [None]:

matches=glob.glob(ProcessingDir2+'*noga*MapD*')
matches=matches[:-8]
lTrueAll=[]

for m in matches: 
    print('loading:',m)
    lTrueAll.append(np.load(m))


In [None]:
physAll=np.array([np.nanmean(x[:,0,0,:,:],axis=0) for x in lTrueAll])
TruePairAgeExperiments=np.concatenate([np.zeros(6),np.ones(8),np.ones(8)*2,np.ones(8)*3])
physAll=np.array([np.nanmean(physAll[TruePairAgeExperiments==x],axis=0) for x in np.unique(TruePairAgeExperiments)])
physAll.shape

In [None]:
TruePairAgeGroups=np.concatenate([np.zeros(12),np.ones(16),np.ones(16)*2,np.ones(16)*3])
TruePairAgeGroups.shape[0]*24

In [None]:
allMapsPhysical=np.concatenate(tuple(lTrueAll),axis=0)#[idx]
allMapsPhysical.shape

In [None]:
inToCm=2.54
fig, axes = plt.subplots(nrows=4, ncols=6, sharex='col', sharey=True,figsize=(15/inToCm, 12/inToCm))
pfp.paper()

pS=29
pE=31

for iag in range(len(treat)):
    ag=treat[iag]
    a1=ag[0]
    a2=ag[1]
    print(ag,a1,a2)
    for i in range(len(levels)):
        ix=np.where((df['episode']==levels[i]) & (df['age']>a1)&(df['age']<=a2)&idx)[0]
        currAnimals=df.iloc[ix].animalIdRun.unique()
        for j in currAnimals:
            ixb=np.where((df['episode']==levels[i]) & (df['age']>a1)&(df['age']<=a2)&idx&(df.animalIdRun==j))[0]
            axes[iag,i+1].plot(allMaps[ixb,0,0,pS:pE,:].mean(axis=0).mean(axis=0),color='gray',lw=1)
            
        axes[iag,i+1].plot(allMaps[ix,0,0,pS:pE,:].mean(axis=0).mean(axis=0),color='k',lw=2)
        #avgTC[iag,i,:,:]=allMaps[ix,0,1,:,:].mean(axis=0)
        
    #physical interaction profiles
    numInAgeGroup=np.where(TruePairAgeGroups==iag)[0]
    for j in numInAgeGroup:
        
        axes[iag,0].plot(np.nanmean(allMapsPhysical[j*24:(j+1)*24,0,0,pS:pE,:],axis=0).mean(axis=0),color='gray',lw=1)
    axes[iag,0].set_ylabel(['10-13','14-16','17-21','22-27'][iag],rotation='horizontal',ha='right',fontsize=10)
    axes[iag,0].plot(np.nanmean(allMapsPhysical[numInAgeGroup[0]*24:numInAgeGroup[-1]*24,0,0,pS:pE,:],axis=0).mean(axis=0),color='k',lw=2)
    
for i in range(6):
    axes[-1,i].set_xlabel(['Fish','0.0','0.9','1.8','3.7','7.5'][i],fontsize=10)

for a in axes.ravel():
    #a.set_axis_off()
    a.tick_params(axis='y', which='both',length=0)
    a.tick_params(axis='x', which='both',length=0)
    a.set_xticklabels('')
    a.set_yticklabels('')
    a.set_ylim([-1,19])
    a.set_xlim([0,60])
    a.axhline(1,ls=':',color='r')
    #a.set_yscale('log')
  
axes[-1,-1].axhline(15,color='k')
axes[-1,0].axvline(1,color='k')
#axes[-1,-1].axvline(59,color='k')
sns.despine(left=True,bottom=True)
fig.text(0.6, -0.05, 'Dot diameter (mm)', ha='center',fontsize=10)
fig.text(-.03, 0.5, 'Age (dpf)', va='center', rotation='vertical',fontsize=10)

bbox_args = dict(boxstyle="round", fc="0.8")
arrow_args = dict(arrowstyle="-")

axes[-1,i].annotate('', xy=(.95, 0.1), xycoords='figure fraction',
             xytext=(-280, 0), textcoords='offset points',
             ha="right", va="top",
             arrowprops=arrow_args)

figPath=outputDir+'\\S4_Age_NN_profiles.svg'
plt.savefig(figPath,bbox_inches='tight')

In [None]:
fig, axes = plt.subplots(nrows=4, ncols=17, sharex='col', sharey=True,figsize=(20, 5))


for iag in range(len(treat)):

    #physical interaction profiles
    numInAgeGroup=np.where(TruePairAgeGroups==iag)[0]
    pan=0
    for j in numInAgeGroup:
        
        axes[iag,pan].imshow(np.nanmean(allMapsPhysical[j*24:(j+1)*24,0,0,:,:],axis=0))
        pan+=1
        #rint j,pan
    #axes[iag,j].plot(np.nanmean(allMapsPhysical[numInAgeGroup[0]*24:numInAgeGroup[-1]*24,0,0,:,28:32],axis=0).mean(axis=1),color='k',lw=2)

sns.despine()

In [None]:
from mpl_toolkits.axes_grid1 import AxesGrid
import matplotlib
def shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name='shiftedcmap'):
    '''
    Function to offset the "center" of a colormap. Useful for
    data with a negative min and positive max and you want the
    middle of the colormap's dynamic range to be at zero

    Input
    -----
      cmap : The matplotlib colormap to be altered
      start : Offset from lowest point in the colormap's range.
          Defaults to 0.0 (no lower ofset). Should be between
          0.0 and `midpoint`.
      midpoint : The new center of the colormap. Defaults to 
          0.5 (no shift). Should be between 0.0 and 1.0. In
          general, this should be  1 - vmax/(vmax + abs(vmin))
          For example if your data range from -15.0 to +5.0 and
          you want the center of the colormap at 0.0, `midpoint`
          should be set to  1 - 5/(5 + 15)) or 0.75
      stop : Offset from highets point in the colormap's range.
          Defaults to 1.0 (no upper ofset). Should be between
          `midpoint` and 1.0.
    '''
    cdict = {
        'red': [],
        'green': [],
        'blue': [],
        'alpha': []
    }

    # regular index to compute the colors
    reg_index = np.linspace(start, stop, 257)

    # shifted index to match the data
    shift_index = np.hstack([
        np.linspace(0.0, midpoint, 128, endpoint=False), 
        np.linspace(midpoint, 1.0, 129, endpoint=True)
    ])

    for ri, si in zip(reg_index, shift_index):
        r, g, b, a = cmap(ri)

        cdict['red'].append((si, r, r))
        cdict['green'].append((si, g, g))
        cdict['blue'].append((si, b, b))
        cdict['alpha'].append((si, a, a))

    newcmap = matplotlib.colors.LinearSegmentedColormap(name, cdict)
    plt.register_cmap(cmap=newcmap)

    return newcmap

In [None]:
fig, axes = plt.subplots(nrows=4, ncols=5, sharex='col', sharey=True,figsize=(10, 8))
m=np.nanpercentile(avgT,99.9)
orig_cmap = matplotlib.cm.seismic
cmap=shiftedColorMap(orig_cmap,midpoint=1-(m/(m+1)))
trLab=['a','b','c','d','e']
for iag in range(len(treat)):
    for i in range(len(levels)):
        im = axes[iag,i].imshow(avgT[iag,i,:,:],clim=[0,m],extent=[-31,31,-31,31],origin='lower',cmap=cmap)
        axes[iag,i].set_title('a:'+str(iag)+trLab[iag]+ 's:'+levels[i][-2:],fontsize=10)
        axes[iag,i].set_title('a:'+str(iag)+trLab[iag]+ 's:'+levels[i][-2:],fontsize=10)
#plt.colorbar(im)
sns.set_style("ticks")
sns.set_context("paper")
plt.tight_layout()
sns.despine();

In [None]:
import matplotlib.gridspec as gridspec

pfp.paper()
outer = gridspec.GridSpec(1, 3, width_ratios = [1, 5,.2], wspace = 0.05) 
#make nested gridspecs
gs1 = gridspec.GridSpecFromSubplotSpec(4, 1, subplot_spec = outer[0])
gs2 = gridspec.GridSpecFromSubplotSpec(4, 5, subplot_spec = outer[1])
gs3 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec = outer[2])

fig = plt.figure(figsize=(7/inToCm,4.5/inToCm))
axes = [fig.add_subplot(gs2[i]) for i in range(20)]
#ax = fig.add_subplot(gs[0])

axesP = [fig.add_subplot(gs1[i]) for i in range(4)]

axesCB=[fig.add_subplot(gs3[i]) for i in range(1)]

inToCm=2.54
#fig, axes = plt.subplots(nrows=4, ncols=6, sharex=True, sharey=True,figsize=(6/inToCm, 4.5/inToCm))
m=14#np.nanpercentile(physAll,99.9)
orig_cmap = matplotlib.cm.bwr
cmap=shiftedColorMap(orig_cmap,midpoint=1-(m/(m+1)))
#trLab=['a','b','c','d','e']

for i in range(4):
    im = axesP[i].imshow(physAll[i,:,:],clim=[0,m],extent=[-31,31,-31,31],origin='lower',cmap=cmap)

for i in range(1):
    #axesCB[i].set_axis_off()
    axesCB[i].tick_params(axis='y', which='both',length=0)
    axesCB[i].tick_params(axis='x', which='both',length=0)
    axesCB[i].set_xticks([])
    axesCB[i].set_yticks([])
for i in range(4):
    axesP[i].tick_params(axis='y', which='both',length=0)
    axesP[i].tick_params(axis='x', which='both',length=0)
    axesP[i].set_xticks([])
    axesP[i].set_yticks([])
    axesP[i].set_ylabel(['10-13','14-16','17-21','22-27'][i],rotation='horizontal',ha='right',fontsize=10)
    axesP[i].yaxis.set_label_coords(-0.1,0.2)
    axesP[i].spines['top'].set_color('white')
    axesP[i].spines['bottom'].set_color('white')
    axesP[i].spines['left'].set_color('white')
    axesP[i].spines['right'].set_color('white')
    
axesP[i].set_xlabel('Fish',fontsize=10)

for iag in range(len(treat)):
    for i in range(len(levels)):
        ind=i+iag*5
        im = axes[ind].imshow(avgT[iag,i,:,:],clim=[0,m],extent=[-31,31,-31,31],origin='lower',cmap=cmap)
        axes[ind].tick_params(axis='y', which='both',length=0)
        axes[ind].tick_params(axis='x', which='both',length=0)
        axes[ind].set_xticks([])
        axes[ind].set_yticks([])
        axes[ind].spines['top'].set_color('white')
        axes[ind].spines['bottom'].set_color('white')
        axes[ind].spines['left'].set_color('white')
        axes[ind].spines['right'].set_color('white')
        if iag==3:
            axes[ind].set_xlabel(yt[i],fontsize=10)
            
cbar=plt.colorbar(im,cax=axesCB[0],ticks=[0,1,m])
axesCB[0].yaxis.set_ticks_position('right')
axesCB[0].yaxis.set_label_position('right')
for i in range(1):
    axesCB[i].spines['top'].set_color('white')
    axesCB[i].spines['bottom'].set_color('white')
    axesCB[i].spines['left'].set_color('white')
    axesCB[i].spines['right'].set_color('white')

cbar.ax.set_yticklabels(np.array([0,1,m]).astype('int'))
cbar.ax.tick_params(axis='y', which='both',length=2)
cbar.ax.yaxis.set_tick_params(pad=0)
cbar.set_label('Neighbor density',labelpad=0,fontsize=10)
plt.subplots_adjust(wspace=0, hspace=0)

#sns.despine(left=True,bottom=True);

fig.text(0.6, -0.11, 'Dot diameter (mm)', ha='center',fontsize=10)
fig.text(-0.11, 0.5, 'Age (dpf)', va='center', rotation='vertical',fontsize=10)

bbox_args = dict(boxstyle="round", fc="0.8")
arrow_args = dict(arrowstyle="-")

axes[19].annotate('', xy=(.81, 0.15), xycoords='figure fraction',
             xytext=(-110, 0), textcoords='offset points',
             ha="right", va="top",
             arrowprops=arrow_args)

arrow_args = dict(arrowstyle="-",linewidth=2)
axes[19].annotate('', xy=(.845, 0.94), xycoords='figure fraction',
             xytext=(-26, 0), textcoords='offset points',
             ha="right", va="top",
             arrowprops=arrow_args)

figPath=outputDir+'\\1G_physicalVsVirtual_maps.svg'
plt.savefig(figPath,bbox_inches='tight')