In [4]:
import pandas as pd
import numpy as np
import csv
import re
import matplotlib.pyplot as plt
import imageio
import os

In [3]:

class FPDataCapture:
    def __init__(self, base_file_path):
        self.headers_f_1 = {}
        self.headers_f_2 = {}
        self.sample_frequency = 1200
        self.base_file_path = base_file_path
        self.radar_to_mocap_conversion_table_path = '/Users/danielcopeland/Library/Mobile Documents/com~apple~CloudDocs/MIT Masters/DRL/LABx/RADARTreePose/data/csvs/radar_seconds_per_frame_t0.csv' 
        self.data_f_1 = self.import_data(self.base_file_path.replace(".tsv", "_f_1.tsv"), self.headers_f_1)
        self.data_f_2 = self.import_data(self.base_file_path.replace(".tsv", "_f_2.tsv"), self.headers_f_2)
        self.foot_lift_times = None
        self.foot_down_times = None
        self.foot_lift_frames_after_actuator = None
        self.foot_down_frames_after_actuator = None
        self.RADAR_Capture = self.base_file_path.split('.')[-2].split('/')[-1].replace("MC", "RR")
        self.seconds_per_frame = 0.036352

    def import_data(self, file_path, headers_dict):
        # The number of initial lines containing metadata information.
        num_metadata_lines = 26 
        
        # Assuming that the first line of actual data has the correct column headers
        data = pd.read_csv(file_path, delimiter='\t', header=num_metadata_lines)
        
        # Convert data to numeric, handling non-numeric entries
        data = data.apply(pd.to_numeric, errors='coerce')
        
        # Change TIME to time
        data.rename(columns={"TIME": "time"}, inplace=True)

        # Reset the index and column names of the dataframe
        data.reset_index(drop=False, inplace=True)
        column_names = list(data.columns)[1:]
        data.drop(data.columns[-1], axis=1, inplace=True)
        data.columns = column_names
        
        test_file_path = "/Users/danielcopeland/Library/Mobile Documents/com~apple~CloudDocs/MIT Masters/DRL/LABx/RADARTreePose/data/csvs/"
         
        # Save the DataFrame to a CSV file
        csv_file_path = test_file_path + str(file_path.split('/')[-1].replace('.tsv', '.csv'))
        data.to_csv(csv_file_path, index=False)
        print(f"Data saved to {csv_file_path}")

        return data

    def identify_foot_lift(self):
        # Determine which force plate data to use based on the filename content
        if "MNTRL" in self.base_file_path:
            data = self.data_f_1
        elif "MNTRR" in self.base_file_path:
            data = self.data_f_2
        else:
            raise ValueError("Filename must contain 'MNTRL' or 'MNTRR'")
        
        # Convert 'time' column to numeric for comparison
        data['time'] = pd.to_numeric(data['time'], errors='coerce')

        # Identifying foot lift and put down events
        # A foot lift event is identified when COP_X and COP_Y go from non-zero to zero
        foot_lift_events = data[(data['COP_X'].shift(1) != 0) & (data['COP_Y'].shift(1) != 0) &
                                (data['COP_X'] == 0) & (data['COP_Y'] == 0)]

        # A foot down event is identified when COP_X and COP_Y go from zero to non-zero
        foot_down_events = data[(data['COP_X'].shift(1) == 0) & (data['COP_Y'].shift(1) == 0) &
                                (data['COP_X'] != 0) & (data['COP_Y'] != 0)]

        # Filter out foot lift events that are too close to each other (within 7 seconds)
        filtered_lift_times = []
        for t in data.loc[foot_lift_events.index, 'time']:
            if not filtered_lift_times or t - filtered_lift_times[-1] > 7:
                filtered_lift_times.append(t)

        # Filter out foot down events that are too close to foot lift events
        filtered_down_times = []
        for t in data.loc[foot_down_events.index, 'time']:
            if not any(abs(t - lift_time) <= 1 for lift_time in filtered_lift_times):
                filtered_down_times.append(t)

        # Now filter foot down events that are too close to each other (within 7 seconds)
        final_filtered_down_times = []
        for t in filtered_down_times:
            if not final_filtered_down_times or t - final_filtered_down_times[-1] > 7:
                final_filtered_down_times.append(t)

        
        filtered_lift_times, final_filtered_down_times = self.filter_lift_and_down_times(filtered_lift_times, final_filtered_down_times)
        
        # Save the times for foot lift and put down events
        self.foot_lift_times = filtered_lift_times
        self.foot_down_times = final_filtered_down_times
        
        # Returning the times where foot lift and put down events occur
        return self.foot_lift_times, self.foot_down_times
    
    def filter_lift_and_down_times(self, filtered_lift_times, final_filtered_down_times):
        # Ensure alternating sequence of lift and down times, and handle duplicates
        alternating_sequence = []
        last_lift_index = 0
        last_down_index = 0
        max_lifts = 3
        max_downs = 2

        while last_lift_index < len(filtered_lift_times) or last_down_index < len(final_filtered_down_times):
            # Add lift time if not exceeding maximum and if it precedes the corresponding down time
            if last_lift_index < len(filtered_lift_times) and len([t for t in alternating_sequence if "lift" in t]) < max_lifts:
                alternating_sequence.append((filtered_lift_times[last_lift_index], "lift"))
                last_lift_index += 1

            # Add down time if not exceeding maximum and if it follows the corresponding lift time
            if last_down_index < len(final_filtered_down_times) and len([t for t in alternating_sequence if "down" in t]) < max_downs:
                alternating_sequence.append((final_filtered_down_times[last_down_index], "down"))
                last_down_index += 1

            # Remove consecutive duplicates, keeping the first occurrence
            alternating_sequence = [t for i, t in enumerate(alternating_sequence) if i == 0 or t[1] != alternating_sequence[i-1][1]]

            # Break if both lift and down times have reached their maximum count
            if len([t for t in alternating_sequence if "lift" in t]) >= max_lifts and len([t for t in alternating_sequence if "down" in t]) >= max_downs:
                break

        # Extract and separate the filtered lists based on the alternating sequence
        filtered_lift_times = [t[0] for t in alternating_sequence if t[1] == "lift"]
        final_filtered_down_times = [t[0] for t in alternating_sequence if t[1] == "down"]

        return filtered_lift_times, final_filtered_down_times

    
    def convert_force_plate_time_to_frames(self):
        
        # Read in the CSV data into a pandas DataFrame
        df = pd.read_csv(self.radar_to_mocap_conversion_table_path, delimiter=(','))
        
        print(df)
        
        print(self.RADAR_Capture)

        # Find the row that matches the given RADAR capture
        row = df[df['RADAR_capture'] == self.RADAR_Capture].iloc[0]
        
        print(row)
        self.seconds_per_frame = row['Seconds_per_Frame']

        # Calculate the frames after actuator for foot lift and foot down
        self.foot_lift_frames_after_actuator = (np.round((self.foot_lift_times - row['MOCAP_Start_Time']) / row['Seconds_per_Frame']) + row['RADAR_Start_Frame']).astype(int)
        self.foot_down_frames_after_actuator = (np.round((self.foot_down_times - row['MOCAP_Start_Time']) / row['Seconds_per_Frame']) + row['RADAR_Start_Frame']).astype(int)
        
        return
    
    def calculate_rolling_std(self, window_size=100):
        """
        Calculate the rolling standard deviation for the force vectors.

        Parameters:
        window_size (int): The number of samples to include in the rolling window.

        Returns:
        pd.DataFrame: A dataframe with the rolling standard deviation for each force vector.
        """

                # Determine which force plate data to use based on the filename content
        if "MNTRL" in self.base_file_path:
            data = self.data_f_2 
        elif "MNTRR" in self.base_file_path:
            data = self.data_f_1
        else:
            raise ValueError("Filename must contain 'MNTRL' or 'MNTRR'")
        
        print(data.head)
        
        self.rolling_std = data[['Force_X', 'Force_Y', 'Force_Z']].rolling(window=window_size).std()
        return self.rolling_std

    def plot_rolling_std(self):
        """
        Plot the rolling standard deviation for the force vectors.
        """
        if "MNTRL" in self.base_file_path:
            data = self.data_f_2
        elif "MNTRR" in self.base_file_path:
            data = self.data_f_1

        if self.rolling_std is None:
            raise ValueError("Rolling standard deviation has not been calculated. Please run calculate_rolling_std() first.")

        plt.figure(figsize=(14, 7))
        
        # Plot rolling standard deviation for each force vector
        for column in ['Force_X', 'Force_Y', 'Force_Z']:
            plt.plot(data['time'], self.rolling_std[column], label=f'Rolling Std of {column}')
        
        plt.title('Rolling Standard Deviation of Force Vectors')
        plt.xlabel('Time (s)')
        plt.ylabel('Standard Deviation')
        plt.legend()
        plt.show()


    def plot_force_vectors(self):
        """
        Plot the X, Y, Z force values over time.
        """
        # Determine which force plate data to use based on the filename content
        if "MNTRL" in self.base_file_path:
            data = self.data_f_2
        elif "MNTRR" in self.base_file_path:
            data = self.data_f_1
        else:
            raise ValueError("Filename must contain 'MNTRL' or 'MNTRR'")

        plt.figure(figsize=(14, 7))
        
        # Plot force vectors
        for column in ['Force_X', 'Force_Y', 'Force_Z']:
            plt.plot(data['time'], data[column], label=f'{column}')
        
        plt.title('Force Vectors Over Time')
        plt.xlabel('Time (s)')
        plt.ylabel('Force (N)')
        plt.legend()
        plt.show()
        
    def calculate_and_plot_cop_velocity(self):
        # Determine which force plate data to use based on the filename content
        if "MNTRL" in self.base_file_path:
            data = self.data_f_2
        elif "MNTRR" in self.base_file_path:
            data = self.data_f_1
        else:
            print("Invalid file path")
            return

        # Calculate the time interval between samples
        delta_t = 1 / self.sample_frequency

        # Calculate the differences in COP positions
        dx = np.diff(data['COP_X'])
        dy = np.diff(data['COP_Y'])

        # Calculate the velocity: velocity = sqrt((dx)^2 + (dy)^2) / delta_t
        velocity = np.sqrt(dx**2 + dy**2) / delta_t

        # Calculate the time for each velocity measurement
        # Since velocity is calculated from differences, it has one less point
        time = np.arange(1, len(data['COP_X'])) * delta_t

        # Plot the velocity
        plt.figure(figsize=(10, 6))
        plt.plot(time, velocity, label='COP Velocity')
        plt.xlabel('Time (seconds)')
        plt.ylabel('Velocity (mm/second)')
        plt.title('Center of Pressure (COP) Velocity Over Time')
        plt.legend()
        plt.show()
        
    def generate_cop_trace_gif(self, gif_filename, subsample_factor=50):
        # Determine which force plate data to use based on the filename content
        if "MNTRL" in self.base_file_path:
            data = self.data_f_2
        elif "MNTRR" in self.base_file_path:
            data = self.data_f_1
        else:
            print("Invalid file path")
            return
        
        # Convert foot lift and down times to indices
        foot_lift_indices = [int(time * self.sample_frequency / subsample_factor) for time in self.foot_lift_times]
        foot_down_indices = [int(time * self.sample_frequency / subsample_factor) for time in self.foot_down_times]


        # Ensure data is in the correct format (convert to numpy arrays if they're pandas Series)
        COP_X = np.array(data['COP_X'])
        COP_Y = np.array(data['COP_Y'])
        
        # Center the data by subtracting the mean
        COP_X_centered = COP_X - np.mean(COP_X)
        COP_Y_centered = COP_Y - np.mean(COP_Y)

        # Subsample the centered data
        COP_X_sub = COP_X_centered[::subsample_factor]
        COP_Y_sub = COP_Y_centered[::subsample_factor]
    
        print(f"Length of COP X sub sample is {len(COP_X_sub)}") 

        # Prepare for GIF creation
        filenames = []
        gif_folder = "/Users/danielcopeland/Library/Mobile Documents/com~apple~CloudDocs/MIT Masters/DRL/LABx/RADARTreePose/data/gifs"
        gif_path = f"{gif_folder}/{gif_filename}"
        
        # Calculate elapsed time for each subsampled point
        elapsed_time_per_point = 1 / self.sample_frequency * subsample_factor
        
        current_label = "Foot Down"  # Track the current label ("Foot Up" or "Foot Down")

        for i in range(len(COP_X_sub)):
            plt.figure(figsize=(6, 6))
            plt.plot(COP_X_sub[:i], COP_Y_sub[:i], 'bo', markersize=2)  # Previous positions in blue
            plt.plot(COP_X_sub[i], COP_Y_sub[i], 'ro', markersize=5)  # Current position in red
            
            # Update current_label based on foot_lift_indices and foot_down_indices
            if i in foot_lift_indices:
                current_label = 'Foot Up'
            elif i in foot_down_indices:
                current_label = 'Foot Down'
            
            # If there's a current label, display it near the top of the chart
            if current_label:
                plt.text(0.5, 0.95, current_label, horizontalalignment='center', verticalalignment='center', transform=plt.gca().transAxes, fontsize=12, color='green')

            plt.xlim([-25, 25])
            plt.ylim([-45, 45])
            
            # Calculate elapsed time for the title
            elapsed_time = i / self.sample_frequency * subsample_factor
            plt.title(f'Time: {elapsed_time:.2f} s')
            
            # Save plot to a file
            filename = f'{gif_folder}/frame_{i}.png'
            plt.savefig(filename)
            plt.close()
            filenames.append(filename)

        # Create a GIF
        with imageio.get_writer(gif_path, mode='I', duration=40/len(COP_X_sub)) as writer:
            for filename in filenames:
                image = imageio.imread(filename)
                writer.append_data(image)
                
        # Cleanup the temporary frame images
        for filename in filenames:
            os.remove(filename)

        print(f'GIF saved to {gif_path}')
        return gif_path

    def generate_cop_trace_gif_fu_only(self, gif_filename, subsample_factor=50):
        # Determine which force plate data to use based on the filename content
        if "MNTRL" in self.base_file_path:
            data = self.data_f_2
        elif "MNTRR" in self.base_file_path:
            data = self.data_f_1
        else:
            print("Invalid file path")
            return

        # Ensure data is in the correct format (convert to numpy arrays if they're pandas Series)
        COP_X = np.array(data['COP_X'])
        COP_Y = np.array(data['COP_Y'])

        # Center the data by subtracting the mean
        COP_X_centered = COP_X - np.mean(COP_X)
        COP_Y_centered = COP_Y - np.mean(COP_Y)

        # Convert foot lift and down times to indices
        foot_lift_indices = [int(time * self.sample_frequency) for time in self.foot_lift_times]
        # Ensure there is a "foot down" time for each "foot up", or use the length of the dataset for the last "foot up"
        if len(self.foot_down_times) < len(self.foot_lift_times):
            self.foot_down_times.append(len(COP_X) / self.sample_frequency)  # Assume the last "foot up" extends to the end
        foot_down_indices = [int(time * self.sample_frequency) for time in self.foot_down_times]

        # Filtering indices for "Foot Up" periods
        indices_to_plot = []
        for start, end in zip(foot_lift_indices, foot_down_indices):
            # Adjust the range to ensure it does not exceed the length of the dataset
            end_index = min(end, len(COP_X))
            indices_to_plot.extend(range(start, end_index, subsample_factor))

        # Subsample the centered data for "Foot Up" periods
        COP_X_sub = COP_X_centered[indices_to_plot]
        COP_Y_sub = COP_Y_centered[indices_to_plot]

        # Prepare for GIF creation
        gif_folder = "/Users/danielcopeland/Library/Mobile Documents/com~apple~CloudDocs/MIT Masters/DRL/LABx/RADARTreePose/data/gifs"
        if not os.path.exists(gif_folder):
            os.makedirs(gif_folder)
        gif_path = f"{gif_folder}/{gif_filename}"
        
        filenames = []

        for i in range(len(COP_X_sub)):
            plt.figure(figsize=(6, 6))
            plt.plot(COP_X_sub[:i+1], COP_Y_sub[:i+1], 'bo', markersize=2)  # Previous positions in blue
            plt.plot(COP_X_sub[i], COP_Y_sub[i], 'ro', markersize=5)  # Current position in red

            # Adjust the plot limits
            plt.xlim([-25, 25])
            plt.ylim([-45, 45])
            
            elapsed_time = indices_to_plot[i] / self.sample_frequency
            plt.title(f'Time: {elapsed_time:.2f} s')
            
            # Save each frame
            filename = os.path.join(gif_folder, f'frame_{i}.png')
            plt.savefig(filename)
            plt.close()
            filenames.append(filename)

        # Create and save the GIF
        with imageio.get_writer(gif_path, mode='I', duration=subsample_factor/(self.sample_frequency)) as writer:
            for filename in filenames:
                image = imageio.imread(filename)
                writer.append_data(image)
                
        # Cleanup the temporary frame images
        for filename in filenames:
            os.remove(filename)

        print(f'GIF saved to {gif_path}')
# Usage example:
# file_path_f_1 = 'path_to_f_1.tsv' # Replace with actual file path
# file_path_f_2 = 'path_to_f_2.tsv' # Replace with actual file path
# fp_data_capture = FPDataCapture(file_path_f_1, file_path_f_2)
# filename_to_check = 'some_filename_containing_MNTRL_or_MNTRR'
# fp_data_capture.identify_foot_lift(filename_to_check)
# Now you can access the times directly
# print(fp_data_capture.foot_lift_times_f_1)
# print(fp_data_capture.foot_down_times_f_1)
# print(fp_data_capture.foot_lift_times_f_2)
# print(fp_data_capture.foot_down_times_f_2)
