In [None]:

#Copyright (c) 2019 Primoz Ravbar UCSB Licensed under BSD 2-Clause [see LICENSE for details] Written by Primoz Ravbar

#This script will train a LDA model with ST-features and human labels. It will next extract behaviors (produce ethograms) from all 
# ST-feature files in the folder specified. 

import numpy as np
import scipy
from scipy import ndimage
from scipy import misc
import pickle
import pandas as pd
import time
import matplotlib.pyplot as plt
import cv2
import os

from ABRS_modules import discrete_radon_transform
from ABRS_modules import etho2ethoAP
from ABRS_modules import smooth_1d
from ABRS_modules import create_LDA_training_dataset
from ABRS_modules import removeZeroLabelsFromTrainingData
from ABRS_modules import computeSpeedFromPosXY 


from sklearn.discriminant_analysis import LinearDiscriminantAnalysis

from ABRS_data_vis import create_colorMat
from ABRS_data_vis import cmapG
from ABRS_data_vis import cmapAP


pathToABRS_GH_folder = 'INSERT PATH TO ABRS MAIN FOLDER HERE'

dirPathFeatures = 'INSERT PATH TO FOLDER CONTAINING ST-FEATURES HERE'

dirPathLabel = pathToABRS_GH_folder + '\\Labels' #PATH TO FOLDER CONTAINING THE LABEL FILES

idxLabelDirPathFileName = dirPathLabel + '\\' + 'combLabelLiplusCS7fb3' #NAME OF THE LABEL FILE


dirPathLabelFeatures = pathToABRS_GH_folder + '\\LabelsFeatures\STF_comb' #PATH TO ST-FEATURES FOR TRAINING (MUST CORRESPOND TO THE LABEL)

numbFilesFeatureLabel = 7; #number of ST feature files used for LDA training

outputFolderEtho = pathToABRS_GH_folder + '\\Etho';

smoothingWindow = 89; #89 works for fly grooming behavior at 30 Hz
halfWindowSpeed = 15; #16 or 15 for 30 Hz #adjust the time-window to extract speed of body displacement


labelShift = 10; # if the label is not fully aligned with the ST-features, shift it; this is a patch
#labelShift = 0; # 10 works with auto

labelShiftVal = 10; # 

with open(idxLabelDirPathFileName, "rb") as f:
     idxLabel = pickle.load(f)
        
with open(dirPathLabel + '\\' + 'labelCS1fb1_SS', "rb") as f:
	 idxLabelVal = pickle.load(f)   
        
shLVal = np.shape(idxLabelVal);
labelShftRightVal = np.hstack((np.zeros((1,labelShiftVal)),idxLabelVal[:,0:shLVal[1]-labelShiftVal])); # works with janelia data 11/16/2018 # shift 15 works too
idxLabelVal = labelShftRightVal;          

shL = np.shape(idxLabel);
labelShftRight = np.hstack((np.zeros((1,labelShift)),idxLabel[:,0:shL[1]-labelShift])); # works with janelia data 11/16/2018 # shift 15 works too
idxLabel = labelShftRight;  
idxLabel[idxLabel==0]=7

idxLabelAP=etho2ethoAP(idxLabel)

featureMat, posMat, maxMovementMat = create_LDA_training_dataset (dirPathLabelFeatures,numbFilesFeatureLabel); #

speedVect = computeSpeedFromPosXY (posMat,halfWindowSpeed)

featureMatSm = smooth_1d (featureMat, smoothingWindow, axis = 1);
speedVectSm = smooth_1d (speedVect, smoothingWindow, axis = 1);

inputMatComb = np.vstack((featureMat,speedVect));
inputMatAP = np.vstack((featureMatSm,speedVectSm));

x_train = np.transpose(inputMatComb)
y_train = idxLabel[0,:]

model = LinearDiscriminantAnalysis();
model.fit(x_train, y_train);
    
#prX = model.transform(x_train);
prX = model.predict_proba(x_train);

###########################
x_trainAP = np.transpose(inputMatAP)
y_trainAP = idxLabelAP[0,:]

modelAP = LinearDiscriminantAnalysis();
modelAP.fit(x_trainAP, y_trainAP);

prAP = modelAP.predict_proba(x_trainAP);

###########################

x_trainComb = smooth_1d(np.hstack((prAP,prX)),9,axis = 0)
y_trainComb = idxLabel[0,:]

modelComb = LinearDiscriminantAnalysis();
modelComb.fit(x_trainComb, y_trainComb);

prComb = modelComb.predict_proba(x_trainComb);

##############################################

fileList = sorted(os.listdir(dirPathFeatures));
sh = np.shape(fileList);
numbFiles = sh[0] #LOOP AROUND ALL ST-FEATURE FILES (NUMBER OF ETHOGRAMS)
#numbFiles = 1; # just the first STF file
#numbFiles = -5; # just run the training


for fl in range(0, numbFiles, 1):

    inputFileName = fileList[fl];
    print (inputFileName);

    featureMatDirPathFileName = dirPathFeatures + '\\' + inputFileName;

    with open(featureMatDirPathFileName, "rb") as f:
         STF_30_posXY_dict = pickle.load(f)
         
         
    featureMatCurrent =  STF_30_posXY_dict["featureMat"];
    
    posMat = STF_30_posXY_dict["posMat"];
    maxMovementMat = STF_30_posXY_dict["maxMovementMat"];
    featureMatSm = smooth_1d (featureMatCurrent, smoothingWindow, axis = 1);
    
    speedVect = computeSpeedFromPosXY (posMat,halfWindowSpeed);
    speedVectSm = smooth_1d (speedVect, smoothingWindow, axis = 1);
    
    inputMatCombCurrent = np.vstack((featureMatCurrent,speedVect));
    inputMatAPCurrent = np.vstack((featureMatSm,speedVectSm));


    if fl == 0:
        
        speedVectRec = speedVect
        maxMovementRec = maxMovementMat
        
        inputMatComb = inputMatCombCurrent;
        inputMatAP = inputMatAPCurrent;
        
        x_pred = np.transpose(inputMatComb);
        
        x_predAP = np.transpose(inputMatAP);      
        
        predictions = model.predict_proba(x_pred)
        predictionsAP = modelAP.predict_proba(x_predAP)
        
        x_predComb = smooth_1d(np.hstack((predictionsAP,predictions)),9,axis=0)
        
        predictionsComb = modelComb.predict_proba(x_predComb);
                                                          
    if fl > 0:

        speedVectRec = np.hstack((speedVectRec,speedVect))
        maxMovementRec = np.hstack((maxMovementRec,maxMovementMat))
        
        inputMatComb = inputMatCombCurrent
        inputMatAP = inputMatAPCurrent
        
        x_pred = np.transpose(inputMatComb);
        
        x_predAP = np.transpose(inputMatAP);
        
        predictionsCurrent = model.predict_proba(x_pred)
        predictionsAPCurrent = modelAP.predict_proba(x_predAP)
        
        predictions = np.vstack((predictions,predictionsCurrent))
        predictionsAP = np.vstack((predictionsAP,predictionsAPCurrent))
     
        x_predComb = smooth_1d(np.hstack((predictionsAPCurrent,predictionsCurrent)),9,axis=0)
        
        predictionsCombCurrent = modelComb.predict_proba(x_predComb);
        
        predictionsComb = np.vstack((predictionsComb,predictionsCombCurrent))
        

predictions=np.transpose(predictions)
predictionsAP=np.transpose(predictionsAP)
predictionsComb=np.transpose(predictionsComb)
        
thrX = smooth_1d (predictions, 89, axis = 1)
thrComb = smooth_1d (predictionsComb, 89, axis = 1)


thrX_left = np.hstack((thrX[:,30:np.shape(thrX)[1]], np.zeros((np.shape(thrX)[0],30)) )) 
thrX = thrX_left

thrComb_left = np.hstack((thrComb[:,30:np.shape(thrComb)[1]], np.zeros((np.shape(thrComb)[0],30)) )) 
thrComb = thrComb_left

     

In [None]:
# this code applies user-specified thresholds to the LDA features to produce ethograms

protocolName = 7 

shPred = np.shape(predictions);

etho = np.zeros((1,shPred[1]));
ethoAP = np.zeros((1,shPred[1]));
lengthEtho=shPred[1]

if protocolName == 7:


    for i in range(0, lengthEtho):
        
        if (predictionsComb[1,i]-predictionsComb[0,i])<(thrComb[1,i]-thrComb[0,i]) and (predictionsAP[1,i]-predictionsAP[2,i])>0:   
            etho[0,i] = 1
        
        if (predictionsComb[1,i]-predictionsComb[0,i])>(thrComb[1,i]-thrComb[0,i]) and (predictionsAP[1,i]-predictionsAP[2,i])>0:   
            etho[0,i] = 2
            
        if (predictionsComb[3,i]-predictionsComb[2,i])<(thrComb[3,i]-thrComb[2,i]) and (predictionsAP[1,i]-predictionsAP[2,i])<0:   
            etho[0,i] = 3
        
        if (predictionsComb[3,i]-predictionsComb[2,i])>(thrComb[3,i]-thrComb[2,i]) and (predictionsAP[1,i]-predictionsAP[2,i])<0:    
            etho[0,i] = 4   
            
        if (predictionsComb[4,i])>0.01 and (predictionsComb[4,i]-predictionsComb[3,i])>0 and (predictionsAP[1,i]-predictionsAP[2,i])<0 and \
            maxMovementRec[0,i]>100:    
            etho[0,i] = 5    
                     
        if (predictionsAP[3,i]+predictionsComb[5,i]>1.0 or speedVectRec[0,i]>25) and ethoAP[0,i]==0: 
            etho[0,i] = 6
            
        if  maxMovementRec[0,i]<100:    
            etho[0,i] = 7   
              
sizeEtho = 50000            
ethoZ = etho;ethoZ[0,5]=7;ethoZ[0,6]=0  

ethoMat = np.reshape(etho[0,:],(numbFiles,sizeEtho));  
            
plt.matshow(ethoMat, interpolation=None, aspect='auto',cmap=cmapG);plt.show()  