In [297]:
import pandas as pd
import numpy as np


In [298]:
# read all freeze_frame data set 

freeze_frame = pd.read_csv('freeze_frame_index_all.csv')
freeze_frame_1_bouts = freeze_frame[freeze_frame['bouts'] == 1]
freeze_frame_125_bouts = freeze_frame[freeze_frame['bouts'] == 1.25]
freeze_frame_15_bouts = freeze_frame[freeze_frame['bouts'] == 1.5]
freeze_frame_175_bouts = freeze_frame[freeze_frame['bouts'] == 1.75]
freeze_frame_2_bouts = freeze_frame[freeze_frame['bouts'] == 2]

In [299]:
# # so same mice: ptsd2_84_recall1, let's do 1s as bout first 

# ptsd2_84_recall1_1 = freeze_frame_1_bouts[(freeze_frame_1_bouts['cohort_id'] == 'ptsd2_84') & (freeze_frame_1_bouts['day'] == 'recall1')]
# ptsd2_84_recall1_125 = freeze_frame_125_bouts[(freeze_frame_125_bouts['cohort_id'] == 'ptsd2_84') & (freeze_frame_125_bouts['day'] == 'recall1')]
# ptsd2_84_recall1_15 = freeze_frame_15_bouts[(freeze_frame_15_bouts['cohort_id'] == 'ptsd2_84') & (freeze_frame_15_bouts['day'] == 'recall1')]
# ptsd2_84_recall1_175 = freeze_frame_175_bouts[(freeze_frame_175_bouts['cohort_id'] == 'ptsd2_84') & (freeze_frame_175_bouts['day'] == 'recall1')]
# ptsd2_84_recall1_2 = freeze_frame_2_bouts[(freeze_frame_2_bouts['cohort_id'] == 'ptsd2_84') & (freeze_frame_2_bouts['day'] == 'recall1')]


In [300]:
import pandas as pd
import re
import os


def dlc_to_long(file_path):
    """
    Transforms wide DLC data into a long format, using metadata rows to structure columns
    and extracts cohort_id from the file name.
    
    Parameters:
    - file_path: String representing the file path to the wide-format positional data CSV.
    
    Returns:
    - long_data: DataFrame in long format with columns ['x', 'y', 'likelihood', 'body_part', 'coords', 'cohort_id'].
    """
    
    # Extract the file name from the file path
    file_name = os.path.basename(file_path)
    
    # Attempt to match both formats
    match1 = re.match(r'(\w+)_([a-zA-Z]+\d*)_(\d+)', file_name)  # Format 1
    match2 = re.match(r'(\w+)_([a-zA-Z]+\d*)_(\d+)-(\d+)', file_name)  # Format 2

    if match2:  # only for ptsd9 group (ptsd9 group have longer cohort_id such as ptsd9_31_2)
        cohort_prefix = match2.group(1)  # 'ptsd9'
        day = match2.group(2)            # 'recall4'
        cohort_number1 = match2.group(3)  # '31'
        cohort_number2 = match2.group(4)  # '2'
        cohort_id = f"{cohort_prefix}_{cohort_number1}_{cohort_number2}"  # 'ptsd9_31_2'
    elif match1:  # for all other groups 
        cohort_prefix = match1.group(1)  # 'ptsd2'
        day = match1.group(2)            # 'recall1'
        cohort_number = match1.group(3)  # '81'
        cohort_id = f"{cohort_prefix}_{cohort_number}"  # 'ptsd2_81'
    else:
        cohort_id = 'unknown'
        day = 'unknown'
    
    raw_data = pd.read_csv(file_path)
    
    # Extract body parts from the first metadata row and coordinate types from the second row
    body_parts = raw_data.iloc[0, 1::3].values
    coordinates = ["x", "y", "likelihood"]
    
    # Generate column names using body parts and coordinates
    column_names = [f"{body}_{coord}" for body in body_parts for coord in coordinates]
    
    # Reload the dataset, skipping metadata rows and assigning correct column names
    data = pd.read_csv(file_path, skiprows=[0, 1])
    data = data.iloc[:, 1:]  # Drop the first unnecessary 'coords' column
    data.columns = column_names
    
    # Insert 'cohort_id' and 'day' column
    data['cohort_id'] = cohort_id
    data['day'] = day
    
    # Convert to long format
    long_data = pd.DataFrame()
    for part in body_parts:
        part_data = data[[f"{part}_x", f"{part}_y", f"{part}_likelihood", 'cohort_id', 'day']].copy()
        part_data.columns = ['x', 'y', 'likelihood', 'cohort_id', 'day']
        part_data['body_part'] = part
        part_data['index'] = part_data.index
        part_data['t(sec)'] = (part_data['index'] / len(part_data) * 300).round(2)
        long_data = pd.concat([long_data, part_data], ignore_index=True)
    
    return long_data


In [301]:
def process_dlc_folder(folder_path):
    """
    Processes all relevant DLC CSV files in a folder, converts them to long format, and
    concatenates them into a single DataFrame.
    
    Parameters:
    - folder_path: String representing the path to the folder containing DLC CSV files.
    
    Returns:
    - combined_data: DataFrame containing all processed data in long format.
    """
    combined_data = pd.DataFrame()
    
    # List all files in the folder
    all_files = os.listdir(folder_path)
    
    # Filter for relevant DLC CSV files
    dlc_files = [f for f in all_files if f.endswith('.csv') and 'DLC' in f]
    
    successful_count = 0

    # Process each file
    for file_name in dlc_files:
        file_path = os.path.join(folder_path, file_name)
        # Process the file using dlc_to_long
        try:
            long_data = dlc_to_long(file_path)
            combined_data = pd.concat([combined_data, long_data], ignore_index=True)
            successful_count += 1
        except Exception as e:
            print(f"Error processing {file_name}: {e}")
            continue
    
    print(f"Successfully processed {successful_count} files.")

    return combined_data

In [302]:
def get_shape(folder_path):
    """

    """    
    # List all files in the folder
    all_files = os.listdir(folder_path)
    
    # Filter for relevant DLC CSV files
    dlc_files = [f for f in all_files if f.endswith('.csv') and 'DLC' in f]
    
    successful_count = 0

    # Process each file
    for file_name in dlc_files:
        file_path = os.path.join(folder_path, file_name)
        # Process the file using dlc_to_long
        try:
            long_data = dlc_to_long(file_path)
            combined_data = pd.concat([combined_data, long_data], ignore_index=True)
            successful_count += 1
        except Exception as e:
            print(f"Error processing {file_name}: {e}")
            continue
    
    print(f"Successfully processed {successful_count} files.")

    return combined_data

In [303]:
dlc_frames = process_dlc_folder('/Users/novak/Documents/Columbia/mentored_research/Turi/dlc_csv')

Successfully processed 140 files.


In [304]:
import os
import pandas as pd

def check_dlc_shapes(folder_path):
    """
    Checks all relevant DLC CSV files in a folder for the expected shape of (1124, 8).
    If a file does not match this shape, prints the file name and its shape.
    
    Parameters:
    - folder_path: String representing the path to the folder containing DLC CSV files.
    """
    # List all files in the folder
    all_files = os.listdir(folder_path)
    
    # Filter for relevant DLC CSV files
    dlc_files = [f for f in all_files if f.endswith('.csv') and 'DLC' in f]
    
    successful_count = 0
    outlier_files = []

    # Process each file
    for file_name in dlc_files:
        file_path = os.path.join(folder_path, file_name)
        try:
            long_data = dlc_to_long(file_path)
            if long_data.shape == (13488, 8):
                successful_count += 1
            else:
                outlier_files.append((file_name, long_data.shape))
        except Exception as e:
            print(f"Error processing {file_name}: {e}")
            continue
    
    print(f"Successfully processed {successful_count} files with the expected shape (1124, 8).")
    if outlier_files:
        print("Files with unexpected shapes:")
        for file_name, shape in outlier_files:
            print(f"  - {file_name}: {shape}")
    else:
        print("All files have the expected shape.")



In [305]:
# # load in DLC file and pass it through the function 

# file_path = '/Users/novak/Documents/Columbia/mentored_research/Turi/DLC_1107/ptsd2_recall1_84DLC_resnet50_phi_cfc_boxNov7shuffle1_388000.csv'
# ptsd2_84_recall1_dlc = dlc_to_long(file_path)
# # ptsd2_84_recall1_dlc['cohort_id'] = 'ptsd2_84_recall1'

# # Display the first few rows
# ptsd2_84_recall1_dlc.head()

In [306]:
def find_freeze_to_non_freeze(freeze_frame_data):
    """
    Finds the timestamps where the 'freeze' column changes from 1 to 0.
    
    Parameters:
    - freeze_frame_data: DataFrame containing freeze frame data with 'freeze' and 't(sec)' columns.
    
    Returns:
    - List of timestamps in seconds where freeze transitions from 1 to 0.
    """
    # Ensure 't(sec)' is numeric
    # freeze_frame_data['t(sec)'] = pd.to_numeric(freeze_frame_data['t(sec)'], errors='coerce')
    
    # Shift the freeze column to compare current and previous rows
    freeze_frame_data.loc[:, 'freeze_shift'] = freeze_frame_data['freeze'].shift(1).fillna(0)
    
    # Identify points where freeze changes from 1 to 0
    transitions = freeze_frame_data[(freeze_frame_data['freeze_shift'] == 1) & (freeze_frame_data['freeze'] == 0)]
    
    # Extract the timestamps for these transitions
    return transitions['t(sec)'].tolist()


In [307]:
def find_closest_timestamps(freeze_timestamps, dlc_data, body_part, margin=0.113):
    """
    Finds the closest matching timestamps in the DLC data for each freeze frame timestamp.
    
    Parameters:
    - freeze_timestamps: List of timestamps from freeze frame data (in seconds).
    - dlc_data: DataFrame containing DLC data with a 't(sec)' column representing time in seconds.
    - margin: Float value specifying the maximum allowable difference in seconds (default is 0.113).
    
    Returns:
    - List of tuples (freeze_timestamp, closest_dlc_timestamp) where each DLC timestamp
      is the closest match to a freeze frame timestamp within the given margin.
    """
    # Ensure 't(sec)' in DLC data is numeric
    # dlc_data['t(sec)'] = pd.to_numeric(dlc_data['t(sec)'], errors='coerce')
    
    dlc_data = dlc_data[dlc_data['body_part'] == body_part]

    matched_timestamps = []
    for t1 in freeze_timestamps:
        # Find the DLC timestamps within the margin range
        within_margin = dlc_data[(dlc_data['t(sec)'] >= t1 - margin) & (dlc_data['t(sec)'] <= t1 + margin)]
        
        if not within_margin.empty:
            # Find the closest timestamp within the margin
            closest_dlc_time = within_margin.iloc[(within_margin['t(sec)'] - t1).abs().argmin()]['t(sec)']
            matched_timestamps.append((t1, closest_dlc_time))
        else:
            # If no match is found within the margin, append None for the DLC timestamp
            matched_timestamps.append((t1, None))

    return matched_timestamps


In [308]:
import numpy as np
import matplotlib.pyplot as plt

def analyze_dlc_points_long_format(matched_timestamps, dlc_data, body_part, window_size=5, epsilon=0.5):
    """
    Analyzes DLC data points in long format by extracting and plotting x and y coordinates around matched timestamps.
    Calculates a normalized loss based on positional changes around each transition.
    
    Parameters:
    - matched_timestamps: List of tuples (freeze_timestamp, closest_dlc_timestamp) from find_closest_timestamps.
    - dlc_data: DataFrame containing DLC data in long format with 't(sec)', 'x', 'y', 'body_part' columns.
    - body_part: The body part to analyze (e.g., 'nose'); used to filter specific x and y values in the long format.
    - window_size: The number of points before and after the matched point to consider in the loss calculation.
    - epsilon: Threshold for minimal change to define "stationary" behavior before the transition.
    
    Returns:
    - Average loss for the parameter setting.
    """
    total_loss = 0
    num_matched_points = 0
    
    for freeze_time, dlc_time in matched_timestamps:
        if dlc_time is not None:
            # Filter for the specified body part and find the matched time point
            body_part_data = dlc_data[(dlc_data['body_part'] == body_part)]
            matched_index = body_part_data[body_part_data['t(sec)'] == dlc_time]['index'].astype(int).iloc[0]
            
            # Extract a window of points before and after the matched point
            start_index = max(matched_index - window_size, 0)
            end_index = min(matched_index + window_size, len(body_part_data) - 1)
            window_data = body_part_data[(body_part_data['index'] >= start_index) & (body_part_data['index'] <= end_index)]

            # Extract x and y coordinates
            x_coords = window_data['x'].values
            y_coords = window_data['y'].values
            
            # Center the coordinates based on the matched point (middle of the window)
            center_x = x_coords[window_size]
            center_y = y_coords[window_size]
            x_centered = x_coords - center_x
            y_centered = y_coords - center_y
            
            # Calculate absolute change in position around the transition point
            abs_change_x = np.abs(x_centered)
            abs_change_y = np.abs(y_centered)
            
            # Define the before and after windows
            before_x = abs_change_x[:window_size]
            after_x = abs_change_x[window_size:]
            before_y = abs_change_y[:window_size]
            after_y = abs_change_y[window_size:]
            
            # Calculate penalties for 'before' and 'after' based on the absolute changes
            # Penalize deviation from zero (stationary) before the transition
            loss_before = np.sum(np.maximum(before_x - epsilon, 0)) + np.sum(np.maximum(before_y - epsilon, 0))
            
            # Penalize stationary behavior after the transition
            loss_after = np.sum(np.maximum(epsilon - after_x, 0)) + np.sum(np.maximum(epsilon - after_y, 0))
            
            # Total loss for this transition
            transition_loss = loss_before + loss_after
            total_loss += transition_loss
            num_matched_points += 1
            
            # # Plotting if desired
            # time_relative = window_data['t(sec)'].values - dlc_time
            # fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(6, 2))
            
            # # Plot centered x coordinates
            # ax1.plot(time_relative, x_centered, label='Centered X', linewidth=1)
            # ax1.set_xlabel('Time Relative (s)')
            # ax1.set_ylabel('Centered X')
            # ax1.axvline(0, color='r', linestyle='--', label='Matched Time')
            
            # # Plot centered y coordinates
            # ax2.plot(time_relative, y_centered, label='Centered Y', color='orange', linewidth=1)
            # ax2.set_xlabel('Time Relative (s)')
            # ax2.set_ylabel('Centered Y')
            # ax2.axvline(0, color='r', linestyle='--', label='Matched Time')
            
            # # Set the overall title for the figure
            # fig.suptitle(f"Centered X and Y Coordinates Around Transition at {freeze_time:.2f} sec")
            # plt.tight_layout()
            # plt.show()
    
    # Normalize total loss by the number of matched points to get the average loss
    average_loss = total_loss / num_matched_points if num_matched_points > 0 else 0
    return average_loss


In [309]:
def run_full_analysis(freeze_frame_data, dlc_data, body_part, margin=0.113, window_size=5, epsilon=0.5):
    """
    Runs the full analysis pipeline, including finding freeze transitions,
    matching timestamps, and calculating the average loss for a parameter setting.
    
    Parameters:
    - freeze_frame_data: DataFrame containing freeze frame data with 'freeze' and 't(sec)' columns.
    - dlc_data: DataFrame containing DLC data in long format with 't(sec)', 'x', 'y', 'body_part' columns.
    - body_part: The body part to analyze (e.g., 'nose'); used to filter specific x and y values in the long format.
    - margin: Time margin within which to match freeze frame timestamps to DLC timestamps.
    - window_size: The number of points before and after the matched point to consider in the loss calculation.
    - epsilon: Threshold for minimal change to define "stationary" behavior before the transition.
    
    Returns:
    - average_loss: Average loss for the given freeze frame parameter setting.
    """
    # Step 1: Find freeze-to-non-freeze transition timestamps
    freeze_timestamps = find_freeze_to_non_freeze(freeze_frame_data)
    
    # Step 2: Match timestamps in DLC data
    matched_timestamps = find_closest_timestamps(freeze_timestamps, dlc_data, body_part, margin)
    
    # Step 3: Calculate average loss for this parameter setting
    average_loss = analyze_dlc_points_long_format(matched_timestamps, dlc_data, body_part, window_size, epsilon)
    
    return average_loss


In [310]:
import pandas as pd

def average_loss_analysis(freeze_frame_data, dlc_data):
    """
    Analyzes average loss for each bout and body part across all cohort_id and day combinations.
    
    Parameters:
    - freeze_frame_data: DataFrame containing freeze frame data with columns 'cohort_id', 'day', 'bouts', and other relevant columns.
    - dlc_data: DataFrame containing DLC data in long format with 'cohort_id', 'day', 'body_part', 't(sec)', 'x', 'y', etc.
    
    Returns:
    - result_df: DataFrame with averaged loss values across all cohort_id and day for each bout and body part.
    """
    # Extract unique values from data
    bouts = freeze_frame_data['bouts'].unique()
    cohort_ids = freeze_frame_data['cohort_id'].unique()
    days = freeze_frame_data['day'].unique()
    body_parts = dlc_data['body_part'].unique()
    
    # Initialize an empty dictionary to store results
    results = {body_part: [] for body_part in body_parts}
    results['bout'] = []
    results['Total'] = []
    
    for bout in bouts:
        # Filter the freeze frame data for the specific bout
        bout_freeze_data = freeze_frame_data[freeze_frame_data['bouts'] == bout]
        
        # Initialize a dictionary to store sum of losses for each body part for averaging
        bout_loss_sum = {body_part: 0 for body_part in body_parts}
        count = 0
        
        # Loop through each unique cohort_id and day combination
        for cohort_id in cohort_ids:
            for day in days:
                # Filter data for the current cohort_id and day
                cohort_day_freeze_data = bout_freeze_data[(bout_freeze_data['cohort_id'] == cohort_id) & 
                                                          (bout_freeze_data['day'] == day)]
                cohort_day_dlc_data = dlc_data[(dlc_data['cohort_id'] == cohort_id) & 
                                               (dlc_data['day'] == day)]
                
                # Ensure data exists for the combination
                if cohort_day_freeze_data.empty or cohort_day_dlc_data.empty:
                    print(f"No data found for {cohort_id}, {day}.")    
                    continue
                
                # Run full analysis for each body part and accumulate the loss
                for body_part in body_parts:
                    try:
                        loss = run_full_analysis(cohort_day_freeze_data, cohort_day_dlc_data, body_part=body_part)
                        bout_loss_sum[body_part] += loss
                    except Exception as e:
                        print(f"Error processing {cohort_id}, {day}, {body_part}: {e}")
                        continue
                count += 1
        
        # Calculate average loss for each body part for the current bout
        bout_average_losses = {body_part: (bout_loss_sum[body_part] / count if count > 0 else 0) for body_part in body_parts}
        bout_total = sum(bout_average_losses.values())  # Total loss for the bout
        
        # Append results for the current bout
        for body_part, avg_loss in bout_average_losses.items():
            results[body_part].append(avg_loss)
        results['bout'].append(bout)
        results['Total'].append(bout_total)
    
    # Create a DataFrame from results dictionary
    result_df = pd.DataFrame(results)
    result_df.set_index('bout', inplace=False)
    return result_df




In [311]:
# Example usage
# Assuming freeze_frame_data and dlc_data are already defined and contain the expected columns
result_df = average_loss_analysis(freeze_frame, dlc_frames)

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  freeze_frame_data.loc[:, 'freeze_shift'] = freeze_frame_data['freeze'].shift(1).fillna(0)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  freeze_frame_data.loc[:, 'freeze_shift'] = freeze_frame_data['freeze'].shift(1).fillna(0)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  freeze_frame_data.loc[:, 

In [313]:
result_df

Unnamed: 0,nose,head,right_ear,left_ear,neck,back2,back1,back3,back4,tail_base,tail1,tail2,bout,Total
0,67.023886,37.198312,28.169135,32.02622,25.512827,22.101091,15.998696,18.246771,19.079607,26.286063,42.849269,94.548138,1.0,429.040015
1,64.303298,36.207856,27.314291,30.965567,24.036672,20.619159,14.951254,17.374351,17.644341,26.33054,43.425672,91.938714,1.25,415.111714
2,66.234405,36.033979,26.407355,30.364024,23.075254,18.998498,13.813336,16.048777,16.29492,24.999549,42.88591,87.415093,1.5,402.5711
3,63.481166,35.918605,25.527351,30.856444,22.986909,19.429881,13.910743,16.646257,16.583168,25.281777,44.314248,87.243253,1.75,402.179801
4,63.481166,35.918605,25.527351,30.856444,22.986909,19.429881,13.910743,16.646257,16.583168,25.281777,44.314248,87.243253,2.0,402.179801


In [315]:
result_df['Total']-result_df['tail1']-result_df['tail2']

0    291.642608
1    279.747328
2    272.270097
3    270.622300
4    270.622300
dtype: float64