# Analysis of assortative shoaling in quartets consisting of 2 age groups each

## (16,17,18)dpf and  (23,24,25)dpf animals

## experiments 1-3 summary

Each of the experiments has the following structure:

Groups of 4 animals are considered one quartet. Each experiment has three quartets and 3 additional animals which are not analyzed.

The stimulus protocol sequentially links the animals of a quartet for quartet interactions and pair-wise interactions:
    1. 00pairs: quartet interactions
    2. 01pairs - 03pairs: pair-wise interactions

In [None]:
%config InteractiveShellApp.pylab_import_all = False
%matplotlib inline
%pylab inline
%reload_ext autoreload
%autoreload 2

import sys
import os
import fnmatch

import numpy as np
import math
import matplotlib.pyplot as plt
import pandas as pd
from pandas import DataFrame, Series
import seaborn as sns
import glob
from datetime import datetime

propsFn='props.csv'
props=pd.read_csv(propsFn, header=None, index_col=0, squeeze=True,delim_whitespace=True).to_dict()

base=props['BaseDir']
expFile=props['allExpFn']

RawDataDir = os.path.join(base,props['RawDataDir'])+'\\'
ProcessingDir = os.path.join(base,props['ProcessingDir'])+'\\Fig4I_ageSorting\\'
outputDir = os.path.join(base,props['outputDir'])+'\\'

if not os.path.isdir(ProcessingDir):
    os.makedirs(ProcessingDir)
if not os.path.isdir(outputDir):
    os.makedirs(outputDir)

os.chdir('..\\')
import functions.matrixUtilities_joh as mu
import matplotlib.pyplot as plt
import models.experiment as xp
import models.experiment_set as es
import functions.paperFigureProps as pfp
pfp.paper()
inToCm=2.54


In [None]:
info=pd.read_csv(expFile, sep=',')
info=info[info.stimulusProtocol=='pc']
info

In [None]:
# collect meta information and save to new csv file for batch processing

aviPath=[]
posPath=[]
PLPath=[]
expTime = []
    
for index,row in info.iterrows():
    startDir=RawDataDir+row.path+'\\'
    #startDir='D:\\data\\b\\2017\\'+row.path+'\\'
    #if not os.path.isdir(startDir):
    #    startDir='E:\\b\\2017\\'+row.path+'\\'
        
    posPath.append(glob.glob(startDir+'PositionTxt*.txt')[0])
    PLPath.append(glob.glob(startDir+'PL*.csv')[0])
    
    head, tail = os.path.split(posPath[-1])
    currTime=datetime.strptime(tail[-23:-4], '%Y-%m-%dT%H_%M_%S')
    expTime.append(currTime)
    
info['txtPath']=posPath
info['pairList']=PLPath

info['epiDur'] = 5      # duration of individual episodes (default: 5 minutes)
info['episodes'] = -1   # number of episodes to process: -1 to load all episodes (default: -1)
info['inDish'] = 10#np.arange(len(posPath))*120     # time in dish before experiments started (default: 10)
info['arenaDiameter_mm'] = 100 # arena diameter (default: 100 mm)
info['minShift'] = 60 # minimum number of seconds to shift for control IAD
info['episodePLcode'] = 1 # flag if first two characters of episode name encode animal pair matrix (default: 0)
info['recomputeAnimalSize'] = 1 # flag to compute animals size from avi file (takes time, default: 1)
info['SaveNeighborhoodMaps'] = 0 # flag to save neighborhood maps for subsequent analysis (takes time, default: 1)
info['computeLeadership'] = 0 # flag to compute leadership index (takes time, default: 1)
info['ComputeBouts'] = 0 # flag to compute swim bout frequency (takes time, default: 1)
info['set'] = np.arange(len(posPath))   # experiment set: can label groups of experiments (default: 0)
info['ProcessingDir']=ProcessingDir
info['outputDir']=outputDir
info['allowEpisodeSwitch']=1

info['expTime']=expTime

csvFile=os.path.join(ProcessingDir,'Fig4I_ageSorting.csv')
info.to_csv(csvFile,encoding='utf-8')
info

In [None]:
def readExperiment(keepData=False):
    tmp=es.experiment_set(csvFile=csvFile)
    if keepData:
        return tmp
    else:
        return 1

expSet=readExperiment(keepData=True)

## structure of the summary file:
### episode column indicates interactions

00pairs is the quartet interaction

    for each 00pair episode, there are 3 rows per animal, corresponding to the pairings with each other animal indicated by columns animalIndex and CurrentPartner

01pairs - 03pairs are 2-way interactions

    for each of those, there is only one row


In [None]:
csvPath = []
for f in [mu.splitall(x)[-1][:-4] for x in info.txtPath]:
    csvPath.append(glob.glob(ProcessingDir+f+'*siSummary*.csv')[0])

df=pd.DataFrame()
i=0
for fn in csvPath:
    print(fn)
    tmp=pd.read_csv(fn,index_col=0,sep=',')
    tmp.animalSet=i
    tmp['animalIndexCont']=tmp.animalIndex+((i)*15)
    tmp['CurrentPartnerCont']=tmp.CurrentPartner+((i)*15)
    df=pd.concat([df,tmp])
    i+=1
df['episode']=[x.strip().replace('_','') for x in df['episode']]

print('df shape',df.shape)

d=df.time
r=datetime(2017,1,1)
t2=[pd.to_datetime(x).replace(day=1,month=1)for x in df.time]

t3=[(x-r)/pd.Timedelta('1 hour') for x in t2]
df['t2']=t2
df['t3']=t3
df.head()

## Shoaling preferences during 4-way interaction

In [None]:
#first, analyze 4-way interactions (only use animals 0-11, 12-14 are not a group of 4)
iQuad=(df.episode=='00pairs')&(df.animalIndex<12)&(df.inDishTime<355)
dfq=df[iQuad]

#adjust 'animalIndex' and 'currentPartner' such that animalIndex is always the lower number
#this is legit because 0-1 designates the same pair as 1-0
aa=dfq[['animalIndexCont','CurrentPartnerCont']].values
dfq.loc[:,['animalIndexCont','CurrentPartnerCont']]=np.sort(aa,axis=1)


In [None]:
#get the data for 
dfqSI=dfq.groupby(['animalIndexCont','CurrentPartnerCont','inDishTime']).mean().si.reset_index()

In [None]:
# during 4-way interactions, calculate the average of the 4 possible small-large pairings
comboGroups=np.array([0,1,1,1,1,2]) #sm-sm, sm-lg, lg-lg

#cg 0: small small
#cg 1: small large
#cg 2: large large
dfqSI['cg']=np.repeat(np.tile(comboGroups,9),dfq.inDishTime.unique().shape[0])


dfqSI['gi']=np.repeat(np.arange(9),6*dfq.inDishTime.unique().shape[0])

dfqSI.head()

In [None]:
dfqSI=dfqSI.groupby(['animalIndexCont','CurrentPartnerCont','inDishTime','cg','gi']).mean().si.reset_index()

In [None]:
sns.tsplot(data=dfqSI.groupby(['inDishTime','cg','gi']).mean().si.reset_index(),time='inDishTime',unit='gi',value='si',condition='cg')

## Shoaling preferences during pair-wise interaction 01pairs (small-small and large-large) and 02pairs/03pairs (small-large combinations)

In [None]:
iSameDub=(df.episode.isin(['01pairs']))&(df.animalIndex.isin(np.arange(0,12,2)))&(df.inDishTime<355)
dfsd=df[iSameDub]
dfsdSI=dfsd.groupby(['animalIndexCont','CurrentPartnerCont','age','inDishTime']).mean().si.reset_index()
dfsdSI['cg']=3
dfsdSI['gi']=np.repeat(np.repeat(np.arange(9),2),dfsdSI.inDishTime.unique().shape[0])

ix=(dfsdSI.age>20)
dfsdSI.loc[ix,'cg']=5
dfsdSI=dfsdSI.drop('age',axis=1)
sns.tsplot(data=dfsdSI,time='inDishTime',unit='animalIndexCont',value='si',condition='cg')
plt.ylim([-.1,.5])

In [None]:
iCrossDub=(df.episode.isin(['03pairs','02pairs']))&(df.animalIndex<12)&(df.inDishTime<355)


dfcd=df[iCrossDub]

#adjust 'animalIndex' and 'currentPartner' such that animalIndex is always the lower number
#this is legit because 0-1 designates the same pair as 1-0
aa=dfcd[['animalIndexCont','CurrentPartnerCont']].values
dfcd.loc[:,['animalIndexCont','CurrentPartnerCont']]=np.sort(aa,axis=1)

dfcdSI=dfcd.groupby(['animalIndexCont','CurrentPartnerCont','inDishTime']).mean().si.reset_index()
dfcdSI['cg']=4

dfcdSI['gi']=np.repeat(np.repeat(np.arange(9),4),dfcdSI.inDishTime.unique().shape[0]/2)
sns.tsplot(data=dfcdSI,time='inDishTime',unit='animalIndexCont',value='si',condition='cg')
plt.ylim([-.1,.5])

In [None]:
dfAllT=pd.concat([dfqSI,dfsdSI,dfcdSI],sort=True)

In [None]:
sns.tsplot(data=dfAllT.groupby(['inDishTime','gi','cg']).mean().reset_index(),time='inDishTime',unit='gi',value='si',condition='cg')


In [None]:
tidx=(dfAllT.inDishTime>60)&(dfAllT.inDishTime<350)
dfAll=dfAllT[tidx].groupby(['cg','gi']).mean().si.reset_index()
sns.swarmplot(data=dfAll,x='cg',y='si')
plt.axhline(0)

In [None]:
fig, axes = plt.subplots(figsize=(5, 7))
sns.pointplot(data=dfAll[dfAll.cg<3],x='cg',y='si',hue='gi',ax=axes)
axes.set_xticklabels(['sm-sm','sm-lg','lg-lg'])
axes.set_xlabel('')
axes.set_ylabel('attraction')
sns.despine()

In [None]:
quadOnly=dfAll[dfAll.cg<3]
quadOnly.head()

In [None]:
q=dfAll.pivot_table(index='gi',columns='cg',values='si').reset_index()
q

In [None]:
q['6']=(q[q.columns[1]]+q[q.columns[3]])/2.
q['7']=(q[q.columns[4]]+q[q.columns[6]])/2.
qq=q.drop('gi',axis=1).stack().reset_index()
qq.columns=['gi','cg','si']
qq.head()

In [None]:
fig, axes = plt.subplots(figsize=(4, 7))
#sns.boxplot(data=qq,x='cg',y='si',notch=True)

#sns.swarmplot(data=qq,x='cg',y='si',ax=axes,zorder=1)
sns.pointplot(data=qq,x='cg',y='si',ax=axes,zorder=1,hue='gi')
              
sns.pointplot(data=qq,x='cg',y='si',ax=axes,join=False,
              palette=['k'],
              zorder=100,
              legend=False,
              ci='sd',
              errwidth=1,
              capsize=.3)

In [None]:
fig, axes = plt.subplots(figsize=(4, 7))
sns.boxplot(data=qq,x='cg',y='si',notch=True)

In [None]:
q.columns

In [None]:
from scipy import stats
a=q[q.columns[2]]
b=q[q.columns[7]]
c=q[q.columns[3]]
print(a,b)
print('related samples',stats.ttest_rel(a,b),stats.ttest_rel(b,c))
print('independent samples',stats.ttest_ind(a,b),stats.ttest_ind(b,c))

In [None]:
dyadOnly=dfAll[dfAll.cg>2]
dyadOnly.head()

In [None]:
fig, axes = plt.subplots(figsize=(5, 7))
sns.pointplot(data=qq,x='cg',y='si',hue='gi',ax=axes)
axes.set_xticklabels(['sm-sm','sm-lg','lg-lg'])
axes.set_xlabel('')
axes.set_ylabel('attraction')
sns.despine()

In [None]:
fig, axes = plt.subplots(figsize=(5, 7))
sns.boxplot(data=qq,x='cg',y='si',ax=axes)
sns.pointplot(data=qq,x='cg',y='si',ax=axes)
axes.set_xticklabels(['sm-sm','sm-lg','lg-lg'])
axes.set_xlabel('')
axes.set_ylabel('attraction')
sns.despine()



In [None]:
pfp.paper()
inToCm=2.54
plt.figure(figsize=(9/inToCm,4.5/inToCm))
ax = plt.gca()

col=['gray','k',[0,0.6,0],[0,.35,0]]
sns.set_palette(col)
qq.cg=qq.cg.astype('int')
#sns.boxplot(data=dfAll,x='cg',y='si',ax=axes)
lab=np.array(['Quartet','Pair','Model','Model2'])
plotGroups=np.array([0,0,0,1,1,1,2,3])


sns.pointplot(data=qq,x='cg',y='si',ax=ax,hue='gi',zorder=0,order=[0,1,6,2,3,4,7,5],
             palette=['gray'],
             scale=0.2)
sns.stripplot(data=qq,x='cg',y='si',ax=ax,hue=lab[plotGroups[qq.cg.values]],zorder=1,order=[0,1,6,2,3,4,7,5])

sns.pointplot(data=qq,x='cg',y='si',ax=ax,join=False,hue=lab[plotGroups[qq.cg.values]],
              palette=['k'],
              zorder=1000,
              legend=False,
              ci='sd',
              errwidth=1,
              capsize=.3,order=[0,1,6,2,3,4,7,5])

sns.pointplot(data=qq,x='cg',y='si',ax=ax,join=False,hue=lab[plotGroups[qq.cg.values]],
              palette=['r'],
              zorder=1000,
              ci=None,
              legend=False,
              markers='_',
              scale=3,order=[0,1,6,2,3,4,7,5])


ax.set_xticklabels(['s-s','s-L','s-L \n Model','L-L','s-s','s-L','s-L \n Model','L-L'])
ax.set_yticks(np.arange(0,.5,.2))
ax.set_xlabel('')
ax.set_ylabel('Attraction')

# Get the handles and labels. For this example it'll be 2 tuples
# of length 4 each.
handles, labels = ax.get_legend_handles_labels()

# When creating the legend, only use the first two elements
# to effectively remove the last two.
labels=np.array(labels)
handles=np.array(handles)
li=np.array([9,11,10])
l = plt.legend(handles[li], labels[li], ncol=3, loc='upper center', borderaxespad=0.)

ax.axvline(3.5,ls=':',color='gray')


# statistical annotation, see below for stats!
x1, x2 = 1, 2   # columns
l=0.025
y, h, col = .32 + l, l, 'k'
plt.plot([x1, x1, x2, x2], [y, y+h, y+h, y], lw=1.5, c=col)
plt.text((x1+x2)*.5, y+h, "***", ha='center', va='bottom', color=col,size=10)

# statistical annotation
x1, x2 = 5, 6   # columns
l=0.025
y, h, col = .32 + l, l, 'k'
plt.plot([x1, x1, x2, x2], [y, y+h, y+h, y], lw=1.5, c=col)
plt.text((x1+x2)*.5, y+h, "***", ha='center', va='bottom', color=col,size=10)



sns.despine()
figPath=outputDir+'\\4I_PairChoice.svg'
plt.savefig(figPath)

In [None]:
def normGroup(x):
    return x.si/x.si.values[1]

dfNorm=dfAll[dfAll.cg<3].groupby(['gi']).apply(normGroup).reset_index()
dfNorm['cg']=np.tile(np.arange(3),dfAll[dfAll.cg<3].gi.unique().shape[0])
dfNorm.head()

In [None]:
fig, axes = plt.subplots(figsize=(5, 7))
colors=np.repeat('k',dfAll.gi.unique().shape[0])
sns.pointplot(data=dfNorm,x='cg',y='si',hue='gi',ax=axes)
axes.set_xticklabels(['sm-sm','sm-lg','lg-lg'])
axes.set_xlabel('')
axes.set_ylabel('relative attraction')
axes.set_ylim([-1,5])
plt.legend(title='quartet number',loc='best')
sns.despine()

In [None]:
pfp.paper()
inToCm=2.54
fig, axes = plt.subplots(figsize=(4.5/inToCm,4.5/inToCm))
#sns.boxplot(data=dfNorm,x='cg',y='si',ax=axes,color='gray',notch=True)

sns.pointplot(data=dfNorm,x='cg',y='si',hue='gi',ax=axes,linewidth=0.5,scale=0.5,zorder=-100)
plt.setp(axes.collections, sizes=[50],zorder=-100)
plt.setp(axes.lines, zorder=-100)

sns.pointplot(data=dfNorm,x='cg',y='si',
                color='k',
                ax=axes,
                estimator=np.median,
                ci ='sd',
                zorder=100,
                edgecolor='k',
                join=False,
                markers=['_'],
                scale=4,
             errwidth=1)

#sns.swarmplot(data=dfNorm,x='cg',y='si',ax=axes,size=10,edgecolor='k',linewidth=1,color='k')
axes.set_xticklabels(['sm-sm','sm-lg','lg-lg'])
axes.set_xlabel('')
axes.set_ylabel('relative attraction')
axes.set_ylim([0,4])
#plt.legend(title='quartet number',loc='best')
axes.legend_.remove()
sns.despine()

In [None]:
csvFileNetLogo=os.path.join(RawDataDir,'20180316_netLogo4Animals\\Shoaling_4an 4An_socia_Paper-table.csv')

df_NL=pd.read_csv(csvFileNetLogo,skiprows=6)
df_NL=df_NL.groupby(['[run number]','t1-Social','t2-Social']).mean().reset_index() #drop other columns

#calculate shoaling index, using IAD when both agents have Ps<0.01
idx=(df_NL['t1-Social']<0.01) & (df_NL['t2-Social']<0.01) 
IADs=df_NL[idx].IAD11
print('mean IAD without attraction: ',IADs.mean())
df_NL['SI11']=(IADs.mean()-df_NL.IAD11)/IADs.mean()
df_NL['SI12']=(IADs.mean()-df_NL.IAD12)/IADs.mean()
df_NL['SI22']=(IADs.mean()-df_NL.IAD22)/IADs.mean()
df_NL.head()

In [None]:
pfp.paper()
inToCm=2.54
fig, ax = plt.subplots(figsize=(10/inToCm,10/inToCm))

sns.pointplot(data=df_NL,x='t1-Social',hue='t2-Social',y='SI11',ax=ax,estimator=np.mean)

In [None]:
pfp.paper()
inToCm=2.54
fig, ax = plt.subplots(figsize=(10/inToCm,10/inToCm))

sns.pointplot(data=df_NL,x='t1-Social',hue='t2-Social',y='SI12',ax=ax,estimator=np.mean)

In [None]:
pfp.paper()
inToCm=2.54
fig, ax = plt.subplots(figsize=(10/inToCm,10/inToCm))

sns.pointplot(data=df_NL,x='t1-Social',hue='t2-Social',y='SI22',ax=ax,estimator=np.mean)

In [None]:
pfp.paper()
sns.set_palette('viridis',6)


fig, ax = plt.subplots(figsize=(3/inToCm,4.5/inToCm))
for i in range(9):
    Sa1=i/20.
    idx=(df_NL['t1-Social']==Sa1) &(df_NL['t2-Social']==.10)
    print(Sa1, np.sum(idx))
    if np.sum(idx)>0:
        
        a=df_NL[idx].SI11.mean()
        b=df_NL[idx].SI12.mean()
        c=df_NL[idx].SI22.mean()
        plt.plot([a,b,c],'.-',label=Sa1)



plt.xticks([0,1,2],['t1-t1','t1-t2','t2-t2'])
plt.xlim([-.1,2.3])
plt.ylim([-.1,1])
sns.despine() 
plt.ylabel('Model attraction')
plt.title('Ps type 2: 0.15')
plt.legend(title='Ps type 1:',ncol=1,
          bbox_to_anchor=(0.75, .9, 1., .102),
          handletextpad=0.1);
figPath=outputDir+'\\4H_Model_QuartetCombinations.svg'
plt.savefig(figPath)

from shutil import copy2

def splitall(path):
    allparts = []
    while 1:
        parts = os.path.split(path)
        if parts[0] == path:  # sentinel for absolute paths
            allparts.insert(0, parts[0])
            break
        elif parts[1] == path: # sentinel for relative paths
            allparts.insert(0, parts[1])
            break
        else:
            path = parts[0]
            allparts.insert(0, parts[1])
    return allparts



for i,row in info.iterrows():
    fn=row.txtPath
    head, tail = os.path.split(fn)

    copyList=[]
    copyList.append(glob.glob(head+'\\ROI*.csv')[0])
    copyList.append(glob.glob(head+'\\PositionTxt*.txt')[0])
    copyList.append(glob.glob(head+'\\PL*.csv')[0])
    copyList.append(glob.glob(head+'\\*anSize.csv')[0])
    
    for f in copyList:
        print(f)
        if f[0]=='E':
            keepSlash=3
        else:
            keepSlash=4
        toDirectory = "e:\\b\\LarschAndBaier2018\\RawData\\" + os.path.join(*splitall(f)[keepSlash:-1])+"\\"
        #toDirectory = "e:\\b\\LarschAndBaier2018\\RawData\\" 
        if not os.path.isdir(toDirectory):
            os.makedirs(toDirectory)
        
        copy2(f, toDirectory)
