In [6]:
from transform_coordinates import transform_coordinates
from smooth_data import smooth_data
from append_time_info import append_time_info
import os
import pandas as pd

In [7]:
def label_decision(df):
    """
    Updates the DataFrame in real time, adding a new column with the current state
    (T-entry, Acc, Rej, quit) for each timepoint.

    Parameters:
    df (pandas.DataFrame): The input DataFrame containing the smoothed data.
                           It is assumed to have 'warped Head x' and 'warped Head y' columns.

    Returns:
    pandas.DataFrame: A DataFrame with the decision for each timepoint.
    """
    current_trial = None
    decisions = []
    decision = None
    last_decision = None
    
    for index, row in df.iterrows():
        x = row['warped Head x']
        y = row['warped Head y']
        trial = row['trial']
        
        if current_trial != trial: # 0>_000_0 Start of a new trial
            decision = None
            last_decision = None
            current_trial = trial

        if (decision == None):
            if y < 46: # entering T-entry
                decision = 'T-Entry'

        if decision == 'T-Entry':
            if x > 309:
                decision = 'ACC'
            elif x < 282:
                decision = 'REJ'

        if decision == 'ACC':
            if x < 282:
                decision = 'quit'

        # Only append the decision if it has changed, otherwise append None
        if decision != last_decision:
            decisions.append(decision)
            last_decision = decision
        else:
            decisions.append(None)     
    
    # Add the decisions column to the DataFrame
    df['decision'] = decisions

    return df


In [10]:
folder = '/Users/yang/Documents/Wilbrecht_Lab/sleap_video/RRM026'
filename = 'RRM026_Day151_R1_tracks_raw.csv'

input_file = os.path.join(folder, filename)

columns_to_smooth = [
        'warped Head x', 'warped Head y',
        'warped Neck x', 'warped Neck y',
        'warped Torso x', 'warped Torso y',
        'warped Tailhead x', 'warped Tailhead y'
    ]

df = pd.read_csv(input_file)

smoothed_df = smooth_data(df, columns_to_smooth)

labeled_t_df = label_decision(smoothed_df)


In [9]:
labeled_t_df.head(50)

Unnamed: 0.1,Unnamed: 0,Head x,Head y,Neck x,Neck y,Torso x,Torso y,Tailhead x,Tailhead y,warped Head x,...,warped Tailhead y,time,idx,label,rel_time,restaurant,lapIndex,trial,Elapsed Time,decision
0,0,101.959839,122.953583,94.958862,123.054947,84.930008,125.813225,73.409424,126.324501,293.86868,...,148.526876,45.882906,1375,tone_onset,0.022618,1.0,0.0,1.0,0.0,
1,1,114.973434,122.849068,109.808601,122.760101,97.413643,123.139313,85.575119,123.485611,294.423644,...,135.943004,45.91689,1376,tone_onset,0.056602,1.0,0.0,1.0,0.033333,
2,2,130.262146,123.335648,122.963493,122.90432,110.59391,122.420662,100.986809,122.104706,295.250193,...,123.146176,45.949568,1377,tone_onset,0.08928,1.0,0.0,1.0,0.066667,
3,3,143.189209,125.935005,138.174225,123.315308,123.582375,122.37439,113.895775,121.928734,296.379266,...,110.348295,45.983027,1378,tone_onset,0.122739,1.0,0.0,1.0,0.1,
4,4,156.328369,126.135452,151.443024,123.489243,139.144608,122.373863,127.17543,122.015121,297.582014,...,97.590186,46.017216,1379,tone_onset,0.156928,1.0,0.0,1.0,0.133333,
5,5,171.61795,127.38398,166.592575,126.077316,152.113693,122.410805,142.473862,119.335487,299.004127,...,85.138623,46.049523,1380,T_Entry,0.189235,1.0,0.0,1.0,0.166667,
6,6,183.328171,130.264313,176.391953,126.872452,164.281128,122.800102,154.687271,119.088852,300.798158,...,73.676221,46.082765,1381,T_Entry,0.222477,1.0,0.0,1.0,0.2,
7,7,195.879059,130.831009,188.598709,127.361786,175.99057,123.492805,164.319016,121.924728,303.047907,...,63.377211,46.117274,1382,ACC,0.256986,1.0,0.0,1.0,0.233333,T-Entry
8,8,203.957916,134.714218,197.253082,130.631683,187.51506,125.784271,175.645569,122.173195,306.123551,...,53.999973,46.149722,1383,ACC,0.289434,1.0,0.0,1.0,0.266667,
9,9,208.48555,138.93042,204.221268,134.658875,195.585831,126.750298,184.639282,122.287834,310.026565,...,45.682272,46.182682,1384,ACC,0.322394,1.0,0.0,1.0,0.3,


In [6]:
df = labeled_t_df
set(df['trial'])

{1.0,
 4.0,
 8.0,
 12.0,
 16.0,
 20.0,
 24.0,
 28.0,
 32.0,
 36.0,
 40.0,
 43.0,
 46.0,
 50.0,
 54.0,
 58.0,
 61.0,
 65.0,
 69.0,
 73.0,
 77.0,
 81.0,
 85.0,
 88.0,
 91.0,
 95.0,
 99.0,
 103.0,
 107.0,
 111.0,
 115.0,
 119.0,
 123.0,
 127.0,
 131.0,
 135.0,
 138.0,
 142.0,
 146.0,
 150.0,
 154.0,
 157.0,
 160.0,
 164.0,
 168.0,
 172.0,
 176.0}