# Behavior classification

Classify fly behavior using pretrained models

### Brandon Pratt, 11/17/2024

In [1]:
# python libraries
import pathlib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn
import pickle 
from sklearn import preprocessing

In [23]:
# load data and model
filename = 'grant--rv17-HP-R48A07AD;R20C06DBD-gal4xUAS-gtACR1_processed_v2.pq'
data_path = pathlib.Path.cwd().parent.joinpath('data', filename)
data = pd.read_parquet(data_path, engine='pyarrow')

# Get behavior prediction column to dataset
beh_col = pd.DataFrame(np.nan, index=np.arange(len(data)).tolist(), columns=['behavior_predictions'])
data = data.join(beh_col) # add a behavior column initialized with NaNs to data 

In [10]:
# features for prediction
features = [
 'fictrac_delta_rot_lab_x_mms',
 'fictrac_delta_rot_lab_y_mms',
 'fictrac_delta_rot_lab_z_mms',
 'L1A_rot',
 'R1A_rot',
 'L2A_rot',
 'R2A_rot',
 'L3A_rot',
 'R3A_rot',
 'L1B_rot',
 'R1B_rot',
 'L2B_rot',
 'R2B_rot',
 'L3B_rot',
 'R3B_rot',
 'L1C_rot',
 'R1C_rot',
 'L2C_rot',
 'R2C_rot',
 'L3C_rot',
 'R3C_rot',
 'L1A_flex',
 'R1A_flex',
 'L2A_flex',
 'R2A_flex',
 'L3A_flex',
 'R3A_flex',
 'L1B_flex',
 'R1B_flex',
 'L2B_flex',
 'R2B_flex',
 'L3B_flex',
 'R3B_flex',
 'L1C_flex',
 'R1C_flex',
 'L2C_flex',
 'R2C_flex',
 'L3C_flex',
 'R3C_flex',
 'L1D_flex',
 'R1D_flex',
 'L2D_flex',
 'R2D_flex',
 'L3D_flex',
 'R3D_flex',
 'L1A_abduct',
 'R1A_abduct',
 'L2A_abduct',
 'R2A_abduct',
 'L3A_abduct',
 'R3A_abduct',
 'L1A_rot_d1',
 'L1A_rot_d2',
 'R1A_rot_d1',
 'R1A_rot_d2',
 'L2A_rot_d1',
 'L2A_rot_d2',
 'R2A_rot_d1',
 'R2A_rot_d2',
 'L3A_rot_d1',
 'L3A_rot_d2',
 'R3A_rot_d1',
 'R3A_rot_d2',
 'L1B_rot_d1',
 'L1B_rot_d2',
 'R1B_rot_d1',
 'R1B_rot_d2',
 'L2B_rot_d1',
 'L2B_rot_d2',
 'R2B_rot_d1',
 'R2B_rot_d2',
 'L3B_rot_d1',
 'L3B_rot_d2',
 'R3B_rot_d1',
 'R3B_rot_d2',
 'L1C_rot_d1',
 'L1C_rot_d2',
 'R1C_rot_d1',
 'R1C_rot_d2',
 'L2C_rot_d1',
 'L2C_rot_d2',
 'R2C_rot_d1',
 'R2C_rot_d2',
 'L3C_rot_d1',
 'L3C_rot_d2',
 'R3C_rot_d1',
 'R3C_rot_d2',
 'L1A_flex_d1',
 'L1A_flex_d2',
 'R1A_flex_d1',
 'R1A_flex_d2',
 'L2A_flex_d1',
 'L2A_flex_d2',
 'R2A_flex_d1',
 'R2A_flex_d2',
 'L3A_flex_d1',
 'L3A_flex_d2',
 'R3A_flex_d1',
 'R3A_flex_d2',
 'L1B_flex_d1',
 'L1B_flex_d2',
 'R1B_flex_d1',
 'R1B_flex_d2',
 'L2B_flex_d1',
 'L2B_flex_d2',
 'R2B_flex_d1',
 'R2B_flex_d2',
 'L3B_flex_d1',
 'L3B_flex_d2',
 'R3B_flex_d1',
 'R3B_flex_d2',
 'L1C_flex_d1',
 'L1C_flex_d2',
 'R1C_flex_d1',
 'R1C_flex_d2',
 'L2C_flex_d1',
 'L2C_flex_d2',
 'R2C_flex_d1',
 'R2C_flex_d2',
 'L3C_flex_d1',
 'L3C_flex_d2',
 'R3C_flex_d1',
 'R3C_flex_d2',
 'L1D_flex_d1',
 'L1D_flex_d2',
 'R1D_flex_d1',
 'R1D_flex_d2',
 'L2D_flex_d1',
 'L2D_flex_d2',
 'R2D_flex_d1',
 'R2D_flex_d2',
 'L3D_flex_d1',
 'L3D_flex_d2',
 'R3D_flex_d1',
 'R3D_flex_d2',
 'L1A_abduct_d1',
 'L1A_abduct_d2',
 'R1A_abduct_d1',
 'R1A_abduct_d2',
 'L2A_abduct_d1',
 'L2A_abduct_d2',
 'R2A_abduct_d1',
 'R2A_abduct_d2',
 'L3A_abduct_d1',
 'L3A_abduct_d2',
 'R3A_abduct_d1',
 'R3A_abduct_d2',
 'L1_smoothed_velo',
 'L2_smoothed_velo',
 'L3_smoothed_velo',
 'R1_smoothed_velo',
 'R2_smoothed_velo',
 'R3_smoothed_velo',
]

In [25]:
# load best performing model
model_name = '20241116_random_forest.pkl'
model_path = pathlib.Path.cwd().parent.joinpath('models', model_name)
model = pickle.load(open(model_path, 'rb')) 

In [46]:
# Decipher behavior labels and names
label_file = 'CxHP8_gtACR1_grooming_trials_updated_annotations.pq'
label_path = pathlib.Path.cwd().parent.joinpath('data')
label_data = pd.read_parquet(label_path.joinpath(label_file), engine='pyarrow')

# create an other category (need this for truly assessing the dataset)
label_data.loc[np.isnan(label_data['behavior_id'].values), 'behavior_annotations'] = 'other'
label_data.loc[np.isnan(label_data['behavior_id'].values), 'behavior_id'] = 7

# label and behavior name
beh_ids, idx = np.unique(label_data['behavior_id'], return_index=True)
behavior_labels = label_data['behavior_annotations'][idx].values

In [47]:
behavior_labels

array(['abdomen_grooming', 'antennae_grooming', 'ball_push', 'standing',
       't1_grooming', 't3_grooming', 'walking', 'other'], dtype=object)

In [56]:
# iterate through each trial and predict behaviors
trials = data['fullfile'].unique()
for i, trial in enumerate(trials):
    trial_id = data['fullfile'] == trial
    trial_data = data[features][trial_id]
    
    # Only perform classification if all feature elements are real numbers
    if ~np.any(np.isnan(trial_data)):
        scaler = preprocessing.StandardScaler().fit(trial_data)
        transform_data = scaler.transform(trial_data)
        beh_preds = model.predict(transform_data)  # behavior predictions

        # add the behavior predictions to the dataset
        data.loc[trial_id, 'behavior_predictions'] = behavior_labels[beh_preds]

### Save dataset with behavior predictions

In [63]:
new_filename = f'{filename[0:-3]}_behavior_predictions.pq'
data.to_parquet(data_path.parent.joinpath(new_filename))

### Load dataset and assess behavioral predictions

In [4]:
filename = 'grant--rv17-HP-R48A07AD;R20C06DBD-gal4xUAS-gtACR1_processed_v2.pq'
new_filename = f'{filename[0:-3]}_behavior_predictions.pq'
data_path = pathlib.Path.cwd().parent.joinpath('data', filename)
data = pd.read_parquet(data_path.parent.joinpath(new_filename), engine='pyarrow')

In [208]:
data['behavior_predictions'].unique()

array(['other', None, 't1_grooming', 'walking', 'antennae_grooming',
       't3_grooming', 'abdomen_grooming', 'standing', 'ball_push'],
      dtype=object)

In [232]:
# filter for trials with t1 grooming
grooming_data = data[data['behavior_predictions'] == 'walking']

In [211]:
def behavior_filter(df, threshold):
    df_indices = df.index # df['fnum'].values
    frame_0 = df_indices[0]
    curr_start = frame_0
    for i, frame in enumerate(df_indices[1::]):
        if frame != (df_indices[i] + 1):
            n_frames = df_indices[i] - curr_start + 1
            if n_frames <= threshold:
                f_end = df_indices[i]
                f_start = curr_start
                df.loc[np.arange(f_start, f_end + 1).astype(int), 'behavior_predictions'] = 'other'
            curr_start = frame

        if i == (len(df_indices[1::]) - 1):  # reached the end
            n_frames = frame - curr_start + 1
            if n_frames <= threshold:
                f_start = curr_start
                f_end = frame
                df.loc[np.arange(f_start, f_end + 1).astype(int), 'behavior_predictions'] = 'other'
    return df

In [200]:
df = grooming_data[grooming_data['fullfile'] == '4.11.24|Fly 1_0|04112024_fly1_0 R1C1  str-cw-0 sec']
threshold = frame_threshold
df_indices = df.index # df['fnum'].values
frame_0 = df_indices[0]
curr_start = frame_0
for i, frame in enumerate(df_indices[1::]):
    if frame != (df_indices[i] + 1):
        n_frames = df_indices[i] - curr_start + 1
        print(n_frames)
        if n_frames <= threshold:
            f_end = df_indices[i]
            f_start = curr_start
            df.loc[np.arange(f_start, f_end + 1).astype(int), 'behavior_predictions'] = 'other'
        curr_start = frame

    if i == (len(df_indices[1::]) - 1):  # reached the end
        n_frames = frame - curr_start + 1
        print(n_frames)
        if n_frames <= threshold:
            f_start = curr_start
            f_end = frame
            df.loc[np.arange(f_start, f_end + 1).astype(int), 'behavior_predictions'] = 'other'

15
15
4
1


In [233]:
# filter out short grooming periods
frame_threshold = 0.2 * 300  # 200ms of grooming
grooming_trials = grooming_data['fullfile'].unique()
for trial in grooming_trials:
    trial_df = grooming_data[grooming_data['fullfile'] == trial]
    filt_df = behavior_filter(trial_df, frame_threshold)
    
    # incorporate adjusted labels
    grooming_data.iloc[grooming_data['fullfile'] == trial] = filt_df

KeyboardInterrupt: 

In [230]:
filt_grooming_data = grooming_data[grooming_data['behavior_predictions'] == 'walking']

In [231]:
filt_grooming_data['fullfile'].unique()

array(['4.9.24|Fly 2_0|04092024_fly2_0 R5C12  rot-cw-1 sec',
       '4.5.24|Fly 3_0|04052024_fly3_0 R1C18  rot-ccw-1 sec',
       '4.11.24|Fly 3_0|04112024_fly3_0 R3C15  rot-cw-1 sec',
       '4.11.24|Fly 5_0|04112024_fly5_0 R2C9  str-ccw-1 sec'],
      dtype=object)

In [191]:
grooming_data[grooming_data['fullfile'] == '4.11.24|Fly 1_0|04112024_fly1_0 R1C1  str-cw-0 sec']

Unnamed: 0,L1A_rot,R1A_rot,L2A_rot,R2A_rot,L3A_rot,R3A_rot,L1B_rot,R1B_rot,L2B_rot,R2B_rot,...,R2E_y_phase,L3E_y_phase,R3E_y_phase,L1E_z_phase,R1E_z_phase,L2E_z_phase,R2E_z_phase,L3E_z_phase,R3E_z_phase,behavior_predictions
32,230.636664,306.884555,178.949858,25.781308,132.396937,49.594156,122.701879,201.338548,136.011967,232.415274,...,1.955305,-0.051379,0.930594,0.051219,1.090811,-2.93978,0.9465,-2.548955,-1.985314,other
33,234.025804,306.190198,178.900343,24.816731,132.47955,49.807845,123.755034,201.572384,136.140203,230.659286,...,2.03007,-0.044754,0.989606,0.283889,1.097944,-3.066811,1.093775,-2.552144,-2.00655,other
34,238.915865,305.732664,178.766504,24.245223,132.578222,49.964326,125.35274,202.420384,135.954882,228.793715,...,2.083408,-0.026444,1.053365,0.480951,1.089422,3.131005,1.206511,-2.530851,-2.083631,other
35,244.883495,305.817052,178.499883,23.863651,132.621034,50.058445,126.874913,203.887949,135.625213,227.235131,...,2.125325,-0.008348,1.09825,0.588181,1.108223,3.059649,1.283993,-2.538456,-2.305299,other
36,251.51631,306.332993,178.204962,23.458004,132.534314,50.091395,129.222384,205.585233,135.014423,225.846523,...,2.154878,0.065129,1.137369,0.624004,1.161272,3.052623,1.339439,-2.563711,-2.557141,other
37,257.837386,307.047268,177.829867,23.047962,132.223238,50.172691,131.799755,207.347155,134.304493,224.845649,...,2.182678,0.137844,1.144394,0.634815,1.21491,3.05757,1.370963,-2.663661,-2.780818,other
38,263.143253,307.873003,177.312264,22.640841,131.745668,50.323264,134.277466,209.053036,133.439919,223.823615,...,2.200301,0.217743,1.122754,0.710405,1.201416,3.101568,1.379483,-2.772904,-2.78761,other
39,267.448724,308.734608,176.878367,22.383436,131.30244,50.465411,136.630368,210.455571,132.993068,222.760411,...,2.214432,0.248228,1.049137,0.836655,1.110208,-3.136053,1.356947,-2.8912,-2.758993,other
40,271.196885,309.284478,176.713811,22.235639,130.889719,50.602937,139.470757,211.183968,133.366286,221.94291,...,2.213387,0.258547,0.951837,1.012453,1.012404,-3.065756,1.31647,-2.962222,-2.640331,other
41,275.251664,309.370094,176.768954,22.174264,130.645906,50.651608,143.755246,211.203276,134.346495,221.483364,...,2.206858,0.224093,0.832284,1.209813,0.972703,-3.005038,1.265611,-3.038344,-2.575984,other
