In [1]:
import pandas as pd
import numpy as np
import os
import matplotlib.pyplot as plt
import seaborn as sns
from os.path import join as oj
import pandas as pd
import ipywidgets as widgets
from IPython.display import display, clear_output
import plotly.express as px
import plotly.graph_objects as go

from get_pose_mat_wide import get_pose_mat_wide

In [3]:
root_folder = r'/Users/yang/Documents/Wilbrecht_Lab/data/processed_tracks'
combined_df = pd.read_csv(oj(root_folder, 'combined_df.csv'))

In [4]:
print(combined_df.shape)
combined_df.head()

(1726631, 31)


Unnamed: 0,Head x,Head y,Neck x,Neck y,Torso x,Torso y,Tailhead x,Tailhead y,warped Head x,warped Head y,...,lapIndex,trial,decision,Head_vx,Head_vy,Head_v,angular_velocity,dPhi,animal,session
0,188.623566,28.475878,191.675903,23.529375,195.564896,11.922485,199.491653,6.846012,579.176095,299.356675,...,0,4,,33.297947,-492.902488,494.025926,-8.542783,0.007309,RRM028,Day141
1,188.060455,35.737679,189.072708,28.70846,192.960205,18.967358,196.653687,7.442016,580.286027,282.926592,...,0,4,,29.372719,-487.785702,488.669262,-5.358881,0.006488,RRM028,Day141
2,187.637146,43.289207,188.460831,36.384815,192.20723,24.570004,195.903427,12.627603,581.134276,266.837628,...,0,4,,26.170808,-479.931514,480.644535,-0.88836,0.014482,RRM028,Day141
3,187.621353,48.790691,188.307617,43.796303,191.519562,32.192448,195.331573,20.492044,582.030747,250.931157,...,0,4,,44.096649,-493.561166,495.527133,1.670589,0.045977,RRM028,Day141
4,187.166351,56.842167,187.897903,51.815838,191.288147,39.605297,195.342529,27.651142,584.074053,233.93355,...,0,4,,77.330604,-524.327621,529.999506,2.48968,0.025462,RRM028,Day141


In [9]:
def plot_assigned_tracks(df, condition_func):
    '''
    Plot the trials that satisfy the given condition
    
    PART I: Iterate each row in df, extract trial-wise features and record them into result_df
    '''
    current_trial_num = np.nan
    current_session = np.nan
    current_animal = np.nan
    current_bonsai_decision = np.nan
    current_decision = np.nan
    current_restaurant = np.nan
    bonsai_decision = np.nan
    decision = np.nan
    sum_of_speed = 0
    speed_count = 0
    
    #bonsai_decision_list = []
    decision_list = []
    straight_walking_speed_list =[]
    animal_list = []
    session_list = []
    trial_list = []

    for index, row in df.iterrows():
        bonsai_decision = row['label']
        decision = row['decision']
        trial_num = row['trial']
        animal = row['animal']
        session = row['session']
        speed = np.sqrt(row['Head vx']**2 + row['Head vy']**2)

        restaurant = row['restaurant']
        
    
        if trial_num != current_trial_num: # Start of a new trial
            if speed_count != 0:
                average_speed = sum_of_speed/speed_count
            else:
                average_speed = np.nan

            if current_bonsai_decision == 'collection':
                current_bonsai_decision = 'ACC'
            
            if not np.isnan(current_trial_num):
                if condition_func(current_decision, current_bonsai_decision, average_speed, current_restaurant, current_animal, current_session, current_trial_num):
                    #bonsai_decision_list.append(current_bonsai_decision)
                    decision_list.append(current_decision)
                    straight_walking_speed_list.append(average_speed)
                    animal_list.append(current_animal)
                    session_list.append(current_session)
                    trial_list.append(current_trial_num)

            current_restaurant = restaurant
            current_bonsai_decision = np.nan    
            current_decision = np.nan
            current_trial_num = trial_num 
            current_session = session
            current_animal = animal
            sum_of_speed = 0
            speed_count = 0
        
        if not pd.isna(decision):
            current_decision = decision    
        if not pd.isna(bonsai_decision):
            current_bonsai_decision = bonsai_decision
        
        if pd.isna(current_decision) and (speed < 300):
            sum_of_speed += speed
            speed_count += 1
    

    result_df = pd.DataFrame({
        #'bonsai decision': bonsai_decision_list,
        'decision': decision_list,
        'straight_walking_speed': straight_walking_speed_list,
        'animal': animal_list,
        'session': session_list,
        'trial': trial_list
    })
    plt.hist(straight_walking_speed_list, bins=50)
    plt.show()
    
    '''
    PART II: Iterate through each row in result_df, locate them in df and then plot them.
    '''
    # Extract animal, session, and trial info from the decision dataframe
    decision_info = result_df[['animal', 'session', 'trial']]
    filtered_df = df[
        df[['animal', 'session', 'trial']].apply(tuple, axis=1).isin(decision_info.apply(tuple, axis=1))
    ]
    
    # plot each trial using plotly
    plt.figure(figsize=(10, 10))

    trial_palette = sns.color_palette('deep', n_colors=len(filtered_df['trial'].unique()))
    trial_palette_dict = {trial_num: color for trial_num, color in zip(filtered_df['trial'].unique(), trial_palette)}
    
    decision_palette = sns.color_palette('deep', n_colors=len(result_df['decision'].unique()))
    decision_palette_dict = {decision: color for decision, color in zip(filtered_df['decision'].unique(), decision_palette)}
    decision_palette_dict = {
        "ACC": (0, 1, 0),  # Green with transparency
        "REJ": (1, 0, 0),  # Red with transparency
        "quit": (0, 0, 1),  # Blue with transparency
        "T-Entry": (1, 1, 1), # black
        np.nan: (1, 1, 1) # black
    }
    fig = go.Figure()
    
    # Iterate through each row in result_df
    for _, row in result_df.iterrows():
        # Extract animal, session, and trial information
        decision, animal, session, trial = row['decision'], row['animal'], row['session'], row['trial']
        
        # Filter df for the current trial
        filtered_trial = df[(df['animal'] == animal) & (df['session'] == session) & (df['trial'] == trial)]
        
        # Determine the color for the current trial
        color = 'rgba' + str(tuple(int(c * 255) for c in decision_palette_dict[decision]) + (0.2,))
        
        # Add the trial data to the plot
        fig.add_trace(go.Scatter(
            x=filtered_trial['warped Head x'],
            y=filtered_trial['warped Head y'],
            mode='lines',
            line=dict(color=color, width=0.5),
            showlegend=False
        ))
    
    fig.update_layout(
        xaxis_title='warped Head x',
        yaxis_title='warped Head y',
        yaxis=dict(autorange='reversed'),
        #width = 600,
        #height = 1000
    )
    
    # Adding vertical and horizontal lines and text annotations
    fig.add_vline(x=282, line=dict(color='red', dash='dash', width=1))
    fig.add_annotation(x=285, y=142, text='REJ', showarrow=False, font=dict(color='red'), xanchor='right', yanchor='top')

    fig.add_vline(x=309, line=dict(color='red', dash='dash', width=1))
    fig.add_annotation(x=314, y=142, text='ACC', showarrow=False, font=dict(color='red'), xanchor='right', yanchor='top')

    fig.add_hline(y=46, line=dict(color='red', dash='dash', width=1))
    fig.add_annotation(x=333, y=46, text='T_Entry', showarrow=False, font=dict(color='red'), xanchor='right', yanchor='bottom')

    fig.show()

In [10]:
def interactive_analysis(combined_df, condition_func):
    """
    Use the slider to adjust the start and length of the range of combined_df.
    Then plot the trials that satisfy the given condition
    
    Parameters:
    - combined_df: The DataFrame to analyze.
    - condition_func: A lambda function to determine the condition for plotting.
    """
    
    # Function to update the selected range based on slider values
    def update_range(range_value, interval):
        global selected_range
        global df_subset
        selected_range = [range_value, range_value + interval]
        df_subset = combined_df.iloc[selected_range[0]:selected_range[1]]
        print(f"Selected range: {selected_range}")

    # Create slider widgets
    range_slider = widgets.IntSlider(value=0, min=0, max=len(combined_df), step=100, description='Start:')
    interval_slider = widgets.IntSlider(value=10000, min=0, max=len(combined_df), step=100, description='Length:')
    interactive_widget = widgets.interactive(update_range, range_value=range_slider, interval=interval_slider)
    display(interactive_widget)

    # Create an output widget for the plots
    output = widgets.Output()

    # Button to trigger further analysis
    analyze_button = widgets.Button(description="Plot")

    # Function to handle button click
    def on_analyze_button_clicked(b):
        with output:
            clear_output(wait=True)  # Clear previous plots
            plot_assigned_tracks(df_subset, condition_func)
    

    # Link the button to the handler function
    analyze_button.on_click(on_analyze_button_clicked)

    # Display the button and output widget
    display(analyze_button)
    display(output)
    

In [11]:
condition_func = lambda sleap_decision, bonsai_decision, average_speed, restaurant, animal, session, trial: (
    True
    )

plot_assigned_tracks(combined_df[1:100000], condition_func)

KeyError: 'Head vx'

Straight walking average speed is not affecting bonsai vs sleap decision.
Some implication from the trajectories. see in notion

In [None]:
condition_func = lambda sleap_decision, bonsai_decision, average_speed, restaurant, animal, session, trial: (
    (bonsai_decision == 'quit' and sleap_decision == 'REJ')
    )

interactive_analysis(combined_df, condition_func)

interactive(children=(IntSlider(value=0, description='Start:', max=1643207, step=100), IntSlider(value=10000, …

Button(description='Plot', style=ButtonStyle())

Output()

In [None]:
condition_func = lambda sleap_decision, bonsai_decision, average_speed, restaurant, animal, session, trial: (
    (bonsai_decision == 'ACC' and sleap_decision == 'REJ')
    )

interactive_analysis(combined_df, condition_func)

interactive(children=(IntSlider(value=0, description='Start:', max=1643207, step=100), IntSlider(value=10000, …

Button(description='Plot', style=ButtonStyle())

Output()

In [None]:
condition_func = lambda sleap_decision, bonsai_decision, average_speed, restaurant, animal, session, trial: (
    (bonsai_decision == 'REJ' and sleap_decision == 'ACC')
    )

interactive_analysis(combined_df, condition_func)

interactive(children=(IntSlider(value=0, description='Start:', max=1643207, step=100), IntSlider(value=10000, …

Button(description='Plot', style=ButtonStyle())

Output()