# Experiment analysis: Attraction towards dots
## can group animals by age or treatment

In [None]:
import os
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import scipy.stats as stats
import glob
from datetime import datetime

#
# load custom modules for social behavior analysis
# enter path to your local repository here
#os.chdir(r'C:\Users\Alex Tallafuss\Documents\social_repository')
os.chdir(r'C:\Users\johannes\Dropbox\python\zFishBehavior\dishGroupBehavior')
# ----------------------------


import models.experiment as xp
import models.experiment_set as es
import functions.matrixUtilities_joh as mu
import functions.paperFigureProps as pfp

#
# notebook configuration
%config InteractiveShellApp.pylab_import_all = False
%matplotlib inline
%pylab inline
%reload_ext autoreload
%autoreload 2

#
# custom paper style plotting
pfp.paper()

Pre-Analyze all experiments only if necessary, this takes a couple of minutes! Experiment summary csv files are saved to disk.

In [None]:
  
# ENTER PATH to allExp.xlsx and allAn.xlsx HERE

base=os.path.normpath(r'e:\testData\social\\')
expFile=os.path.join(base,'AllExp.xlsx')
anFile=os.path.join(base,'Allan.xlsx')

# ENTER DATA PATH HERE

RawDataDir = os.path.normpath(r'e:\testData\social\\')
ProcessingDir = RawDataDir
outputDir = RawDataDir
# -------------------------------------

In [None]:
base

In [None]:
info=pd.read_excel(expFile)
info=info[info.stimulusProtocol=='a']
info.head()

In [None]:
infoAn=pd.read_excel(anFile)
infoAn.head()

In [None]:
# collect meta information and save to new csv file for batch processing

aviPath=[]
posPath=[]
PLPath=[]
expTime = []
condition=[]
birthDayAll=[]
    
for index,row in info.iterrows():
    startDir=os.path.join(RawDataDir,row.path)
        
    posPath.append(glob.glob(os.path.join(startDir,'PositionTxt*.txt'))[0])
    PLPath.append(glob.glob(os.path.join(startDir,'PL*.txt'))[0])
    
    head, tail = os.path.split(posPath[-1])
    currTime=datetime.strptime(tail[-23:-4], '%Y-%m-%dT%H_%M_%S')
    expTime.append(currTime)
    
    anNrs=row.anNr
    if ':' in anNrs:
        a,b=anNrs.split(sep=':')
        anNrs=np.arange(int(a),int(b)+1)
    else:
        anNrs=np.array(anNrs.split())
        
    anIDs=anNrs-1
    
    gt=infoAn.line.values[anIDs]
    bd=infoAn.bd.values[anIDs]
    condition.extend(list(gt))
    birthDayAll.append(' '.join(list(bd)))

    
info['txtPath']=posPath
info['pairList']=PLPath
info['birthDayAll']=birthDayAll

info['epiDur'] = 5      # duration of individual episodes (default: 5 minutes)
info['episodes'] = -1   # number of episodes to process: -1 to load all episodes (default: -1)
info['inDish'] = 10#np.arange(len(posPath))*120     # time in dish before experiments started (default: 10)
info['arenaDiameter_mm'] = 100 # arena diameter (default: 100 mm)
info['minShift'] = 60 # minimum number of seconds to shift for control IAD
info['episodePLcode'] = 0 # flag if first two characters of episode name encode animal pair matrix (default: 0)
info['recomputeAnimalSize'] = 0 # flag to compute animals size from avi file (takes time, default: 1)
info['SaveNeighborhoodMaps'] = 0 # flag to save neighborhood maps for subsequent analysis (takes time, default: 1)
info['computeLeadership'] = 0 # flag to compute leadership index (takes time, default: 1)
info['ComputeBouts'] = 0 # flag to compute swim bout frequency (takes time, default: 1)
info['set'] = np.arange(len(posPath))   # experiment set: can label groups of experiments (default: 0)
info['ProcessingDir']=ProcessingDir
info['outputDir']=outputDir

info['expTime']=expTime
condition=np.array(condition)

csvFile=os.path.join(ProcessingDir,'processingSettings.csv')
info.to_csv(csvFile,encoding='utf-8')
info

In [None]:
rereadData=1
if rereadData:
    def readExperiment(keepData=False):
        tmp=es.experiment_set(csvFile=csvFile)
        if keepData:
            return tmp
        else:
            return 1

    expSet=readExperiment(keepData=True)

In [None]:
csvPath = []
for f in [mu.splitall(x)[-1][:-4] for x in info.txtPath]:
    csvPath.append(glob.glob(os.path.join(ProcessingDir,f,)+'*siSummary*.csv')[0])

df=pd.DataFrame()
i=0
for fn in csvPath:
    print(fn)
    tmp=pd.read_csv(fn,index_col=0,sep=',')
    tmp.animalSet=i
    tmp.animalIndex=tmp.animalIndex+((i)*15)
    df=pd.concat([df,tmp])
    i+=1
df['episode']=[x.strip().replace('_','') for x in df['episode']]

print('df shape',df.shape)

In [None]:
d=df.time
r=datetime(int(df.time.values[0][:4]),1,1)
t2=[pd.to_datetime(x).replace(day=1,month=1)for x in df.time]
t3=[(x-r)/pd.Timedelta('1 hour') for x in t2]
df['t2']=t2
df['t3']=t3
df.head()

## Habituation or Fatigue within 20 hours?

Plot shoaling index during closed loop skype episodes over time.

In [None]:
sns.tsplot(data=df, time="t3",value="si",unit="animalIndex",condition="episode",estimator=np.nanmean,interpolate=False,err_style="ci_bars");
plt.xlim([0,24])
plt.axhline(0,ls=':',color='gray')

# Plot individual 5 minute segments for all animals
## here, using 'genotype' to sub-divide the data in each category

In [None]:
df['condition']=condition[df.animalIndex]

In [None]:
#Limit analysis to a time window (typically ignore fist 45 minutes and times later than 350 minutes)
tStart=45
tEnd=350
idx=(df['inDishTime']<tEnd) & (df['inDishTime']>tStart)
dfDR=df[idx]

In [None]:
sns.swarmplot(data=dfDR,
              x='episode',
              y='si',
              hue='condition',
              dodge=1)
plt.axhline(0,ls=':')

# average episodes over animals first
### generally, using n = number of animals whenever possible 

In [None]:
dfAnimalAverage=dfDR.groupby(['episode','animalIndex','condition'],sort=True).mean().reset_index()
sns.swarmplot(data=dfAnimalAverage,
              x='episode',
              y='si',
              hue='condition',
              dodge=1)
plt.axhline(0,ls=':')

In [None]:
sns.pointplot(data=dfAnimalAverage,
              x='episode',
              y='si',
              hue='condition')
plt.axhline(0,ls=':')


In [None]:
#individual animals
dfDR.groupby(['episode','animalIndex'],sort=True)['si'].mean().unstack().plot()

In [None]:
sns.swarmplot(data=dfAnimalAverage,
              x='episode',
              y='si',
              hue='condition',
              dodge=1)

sns.pointplot(data=dfAnimalAverage,
              x='episode',
              y='si',
              hue='condition')

plt.axhline(0,ls=':')

# compare speed between groups

In [None]:
sns.pointplot(data=dfAnimalAverage,
              x='episode',
              y='avgSpeed',
              hue='condition')
plt.axhline(0,ls=':')

#plt.ylabel('average Speed [mm/sec]')

# thigmotaxis index

In [None]:
sns.pointplot(data=dfAnimalAverage,
              x='episode',
              y='thigmoIndex',
              hue='condition')
plt.axhline(0,ls=':')

# bout duration

In [None]:
sns.pointplot(data=dfAnimalAverage,
              x='episode',
              y='boutDur',
              hue='condition',
              estimator=np.median, # bout duration can be heavily influenced by outliers
             ci=None)
plt.axhline(0,ls=':')

# correlate size with attraction

In [None]:
# this only works if size was calculated from video
ep=dfAnimalAverage.episode.unique()[3] #which episode data to plot
print('using ',ep,' from: ',dfAnimalAverage.episode.unique())
sns.pairplot(dfAnimalAverage[(np.isfinite(dfAnimalAverage.anSize))&(dfAnimalAverage.episode==ep)],vars=["anSize", "si"])

# Plot average neighborhood maps for all animals

In [None]:
# get all the maps from the expSet data structure
# (this data is also stored in a .npy file)
nmatAll=np.array([y.animals[0].ts.neighborMat() for y in expSet.experiments[0].pair])

In [None]:
levels=df['episode'].unique()
ans=df['animalIndex'].unique()
avg=np.zeros((len(ans),len(levels),nmatAll.shape[1],nmatAll.shape[2]))


In [None]:
for an in ans:
    for i in range(len(levels)):
        ix=np.where((df['episode']==levels[i]) & (df['animalIndex']==an) & idx)[0]
        avg[an,i,:,:]=nmatAll[ix,:,:].mean(axis=0)


In [None]:
fig, axes = plt.subplots(nrows=15, ncols=7, sharex='col', sharey=True,figsize=(10, 30))
m=np.nanpercentile(avg,95)
trLab=treatName
for an in ans:
    for i in range(len(levels)):
        axes[an,i].imshow(avg[an,i,:,:],clim=[0,m],extent=[-31,31,-31,31])
        axes[an,i].set_title('a:'+str(an)+trLab[treatment[an]]+ 's:'+levels[i][-2:],fontsize=10)

In [None]:
levels=df['episode'].unique()
treat=treatName
avgT=np.zeros((len(treat),len(levels),nmatAll.shape[1],nmatAll.shape[2]))


In [None]:
for an in range(len(treat)):
    for i in range(len(levels)):
        ix=np.where((df['episode']==levels[i]) & (df['treatment']==treat[an]) & idx)[0]
        avgT[an,i,:,:]=nmatAll[ix,:,:].mean(axis=0)
        

In [None]:
from mpl_toolkits.axes_grid1 import AxesGrid
import matplotlib
def shiftedColorMap(cmap, start=0, midpoint=0.5, stop=1.0, name='shiftedcmap'):
    '''
    Function to offset the "center" of a colormap. Useful for
    data with a negative min and positive max and you want the
    middle of the colormap's dynamic range to be at zero

    Input
    -----
      cmap : The matplotlib colormap to be altered
      start : Offset from lowest point in the colormap's range.
          Defaults to 0.0 (no lower ofset). Should be between
          0.0 and `midpoint`.
      midpoint : The new center of the colormap. Defaults to 
          0.5 (no shift). Should be between 0.0 and 1.0. In
          general, this should be  1 - vmax/(vmax + abs(vmin))
          For example if your data range from -15.0 to +5.0 and
          you want the center of the colormap at 0.0, `midpoint`
          should be set to  1 - 5/(5 + 15)) or 0.75
      stop : Offset from highets point in the colormap's range.
          Defaults to 1.0 (no upper ofset). Should be between
          `midpoint` and 1.0.
    '''
    cdict = {
        'red': [],
        'green': [],
        'blue': [],
        'alpha': []
    }

    # regular index to compute the colors
    reg_index = np.linspace(start, stop, 257)

    # shifted index to match the data
    shift_index = np.hstack([
        np.linspace(0.0, midpoint, 128, endpoint=False), 
        np.linspace(midpoint, 1.0, 129, endpoint=True)
    ])

    for ri, si in zip(reg_index, shift_index):
        r, g, b, a = cmap(ri)

        cdict['red'].append((si, r, r))
        cdict['green'].append((si, g, g))
        cdict['blue'].append((si, b, b))
        cdict['alpha'].append((si, a, a))

    newcmap = matplotlib.colors.LinearSegmentedColormap(name, cdict)
    plt.register_cmap(cmap=newcmap)

    return newcmap

In [None]:
import matplotlib.gridspec as gridspec

pfp.paper()
inToCm=2.54

ncols=len(df.episode.unique())
nrows=len(df.treatment.unique())

outer = gridspec.GridSpec(2, 2, width_ratios = [5,.1], wspace = 0.05) 
#make nested gridspecs
gs2 = gridspec.GridSpecFromSubplotSpec(nrows, ncols, subplot_spec = outer[0])
gs3 = gridspec.GridSpecFromSubplotSpec(1, 1, subplot_spec = outer[1])

fig = plt.figure(figsize=(11/inToCm,11/inToCm))
axes = [fig.add_subplot(gs2[i]) for i in range(ncols*nrows)]
axesCB=[fig.add_subplot(gs3[i]) for i in range(1)]

axesSP=fig.add_subplot(outer[2])

m=np.nanpercentile(avgT,99)
orig_cmap = matplotlib.cm.bwr
cmap=shiftedColorMap(orig_cmap,midpoint=1-(m/(m+1)))

trLab=treatName
pal=['gray','r','g','m']
for an in range(len(treat)):
    for i in range(len(levels)):
        ind=i+(ncols*an)
        im = axes[ind].imshow(avgT[an,i,:,:],clim=[0,m],extent=[-31,31,-31,31],origin='lower')#,cmap=cmap)
        axes[ind].tick_params(axis='y', which='both',length=0)
        axes[ind].tick_params(axis='x', which='both',length=0)
        axes[ind].set_xticks([])
        axes[ind].set_yticks([])
        axes[ind].spines['top'].set_color('white')
        axes[ind].spines['bottom'].set_color('white')
        axes[ind].spines['left'].set_color('white')
        axes[ind].spines['right'].set_color('white')

        if i==0:
            axes[ind].set_title(trLab[an],fontsize=8,color=pal[an])
            
        if (i==5)&(an==0):
            axes[ind].set_title('neighbor density',fontsize=9)

cbar=plt.colorbar(im,cax=axesCB[0],ticks=np.round([0,1,m-0.1]))

plt.subplots_adjust(wspace=0, hspace=0.1)

social=df[idx].groupby(['treatment','episode','animalIndex']).si.mean().reset_index()
social['xpretty']=[int(ss[-2:])/2. for ss in social.episode]
sns.swarmplot(data=social,
              x='xpretty',
              hue='treatment',
              y='si',
              zorder=1,
              linewidth=1,
              edgecolor='gray',
              ax=axesSP,
              palette=pal,
              alpha=0.7)

sns.pointplot(x="xpretty", y="si", hue='treatment',data=social,ci=None,zorder=100,scale=2,ax=axesSP,palette=pal,
              linewidth=1,edgecolor='gray')
axesSP.spines['top'].set_color('white')
axesSP.spines['bottom'].set_color('white')
axesSP.spines['right'].set_color('white')
axesSP.tick_params(axis='x', which='both',length=0)

axesSP.yaxis.tick_left()
axesSP.set_xlabel('dot diameter [mm]')
axesSP.set_ylabel('attraction')
handles, labels = axesSP.get_legend_handles_labels()
axesSP.legend(handles[:4], labels[:4])

axesSP.axhline(0,ls=':',color='k')


In [None]:
pfp.paper()
fig, ax = plt.subplots(figsize=(2/inToCm,4.5/inToCm))

social=df[idx].groupby(['treatment','animalIndex']).avgSpeed.mean().reset_index()
sns.boxplot(y="avgSpeed", x='treatment',data=social,ax=ax,palette=pal,linewidth=2)


# Select which box you want to change    
for i,artist in enumerate(ax.artists):
# Change the appearance of that box
    artist.set_edgecolor('k')
    for j in range(i*6,i*6+6):
        line = ax.lines[j]
        line.set_color('k')
        line.set_mfc('k')
        line.set_mec('k')

sns.swarmplot(data=social,x='treatment',y='avgSpeed',zorder=100,linewidth=1,ax=ax,palette=pal,alpha=0.7,edgecolor='k')

plt.xlabel('')
plt.xticks([])
plt.ylabel('average Speed \n [mm/sec]')

plt.ylim([0,7])
sns.despine()
plt.subplots_adjust(wspace=0, hspace=0)


In [None]:
treatNum=len(treatName)

In [None]:
fig, axes = plt.subplots(nrows=treatNum, ncols=7, sharex=True, sharey=True,figsize=(10,10))
m=np.nanpercentile(avg,95)
trLab=treatName
for an in range(len(treat)):
    for i in range(len(levels)):
        profile=avgT[an,i,:,29:31].mean(axis=1)
        axes[an,i].plot(profile,np.arange(profile.shape[0])-30)
        axes[an,i].set_title('a:'+str(an)+trLab[an]+ 's:'+levels[i][-2:],fontsize=10)
        axes[an,i].axhline(0,ls=':',color='gray')


In [None]:
fig, axes = plt.subplots(nrows=treatNum, ncols=len(levels), sharex='col', sharey=True,figsize=(10, 10))
m=np.nanpercentile(avg,95)
trLab=treatName
for an in range(len(treat)):
    for i in range(len(levels)):
        axes[an,i].plot(avgT[an,i,29:31,:].mean(axis=0))
        axes[an,i].set_title('a:'+str(an)+trLab[an]+ 's:'+levels[i][-2:],fontsize=10)
        axes[an,i].axvline(30)