In [1]:
from scipy import stats
from glob import glob
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import savgol_filter
import re
from scipy.signal import find_peaks, peak_widths
from scipy.interpolate import interp1d
pd.set_option('display.max_columns', 500)
from math import sqrt
import os


# Methods

In [22]:
def combine_levels(df: pd.DataFrame) -> pd.DataFrame:
    # use map to run a function to every element in a list
    # in this case, it joins all the column levels together(where titles are present) and adds a _ in between to create a single data index as opposed to a multi index data frame (much better)
    df.columns = df.columns.map("_".join)
    return df

def drop_likelihood(df: pd.DataFrame) -> pd.DataFrame:
    # .str.contains is to find, within the .columns, check if element within .column has this string
    mask = df.columns.str.contains("likelihood")

    # creates a list of all the points at which condition is true
    find_filter = df.columns[mask]
    # .drop removes rows by default unless index is specified
    filter_df = df.drop(columns=find_filter)
    return filter_df

def fix_flip(df):
    boundary = 700
    problem_columns = [
        ("MouseTopRight_Snout", "MouseBottomLeft_Snout"),
        ("MouseTopRight_Forelimb", "MouseBottomLeft_Forelimb"),
        ("MouseTopRight_Wrist", "MouseBottomLeft_Wrist"),
        ("MouseTopRight_Elbow", "MouseBottomLeft_Elbow"),
        ("MouseTopRight_LowerShoulder", "MouseBottomLeft_LowerShoulder"),
        ("MouseTopRight_UpperShoulder", "MouseBottomLeft_UpperShoulder"),
        ("MouseTopRight_IlliacCrest", "MouseBottomLeft_IlliacCrest"),
        ("MouseTopRight_Hip", "MouseBottomLeft_Hip"),
        ("MouseTopRight_Knee", "MouseBottomLeft_Knee"),
        ("MouseTopRight_Ankle", "MouseBottomLeft_Ankle"),
        ("MouseTopRight_Hindlimb", "MouseBottomLeft_Hindlimb"),
        ("MouseTopRight_TailBase", "MouseBottomLeft_TailBase"),
        ("MouseTopRight_TailCenter", "MouseBottomLeft_TailCenter"),
        ("MouseTopRight_TailTip", "MouseBottomLeft_TailTip"),
    ]
    for index, row in df.iterrows():
        for side_pair in problem_columns:
            top_R_y_value, bottom_L_y_value = row[f'{side_pair[0]}_y'], row[f'{side_pair[1]}_y']
            top_R_x_value, bottom_L_x_value = row[f'{side_pair[0]}_x'], row[f'{side_pair[1]}_x']
            if top_R_y_value >= boundary and bottom_L_y_value < boundary:
                row[f'{side_pair[0]}_y'], row[f'{side_pair[1]}_y'] = bottom_L_y_value, top_R_y_value
                row[f'{side_pair[0]}_x'], row[f'{side_pair[1]}_x'] = bottom_L_x_value, top_R_x_value
    return df

def read_raw_data(path: str) -> pd.DataFrame:
    raw_df = pd.read_csv(
        path,
        # skip the first row for the titles, skip whatever you define
        skiprows=[0],
        # based on what you want to join together for the column names
        header=[0, 1, 2],
        # set the column you chose to be your index
        index_col=[0],
    )
    return raw_df

def clean_data(raw_df: pd.DataFrame) -> pd.DataFrame:
    clean_df = (
        # using interpolation to fill in any missing gaps within the data, polynimial
        raw_df.rename_axis(index="frame")
        .interpolate(method="linear")
        # pipe is a g, .pipe allows to call a function that accepts a df and returns a df, allows for complex transformations
        # allows for method chain of functions to remain intact
        .pipe(combine_levels)
        .pipe(drop_likelihood)
        .pipe(fix_flip)
    )
    return clean_df

def get_files():
    pattern = r"(\d*\w+)_(Sham|SCI|Ctl|CCI|Baseline|EAE|mtPst1)_(-?\d+)dpi"
    paths = glob("*DLC*.csv")
    file_info = []

    for path in paths:
        if match := re.search(pattern, path):
            animal_id, condition, dpi = match.groups()
            file_info.append({"file_path": os.path.abspath(path), "file_name": path, "animal_id": animal_id, "condition": condition, "dpi": int(dpi)})

    # Sort files by DPI and within each DPI by Sham, SCI, and then Baseline
    sorted_files = sorted(file_info, key=lambda x: (x["dpi"], x["condition"]))

    return [info["file_name"] for info in sorted_files]

def get_animal_name(file):
    if 'DLC' in file:
        animalName = file.split('DLC')[0]
    else:
        animalName = file.split('.csv')[0]

    return animalName

def gait_widths(df,height,draw,sub_vline,neg_vline,sub_heights):

    dfy = df.iloc[:, 1].to_numpy()
    peaks, peak_dict = find_peaks(dfy, prominence=0.12,height=0) # prominence can be modified
    if len(peaks) == 0:
        peaks, peak_dict = find_peaks(dfy, prominence=0.1,height=0)
    peaks_sec = [x/200 for x in peaks]
    x_lst_sec = [x/200 for x in range(0,len(dfy))]
    
    # Calculate avg_peak_dists
    peak_dists = []
    for i in range(1,len(peaks_sec)):
        peak_dist = peaks_sec[i] - peaks_sec[i-1]
        peak_dists.append(peak_dist)
    if len(peak_dists) == 0:
        print("No peaks found.")
        avg_peak_dist = 0
    else:
        avg_peak_dist = sum(peak_dists) / len(peak_dists)

    # Calculate avg_peak_height
    peak_heights = peak_dict.get('peak_heights')
    if len(peak_heights) == 0:
        avg_peak_height = 0
    else:
        avg_peak_height = sum(peak_heights) / len(peak_heights)

    # Calculate avg_peak_area in cm^2*s (as all areas are calculated as squared cm)
    wide_results_full = peak_widths(dfy, peaks, 0.99)
    wide_vline_points = list(zip(wide_results_full[2], wide_results_full[3]))
    wide_hline_heights = list(wide_results_full[1])
    peak_bounds = merge_regions(wide_vline_points) 
    fill_lst = []
    sum_area = 0
    for bound in peak_bounds:
        left = int(bound[0])
        right = int(bound[1])
        peak_y = dfy[left:right]
        peak_x = x_lst_sec[left:right]
        sum_area += np.trapz(peak_y, x=peak_x)
        fill_lst.append(list(zip(peak_x,peak_y)))
    if len(peak_bounds) == 0:
        avg_peak_area = 0
    else :
        avg_peak_area = sum_area / len(peak_bounds)

    troughs, _ = find_peaks(-dfy, prominence=0.2) # prominence can be modified
    if len(troughs) == 0:
        troughs, _ = find_peaks(-dfy, prominence=0.01)
    troughs_sec = [x/200 for x in troughs]
    results_full = peak_widths(dfy, peaks, height) # rel_height needs a dynamic function, needs to change somehow. however, it can stay constant if fisheye is removed / starting pos is same
    results_full_sec = [x/200 for x in results_full]
    vline_points = list(zip(results_full_sec[2], results_full_sec[3]))

    # Generates the figure
    if(draw):

        plt.figure(figsize=(14, 10))
        plt.rcParams["axes.linewidth"] = 2
        hline_heights = list(results_full[1])

        plt.plot(x_lst_sec,dfy,color = 'darkgray', linewidth=3)
        plt.plot(peaks_sec, dfy[peaks], "o", markersize=10, color='royalblue')
        plt.plot(troughs_sec, dfy[troughs], "o", markersize=10, color='lightsalmon')
        plt.tick_params(axis='x', labelsize=18, length=7, width=2)
        plt.tick_params(axis='y', labelsize=18, length=7, width=2)
        # plt.ylim(0.0, 0.60)
        # plt.xlim(0.03, 0.38)
        # plt.hlines(hline_heights,*results_full_sec[2:], color="slategrey", linewidth=2)
        
        #don't use this, its for the the blue line that only highlight around the peaks
        # for widths in vline_points:
        #     plt.axvspan(*widths, alpha=0.3) 

        if len(sub_vline) != 0:
            sub_left,sub_right = zip(*sub_vline)
            plt.hlines(sub_heights,sub_left,sub_right, color="lightcoral")
            for widths in sub_vline:
                if widths in neg_vline:
                    plt.axvspan(*widths, alpha=0.3,color = 'firebrick')
                else:
                    plt.axvspan(*widths, alpha=0.3,color = 'lightcoral')

        plt.axhline(min([y for y in dfy[troughs]]),color='black', linewidth=2)

        for i in range(0,len(fill_lst)):
            x,y = list(zip(*fill_lst[i]))
            plt.fill_between(x=x,y1=y,y2=wide_hline_heights[i],color="paleturquoise",alpha=0.5)
        plt.show()
        
    return vline_points, len(peaks), avg_peak_dist, avg_peak_height, avg_peak_area

def illustrate_dist(curve,dfy,floor):
    x_lst_sec = [x/200 for x in range(0,len(dfy))]
    plt.figure(figsize=(14, 10))
    plt.rcParams["axes.linewidth"] = 2 
    plt.plot(x_lst_sec,dfy,color='darkgray', linewidth=3)
    plt.plot(x_lst_sec,curve,color='crimson', linewidth=1.5)
    plt.tick_params(axis='x', labelsize=18, length=7, width=2)
    plt.tick_params(axis='y', labelsize=18, length=7, width=2)
    plt.axhline(floor,color='black', linewidth=2)
    plt.title(f"{animalName} {labels2[i]}", fontsize=18)
    plt.show

def fisheye_dist(df,illustrate):

    dfy = df.iloc[:, 1].to_numpy()
    indicies = range(0,len(dfy)) #indicies for every frame
    troughs, _ = find_peaks(-dfy, prominence=0.07) #0.1######################################################
    x_lst = list(troughs)
    y_lst = []
    for i in x_lst:
        y_lst.append(dfy[i])
    
    x_lst_and_ends = [0] + x_lst + [len(dfy)-1]
    x_lst_and_ends_sec = [x/200 for x in x_lst_and_ends]
    y_lst_and_ends = [y_lst[0]] + y_lst + [y_lst[len(y_lst)-1]]

    floor = min(y_lst)
    piecewise_linear = interp1d(x_lst_and_ends_sec, y_lst_and_ends, kind='linear', fill_value='extrapolate')
    curve = [piecewise_linear(x/200) for x in indicies]
    
    # Removes low extremes at either end of the dataset
    for i in range(0,x_lst[0]):
            dfy[i] = max(dfy[i],y_lst[0])
    for i in range(x_lst[len(x_lst)-1],len(dfy)):
            dfy[i] = max(dfy[i],y_lst[len(y_lst)-1])
    
    dfy2 = [dfy[i] - curve[i] for i in indicies]
    new_df = pd.DataFrame()
    new_df[df.columns[0]] = df.iloc[:,0]
    new_df[df.columns[1]] = dfy2

    if(illustrate):
        illustrate_dist(curve,dfy,floor)

    new_df.iloc[:, 1].to_numpy()

    return new_df

def correct_units(df,draw):
    new_df = pd.DataFrame()
    new_df[df.columns[0]] = df.iloc[:,0] / 24
    new_df[df.columns[1]] = df.iloc[:,1] / 24
    new_df.index = df.index / 200
    new_df.index.name = 'seconds'

    if(draw):
        dfy = new_df.iloc[:, 1].to_numpy()
        x_lst_sec = [x/200 for x in range(0,len(dfy))]
        plt.figure(figsize=(14, 10))
        plt.plot(x_lst_sec,dfy,color='darkgray', linewidth=3)
        plt.show

    return new_df

def remove_subthresholds(df,draw):
    
    dfy = df.iloc[:, 1].to_numpy()
    peaks, peak_dict = find_peaks(dfy, prominence=0.03,height=0) # 0.01 prominence can be modified
    peaks_sec = [x/200 for x in peaks]
    troughs, _ = find_peaks(-dfy, prominence=0.03,height=0) #0.01 
    troughs_sec = [x/200 for x in troughs]
    results_full = peak_widths(dfy, peaks, 1)
    peak_heights = peak_dict.get('peak_heights')

    if len(peak_heights) == 0:
        return df, 0, 0, [], [], []

    subthreshold_peaks = []
    neg_peaks = []
    for i in range(0,len(peak_heights)):
        # if peak_heights[i] < 0.1:
        #     neg_peaks.append(i) # The foot was dragging rather than being fully lifted
        #     continue
        if peak_heights[i] < 0.2:
            subthreshold_peaks.append(i) # The animal tried to take a step but the foot was too low for the step to be considered successful
            continue
    if len(subthreshold_peaks) == 0:
        print("No subthreshold peaks found.")
    
    results_full_sec = [x/200 for x in results_full]
    left = [results_full_sec[2][i] for i in subthreshold_peaks]
    right = [results_full_sec[3][i] for i in subthreshold_peaks]
    sub_vline_points = list(zip(left, right))
    neg_left = [results_full_sec[2][i] for i in neg_peaks]
    neg_right = [results_full_sec[3][i] for i in neg_peaks]
    neg_vline_points = list(zip(neg_left, neg_right))

    # Merge any overlapping subthreshold peaks
    sub_vline_points = merge_regions(sub_vline_points)
    neg_vline_points = merge_regions(neg_vline_points)
    
    hline_heights = []
    for region in sub_vline_points:
        left2, right2 = region
        left2 = int(left2*200)
        right2 = int(right2*200)
        avg_height = (dfy[left2] + dfy[right2])/2
        hline_heights.append(avg_height)
    if len(sub_vline_points) > 0:
        left,right = list(zip(*sub_vline_points))
          
    if(draw):

        x_lst_sec = [x/200 for x in range(0,len(dfy))]
        plt.figure(figsize=(14, 10))
        plt.rcParams["axes.linewidth"] = 2 
        
        plt.plot(x_lst_sec,dfy,color = 'darkgrey', linewidth=3)
        plt.plot(peaks_sec, dfy[peaks], "o", markersize=10, color='royalblue')
        plt.plot(troughs_sec, dfy[troughs], "o", markersize=10, color='lightsalmon')
        plt.hlines(hline_heights,left,right, color='lightcoral')
        plt.tick_params(axis='x', labelsize=18, length=7, width=2)
        plt.tick_params(axis='y', labelsize=18, length=7, width=2)
        plt.axhline(min(dfy),color='black', linewidth=2)
        # plt.ylim(-0.03, 0.40)

        for widths in sub_vline_points:
            if widths in neg_vline_points:
                plt.axvspan(*widths, alpha=0.3,color = 'firebrick')
            else:
                plt.axvspan(*widths, alpha=0.3,color = 'lightcoral')

        plt.show()

    pre_sub_time = 0
    for region in sub_vline_points:
        x1 = int(region[0] * 200)
        x2 = int(region[1] * 200)
        y1 = dfy[x1]
        y2 = dfy[x2]
        m = (y2 - y1) / (x2 - x1)
        b = y1 - m * x1
        def eq(x):
            return m * x + b
        for i in range (x1,x2):
            dfy[i] = eq(i)
        
        pre_sub_time += region[1] - region[0]

    neg_time = 0
    for region in neg_vline_points:
        neg_time += region[1] - region[0]
    sub_time = pre_sub_time - neg_time
        
    new_df = pd.DataFrame()
    new_df[df.columns[0]] = df.iloc[:,0]
    new_df[df.columns[1]] = dfy

    return new_df, sub_time, neg_time, sub_vline_points, neg_vline_points, hline_heights

def descriptive_statistics(lst):
    n = len(lst)
    mean = sum(lst) / n
    variance = sum((x - mean) ** 2 for x in lst) / n
    stdev = sqrt(variance)
    return mean, stdev

def merge_regions(regions):
    if not regions:
        return []

    # Sort tuples based on x
    sorted_tuples = sorted(regions, key=lambda x: x[0])

    merged_regions = [sorted_tuples[0]]

    for next_range in sorted_tuples[1:]:
        current_range = merged_regions[-1]

        # Check for overlap
        if next_range[0] <= current_range[1]:
            # Merge ranges
            merged_regions[-1] = (current_range[0], max(current_range[1], next_range[1]))
        else:
            merged_regions.append(next_range)

    return merged_regions

def improved_smoothing(df, window_length, polyorder, peak_radius, draw):
    dfy = df.iloc[:, 1].to_numpy()
    
    # apply the Savitzky-Golay filter
    y_lst = savgol_filter(dfy, window_length, polyorder)
    
    # find local maxs/mins in the original data
    maximums, _ = find_peaks(dfy, prominence=0.5)
    minimums, _ = find_peaks(-dfy, prominence=0.5)
    
    # find local maximums in the smoothed data
    smooth_maximums, _ = find_peaks(y_lst, prominence=0.5)
    smooth_minimums, _ = find_peaks(-y_lst, prominence=0.5)
    
    # adjust the smoothed values to maintain original extrema
    y_lst[smooth_maximums] = np.maximum(y_lst[smooth_maximums], dfy[smooth_maximums])
    y_lst[smooth_minimums] = np.minimum(y_lst[smooth_minimums], dfy[smooth_minimums])
    
    # overwrite maxs/mins in smooth data with maxs/mins in original data
    y_lst[maximums] = dfy[maximums]
    y_lst[minimums] = dfy[minimums]

    # smooth the peaks locally
    for peak in maximums:
        left = max(0, peak - peak_radius)
        right = min(len(dfy), peak + peak_radius)
        peak_range = y_lst[left:right]
        
        # apply the Savitzky-Golay filter
        smooth_peak_range = savgol_filter(peak_range, window_length=peak_radius, polyorder=3)
        
        # ensure the maxs/mins do not change
        peak_index = peak - left
        smooth_peak_range[peak_index] = peak_range[peak_index]
        
        # overwrite the values of the original list with the smoothed values
        y_lst[left:right] = smooth_peak_range

    # forcefully smooth remainings sharp points
    for peak in maximums:
        left_1 = max(0, peak - 1)
        left_2 = max(0, peak - 2)
        right_1 = min(len(y_lst) - 1, peak + 1)
        right_2 = min(len(y_lst) - 1, peak + 2)
        
        if y_lst[peak] > y_lst[left_1] and y_lst[peak] > y_lst[right_1]:
            # move further points up
            if left_2 != left_1:
                y_lst[left_2] = y_lst[peak] - (y_lst[peak] - y_lst[left_1]) * 0.5
            if right_2 != right_1:
                y_lst[right_2] = y_lst[peak] - (y_lst[peak] - y_lst[right_1]) * 0.5
            
            # move closer points up
            y_lst[left_1] = y_lst[peak] - (y_lst[peak] - y_lst[left_1]) * 0.1
            y_lst[right_1] = y_lst[peak] - (y_lst[peak] - y_lst[right_1]) * 0.1

    new_df = pd.DataFrame()
    new_df[df.columns[0]] = df.iloc[:,0]
    new_df[df.columns[1]] = y_lst
    
    if draw:
        dfy = new_df.iloc[:, 1].to_numpy()
        x_lst_sec = [x/200 for x in range(0,len(dfy))]
        plt.figure(figsize=(14, 10))
        plt.rcParams["axes.linewidth"] = 2 
        plt.plot(x_lst_sec, dfy, color='darkgray', linewidth=3)
        plt.tick_params(axis='x', labelsize=18, length=7, width=2)
        plt.tick_params(axis='y', labelsize=18, length=7, width=2)
        plt.show()
    
    return new_df

In [24]:
# These parameters can be modified to fine-tune analysis
window_length = 17
polyorder = 3
peak_radius = 4
height = 0.19
draw = False

labels = [('MouseBottomLeft_Hindlimb_x', 'MouseBottomLeft_Hindlimb_y'),
          ('MouseTopRight_Hindlimb_x', 'MouseTopRight_Hindlimb_y'),
          ('MouseBottomLeft_Forelimb_x', 'MouseBottomLeft_Forelimb_y'),
          ('MouseTopRight_Forelimb_x', 'MouseTopRight_Forelimb_y')]
labels2 = ['Left Hindlimb', 'Right Hindlimb', 'Left Forelimb', 'Right Forelimb']
mouse_dfs = []
files = get_files()

for file in files:
    master_lst = []
    animalName = get_animal_name(file)
    print(animalName)

    df = read_raw_data(file)
    dfc = clean_data(df)

    for i in range(0,4):

        limb_dict = {}
        limbName = labels2[i]
        xy = dfc.loc[:, labels[i]]
        
        # If the left side of the animal is being viewed, it must be flipped. This is because the MotoRater records it up-side down.
        if(i%2 == 0):
            xy = -xy

        units_xy = correct_units(xy,draw) # Converts from pixels/frame to cm/sec
        smooth_xy = improved_smoothing(units_xy,window_length,polyorder,peak_radius,draw) # Smooth data without changing local extrema
        undist_xy = fisheye_dist(smooth_xy,draw) # Removes fisheye distortion. When True, shows the fisheye distortion figure.
        final_xy, time_subthreshold, time_neg, sub_vline, neg_vline, sub_heights = remove_subthresholds(undist_xy,draw)

        widths, num_peaks, avg_peak_dist, avg_peak_height, avg_peak_area = gait_widths(final_xy,height,draw,sub_vline,neg_vline,sub_heights) #When True, shows the gait profile by plotting limb y-coords against time.

        time_swing = sum([(t[1] - t[0]) for t in widths])
        time_stance = len(xy)/200 - time_swing - time_subthreshold

        limb_dict = {'Time in Swing': time_swing,
                     'Time in Stance': time_stance,
                     'Time in Swinging Subthreshold': time_subthreshold,
                     'Time in Dragging Subthreshold' : time_neg,
                     'Number of Swings': num_peaks,
                     'Average Peak Distance': avg_peak_dist,
                     'Average Peak Height' : avg_peak_height,
                     'Average Peak Area' : avg_peak_area}
        master_lst.append(limb_dict)
    
    dataframes = []
    for i, measure in enumerate(master_lst):
            iterables = [[labels2[i]], measure.keys()]
            index = pd.MultiIndex.from_product(iterables, names=["Limb", "Metric"])
            dataframe = pd.DataFrame(
                data=measure.values(),
                index=index,
                columns=[animalName]
            )
            dataframes.append(dataframe)
    stacked_df = pd.concat(dataframes)
    mouse_dfs.append(stacked_df)
final_df = pd.concat(mouse_dfs, axis=1)
final_df.to_excel("MotorRater Oscillation EAE E3.xlsx")


06_Ctl_-1dpi_Ex3_4
No subthreshold peaks found.
07_Ctl_-1dpi_Ex3_1
No subthreshold peaks found.
No subthreshold peaks found.
08_Ctl_-1dpi_Ex3_2
No subthreshold peaks found.
No subthreshold peaks found.
09_Ctl_-1dpi_Ex3_4
No subthreshold peaks found.
No subthreshold peaks found.
10_Ctl_-1dpi_Ex3_3
No subthreshold peaks found.
No subthreshold peaks found.
No subthreshold peaks found.
01_EAE_-1dpi_Ex3_6
No subthreshold peaks found.
No subthreshold peaks found.
No subthreshold peaks found.
02_EAE_-1dpi_Ex3_4
No subthreshold peaks found.
No subthreshold peaks found.
No subthreshold peaks found.
03_EAE_-1dpi_Ex3_4
No subthreshold peaks found.
No subthreshold peaks found.
No subthreshold peaks found.
04_EAE_-1dpi_Ex3_2
No subthreshold peaks found.
No subthreshold peaks found.
No subthreshold peaks found.
No subthreshold peaks found.
05_EAE_-1dpi_Ex3_6
No subthreshold peaks found.
No subthreshold peaks found.
No subthreshold peaks found.
06_Ctl_1dpi_Ex3_7
No subthreshold peaks found.
No subthr