In [3]:
import tifffile
import numpy as np
import glob
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import os
import cv2
import pickle
import json
import re
from datetime import datetime
from operator import itemgetter
from itertools import groupby
import csv
from unidip import UniDip

import scipy
from scipy.signal import butter, filtfilt
from scipy.ndimage import find_objects, gaussian_filter, label
from skimage.filters import apply_hysteresis_threshold

from sklearn import linear_model
from sklearn.metrics import r2_score 
from sklearn.neighbors import KernelDensity


from Brady_Ana_AIO_Functions import *


plt.rcParams.update({'figure.max_open_warning': 0})

# 1. Heart Rate analysis

## 1.1 Draw ROIs in Fiji circling the entire heart
## 1.2 Create Bradyinfo and get HR trace

In [None]:
# Provide Batch name and ROI lists (write ROI in the sequence of the folders)
batch = 'test_batch'
date_list = glob.glob('F:/'+batch + '/*')
fish_dir_list = []
ROI_list = [[(369, 298, 117, 71)],
            [(356, 346, 122, 118),(383, 336, 98, 70),(301, 316, 146, 84),(345, 303, 155, 93),
             (371, 311, 116, 90),(345, 314, 130, 100)],
            [(358, 349, 106, 77), (378, 334, 106, 85),(381, 336, 111, 79),(285, 303, 170, 69),
             (336, 357, 157, 64),(371, 341, 127, 65),(422, 330, 101, 65)]]
for date in date_list:
    fish_dir_list.append(glob.glob(date + '/behavior/F*'))
print(fish_dir_list)

'run single fish'
# create_brady_info_pixelwise(fish_dir_list[0][0]+'/',ROI_list[0][0])
'run entire dataset'
for date in range(0,len(fish_dir_list)):
    for fish in range(0,len(fish_dir_list[date])):
        create_brady_info_pixelwise(fish_dir_list[date][fish]+'/',ROI_list[date][fish])

## 1.3 HR to ceiling calculation and bradycardia identification

In [None]:
median_window_len = 100
max_window_len = 300

main_dir = 'F:/'+batch
bradyinfo_list = glob.glob(main_dir+'/*/behavior/*/Bradyinfo*')
print(bradyinfo_list)

'calculate HR to ceiling and write into bradyinfo'
calculate_HR_to_ceiling(bradyinfo_list,median_window_len,max_window_len)

'identify bradycardia episode for each fish and write into bradyinfo'
identify_bradycardia(bradyinfo_list,3)


# 2. PC Detection

## 2.1 Use Ztrack to track the eye and tail angles
## 2.2 Use eye angles of each fish to fit bimodal distribution and find threshold

In [None]:
# find all fish folders
batch = 'test_batch'
date_list = glob.glob('F:/'+batch + '/*')
fish_dir_list = []
Threshold_list = []

for date in date_list:
    fish_dir_list.append(glob.glob(date + '/behavior/F*'))
print(fish_dir_list)

#First check the bimodel distribution of eye angles and find a threshold
for date in  range(0,len(fish_dir_list)):
    for fish in range(0,len(fish_dir_list[date])):
        eye_angle_df, threshold = get_bimodel_distribution(fish_dir_list[date][fish])
        Threshold_list.append(threshold)
print(Threshold_list)

In [None]:
## the second step is to manually check all threshold and correct the wrong ones
# Threshold_list = [[30,14,16],[17,12,19,8,15]]

for date in range(0,len(fish_dir_list)):
    for fish in range(0,len(fish_dir_list[date])):
        bradyinfo_path = glob.glob(fish_dir_list[date][fish] + '/Brady*')
        bradyinfo = pd.read_excel(bradyinfo_path[0])
        heart_rate_dataframe= pd.read_excel(bradyinfo_path[0], sheet_name='heart_rate_trace')
        HR_to_ceiling = pd.read_excel(bradyinfo_path[0],sheet_name = 'HR_to_ceiling')
        threshold = Threshold_list[date][fish]
        draw_threshold(fish_dir_list[date][fish],threshold,bandwidth = 2)
        
        bradyinfo['PC_Threshold'] = [threshold] * bradyinfo.shape[0]
        
        writer = pd.ExcelWriter(bradyinfo_path[0], engine='xlsxwriter')
        bradyinfo.to_excel(writer, sheet_name='Bradyinfo',index = False)
        heart_rate_dataframe.to_excel(writer, sheet_name='heart_rate_trace',index = False)
        HR_to_ceiling.to_excel(writer, sheet_name='HR_to_ceiling',index = False)
        writer.save()

## 2.3 Detect PC in each trials

In [None]:
#the thrid step is to use threshold to find PC trials
for date in  range(0,len(fish_dir_list)):
    for fish in range(0,len(fish_dir_list[date])):
        bradyinfo_path = glob.glob(fish_dir_list[date][fish] + '/Brady*')
        bradyinfo = pd.read_excel(bradyinfo_path[0])
        heart_rate_dataframe= pd.read_excel(bradyinfo_path[0], sheet_name='heart_rate_trace')
        HR_to_ceiling = pd.read_excel(bradyinfo_path[0],sheet_name = 'HR_to_ceiling')
        
        PC_trial, PC_interval = get_pc_trial(fish_dir_list[date][fish],bradyinfo.PC_Threshold[0])
        bradyinfo['Eye_Convergence_bouts'] = PC_interval
    
        #3.4 Store PC result in Bradyinfo
        writer = pd.ExcelWriter(bradyinfo_path[0], engine='xlsxwriter')
        bradyinfo.to_excel(writer, sheet_name='Bradyinfo',index = False)
        heart_rate_dataframe.to_excel(writer, sheet_name='heart_rate_trace',index = False)
        HR_to_ceiling.to_excel(writer, sheet_name='HR_to_ceiling',index = False)
        
        writer.save()
        writer.close()

# 3 Tail Curvature Calculation and Bout Detection, write into boutinfo

In [None]:
# Provide Batch name
batch = 'test_batch'
date_list = glob.glob('F:/'+batch + '/*')
fish_dir_list = []

for date in date_list:
    fish_dir_list.append(glob.glob(date + '/behavior/F*'))

'run single fish'
# Bout_Reader(fish_dir_list[0][0]+ '\\',duration=2200, low_thresh=0.01, high_thresh=0.2, sigma_angles=0, sigma=1, num_points=5,bout_sigma=0.5,bout_threshold=4)
'run all fish not recommended'
for date in range(0,len(fish_dir_list)):
    for fish in range(0,len(fish_dir_list[date])):
        
        Bout_Reader(fish_dir_list[date][fish]+ '\\',duration=4400, low_thresh=0.008, high_thresh=0.04, sigma_angles=0, sigma=1, num_points=10,bout_sigma=1.2,bout_threshold=1)


# 4 Escape Classification

In [None]:
# load clustering and pca model
pca = pk.load(open('bout_clustering_pca.pkl','rb'))
kmeans = pk.load(open('bout_clustering_kmeans.pkl','rb'))
scaler = pk.load(open('bout_clustering_scaler.pkl','rb'))

## 4.1 identify escape bouts by PCA and k-means Clustering

In [None]:
#collect all bouts and make a dataframe
batch = 'test_batch'
date_list = glob.glob('F:\\'+batch + '\\*')
fish_dir_list = []

for date in date_list:
    fish_dir_list.append(glob.glob(date + '\\behavior\\F*'))
print(fish_dir_list)

bout_collection_df = pd.DataFrame()

for date in range(0,len(fish_dir_list)):
    # may need to be modified when directory name changed
    date_no = fish_dir_list[date][0].split('\\')[-3]
    for fish in range(0,len(fish_dir_list[date])):
        # may need to be modified when directory name changed
        fish_str = fish_dir_list[date][fish].split('\\')[-1]
        print(date,fish)
        #load info
        vsinfo_dir = glob.glob(fish_dir_list[date][fish] + '/vsinfo*')
        vsinfo = pd.read_excel(vsinfo_dir[0])
        expinfo = pd.read_excel(vsinfo_dir[0], sheet_name='ExpInfo')
        boutinfo_path = glob.glob(fish_dir_list[date][fish] + '/Boutsinfo*')
        boutinfo = pd.read_excel(boutinfo_path[0],sheet_name = 'boutinfo')

        #get all bouts
        for t in range(0,boutinfo.shape[0]):
            if vsinfo['Left_Stimulus_Type'][t]=='n' and vsinfo['Right_Stimulus_Type'][t]=='n':
                temp_type = 'ns'
            elif vsinfo['Left_Stimulus_Type'][t]=='l' or vsinfo['Right_Stimulus_Type'][t]=='l':
                temp_type = 'looming'
            elif vsinfo['Left_Color'][t]=='UV' or vsinfo['Right_Color'][t]=='UV':
                temp_type = 'dot4'
            elif vsinfo['Left_Stimulus_Size'][t]+vsinfo['Right_Stimulus_Size'][t]==15:
                temp_type = 'dot15'   
                
            if len(eval(boutinfo.Swim_Bouts[t])) > 0:
                for b in range(0,len(eval(boutinfo.Swim_Bouts[t]))):
                    bout_collection_df = bout_collection_df.append({'Fish_Index': fish_str,
                                               'Trial_Index': int(boutinfo['Trial'][t]),
                                               'Trial_Type': temp_type,
                                               'Swim_Bout': eval(boutinfo.Swim_Bouts[t])[b],
                                               'Amplitude_Tip_Angle': eval(boutinfo.Swim_Bouts_Amplitude_Tip_Angle[t])[b],
                                               'Amplitude_Middle_Angle': eval(boutinfo.Swim_Bouts_Amplitude_Middle_Angle[t])[b],
                                               'Amplitude_Curvature': eval(boutinfo.Swim_Bouts_Amplitude_Curvature[t])[b],
                                               'Avg_Velocity': eval(boutinfo.Swim_Bouts_Avg_Velocity[t])[b],
                                               'Max_Velocity': eval(boutinfo.Swim_Bouts_Max_Velocity[t])[b],                                                                    
                                               'Frequency': eval(boutinfo.Swim_Bouts_Frequency[t])[b],
                                               'Time': eval(boutinfo.Swim_Bouts_Time[t])[b],                        
                                               'Integral': eval(boutinfo.Swim_Bouts_Integral[t])[b],
                                                'Date': date_no},ignore_index = True)               
bout_collection_df['Log_Time'] = np.log10(bout_collection_df['Time'])
bout_collection_df

In [None]:
df_feature = bout_collection_df.iloc[:,[0,1,2,3,6,7,8,13]]
df_norm_feature = scaler.transform(df_feature)
principalComponents = pca.transform(df_norm_feature)
print(pca.explained_variance_ratio_)

labels = kmeans.predict(principalComponents[:,:5])

## Plots
# fig, (ax1,ax2,ax3) = plt.subplots(1, 3, figsize=(16, 6))

# scatter = ax1.scatter(principalComponents[:,0],principalComponents[:,1],c=labels)
# ax1.set_xlabel('PC1',fontsize = 12)
# ax1.set_ylabel('PC2',fontsize = 12)
# legend1 = ax1.legend(*scatter.legend_elements(),
#                     loc="lower left")

# ax2.scatter(principalComponents[:,0],principalComponents[:,2],c=labels)
# ax2.set_xlabel('PC1',fontsize = 12)
# ax2.set_ylabel('PC3',fontsize = 12)

# ax3.scatter(principalComponents[:,1],principalComponents[:,2],c=labels)
# ax3.set_xlabel('PC2',fontsize = 12)
# ax3.set_ylabel('PC3',fontsize = 12)
# ax3.legend()
# plt.show()

bout_collection_df['Label'] = labels

writer = pd.ExcelWriter('F:\\'+batch+'\\all_bouts_info.xlsx', engine='xlsxwriter')
bout_collection_df.to_excel(writer, index=False)
writer.save()
writer.close()

## 4.2 annotate escape bouts and and spontaneous bouts, and write into bradyinfo

In [None]:
main_dir = 'F:\\'+batch
bradyinfo_list = glob.glob(main_dir+'/*/behavior/*/Bradyinfo*')
bout_collection_df_dir = main_dir+'\\all_bouts_info.xlsx'
bout_collection_df = pd.read_excel(bout_collection_df_dir)
for bradyinfo_dir in bradyinfo_list:
    fish = bradyinfo_dir.split('\\')[-2]
    date = bradyinfo_dir.split('\\')[-4]
    print(date+' '+fish)
    bradyinfo = pd.read_excel(bradyinfo_dir,sheet_name = 'Bradyinfo')
    bradyinfo['Escape_bouts'] = '[]'
    bradyinfo['Spontaneous_bouts'] = '[]'
    HR_to_ceiling = pd.read_excel(bradyinfo_dir,sheet_name = 'HR_to_ceiling')
    heart_rate_trace = pd.read_excel(bradyinfo_dir,sheet_name = 'heart_rate_trace')
    
    bout_collection_df_fish = bout_collection_df[np.logical_and(bout_collection_df['Date']==int(date),bout_collection_df['Fish_Index']==fish)]
    for t in range(0,bradyinfo.shape[0]):
        temp_escape_bouts = []
        temp_spon_bouts = []
        trial_idx = t+1
        bout_collection_df_fish_trial_escape = bout_collection_df_fish[np.logical_and(bout_collection_df_fish['Trial_Index']==trial_idx,bout_collection_df_fish['Label']==1)]
        for swim_bout in bout_collection_df_fish_trial_escape.Swim_Bout:
            temp_escape_bouts.append(eval(swim_bout))
        bradyinfo['Escape_bouts'][t] = temp_escape_bouts
        
        bout_collection_df_fish_trial_spon = bout_collection_df_fish[np.logical_and(bout_collection_df_fish['Trial_Index']==trial_idx,bout_collection_df_fish['Label']!=1)]
        for swim_bout in bout_collection_df_fish_trial_spon.Swim_Bout:
            temp_spon_bouts.append(eval(swim_bout))
        bradyinfo['Spontaneous_bouts'][t] = temp_spon_bouts
    writer = pd.ExcelWriter(bradyinfo_dir, engine='xlsxwriter')
    bradyinfo.to_excel(writer, sheet_name='Bradyinfo',index = False)
    heart_rate_trace.to_excel(writer, sheet_name='heart_rate_trace',index = False)
    HR_to_ceiling.to_excel(writer, sheet_name='HR_to_ceiling',index = False)
    writer.save()
    writer.close()  