In [2]:
import os
import ipywidgets as widgets
from IPython.display import display, clear_output
from ipyfilechooser import FileChooser
import numpy as np
import matplotlib.pyplot as plt
from sonpy import lib as sp
from scipy.signal import butter, filtfilt
import pandas as pd
from tqdm.notebook import tqdm  # For progress bars
# Step 1: Modularize the Code
def parse_text_marks(text_marks, time_base, sample_rate):
    indices, texts, categories = [], [], []
    for mark in text_marks:
        indices.append(int(mark.Tick * time_base * sample_rate))
        text = mark.GetString().upper()
        texts.append(text)
        categories.append('PUSH' if 'PUSH' in text else 'PULL' if 'PULL' in text else 'UNKNOWN')
    return indices, texts, categories

def read_smr_file(file_path):
    MyFile = sp.SonFile(str(file_path), True)
    all_channels = {MyFile.GetChannelTitle(i): i for i in range(MyFile.MaxChannels()) if MyFile.ChannelType(i) != sp.DataType.Off}
    time_base = MyFile.GetTimeBase()
    channel_data = {}
    for name, i in all_channels.items():
        data_type = MyFile.ChannelType(i)
        size = MyFile.ChannelBytes(i)
        item_size = MyFile.ItemSize(i)
        scale = MyFile.GetChannelScale(i) / 6553.6
        offset = MyFile.GetChannelOffset(i)
        if data_type == sp.DataType.Adc:
            data = np.array(sp.SonFile.ReadInts(MyFile, i, size // item_size, 0)) * scale + offset
        elif data_type == sp.DataType.TextMark:
            data = np.array(sp.SonFile.ReadTextMarks(MyFile, i, size // item_size, 0))
        else:
            continue
        channel_data[name] = data
    return channel_data, time_base

def process_emg(data, fs, band=(20, 450), window_size=0.3):
    filtered = bandpass_filter(data, band[0], band[1], fs)
    rectified = np.abs(filtered)
    smoothed = np.convolve(rectified, np.ones(int(window_size * fs)) / (window_size * fs), mode='same')
    return smoothed

def low_pass_filter(data, cutoff, fs, order=5):
    b, a = butter(order, cutoff / (0.5 * fs), btype='low')
    return filtfilt(b, a, data)

def bandpass_filter(data, lowcut, highcut, fs, order=4):
    b, a = butter(order, [lowcut / (0.5 * fs), highcut / (0.5 * fs)], btype='band')
    return filtfilt(b, a, data)

def process_trials(trial_indices, trial_type, adjusted_indices, adjusted_texts):
    peak_emg_results = []
    for idx in trial_indices:
        first_index = adjusted_indices[idx]
        text = adjusted_texts[idx]
        
        # Define the time window for the plot
        seconds_before, seconds_after = 1, 5
        
        # Calculate the sample indices for the time window
        start_index = max(first_index - seconds_before * sample_rate, 0)
        end_index = min(first_index + seconds_after * sample_rate, len(channel_data['MZ']))
        
        # Extract data for the time window
        time_window = np.arange(start_index, end_index) / sample_rate
        MZ_window = channel_data['MZ'][start_index:end_index]
        EMG1_window = channel_data['EMG1'][start_index:end_index]
        EMG2_window = channel_data['EMG2'][start_index:end_index]
        EMG3_window = channel_data['EMG3'][start_index:end_index]
        
        # Ensure all arrays have the same length
        min_length = min(len(time_window), len(MZ_window), len(EMG1_window), len(EMG2_window), len(EMG3_window))
        time_window = time_window[:min_length]
        MZ_window = MZ_window[:min_length]
        EMG1_window = EMG1_window[:min_length]
        EMG2_window = EMG2_window[:min_length]
        EMG3_window = EMG3_window[:min_length]
        
        # Process and analyze data
        MZ_filtered = low_pass_filter(MZ_window, cutoff_frequency, sample_rate)
        mz_baseline = np.mean(channel_data['MZ'][start_index:first_index])
        MZ_change = MZ_filtered - mz_baseline
        max_change_value = np.max(np.abs(MZ_change))
        max_change_index = np.argmax(np.abs(MZ_change))
        max_value = mz_baseline + max_change_value
        max_change_time = time_window[max_change_index]

        EMG1_processed = process_emg(EMG1_window, sample_rate)
        EMG2_processed = process_emg(EMG2_window, sample_rate)
        EMG3_processed = process_emg(EMG3_window, sample_rate)

        initial_window_size = int(2 * sample_rate)
        emg1_baseline = np.mean(EMG1_processed[:initial_window_size])
        emg2_baseline = np.mean(EMG2_processed[:initial_window_size])
        emg3_baseline = np.mean(EMG3_processed[:initial_window_size])

        peak_analysis_start_index = max(max_change_index - int(0.05 * sample_rate), 0)
        peak_analysis_end_index = min(max_change_index + int(0.05 * sample_rate), len(EMG1_processed))

        peak_EMG1 = np.max(EMG1_processed[peak_analysis_start_index:peak_analysis_end_index])
        peak_EMG2 = np.max(EMG2_processed[peak_analysis_start_index:peak_analysis_end_index])
        peak_EMG3 = np.max(EMG3_processed[peak_analysis_start_index:peak_analysis_end_index])

        delta_EMG1 = peak_EMG1 - emg1_baseline
        perc_change_EMG1 = (delta_EMG1 / emg1_baseline) * 100 if emg1_baseline != 0 else np.nan
        delta_EMG2 = peak_EMG2 - emg2_baseline
        perc_change_EMG2 = (delta_EMG2 / emg2_baseline) * 100 if emg2_baseline != 0 else np.nan
        delta_EMG3 = peak_EMG3 - emg3_baseline
        perc_change_EMG3 = (delta_EMG3 / emg3_baseline) * 100 if emg3_baseline != 0 else np.nan
        
        peak_emg_results.append({
            'trial_type': trial_type,
            'trial_number': idx + 1,
            'MZ_change': f"Delta: {max_change_value:.2f} (Baseline: {mz_baseline:.2f}, Max: {max_value:.2f}, {max_change_value / mz_baseline * 100:.2f}%)",
            'EMG1 (Soleus)': f"Delta: {delta_EMG1:.5f} (Baseline: {emg1_baseline:.5f}, Max: {peak_EMG1:.5f}, {perc_change_EMG1:.5f}%)",
            'EMG2 (Tibialis Anterior)': f"Delta: {delta_EMG2:.5f} (Baseline: {emg2_baseline:.5f}, Max: {peak_EMG2:.5f}, {perc_change_EMG2:.5f}%)",
            'EMG3 (Gastrocnemius)': f"Delta: {delta_EMG3:.5f} (Baseline: {emg3_baseline:.5f}, Max: {peak_EMG3:.5f}, {perc_change_EMG3:.5f}%)",
            'time_window': time_window,
            'MZ_window': MZ_window,
            'MZ_filtered': MZ_filtered,
            'EMG1_processed': EMG1_processed,
            'EMG2_processed': EMG2_processed,
            'EMG3_processed': EMG3_processed,
            'mz_baseline': mz_baseline,
            'max_change_time': max_change_time,
            'peak_analysis_start_index': peak_analysis_start_index,
            'peak_analysis_end_index': peak_analysis_end_index,
            'emg1_baseline': emg1_baseline,
            'emg2_baseline': emg2_baseline,
            'emg3_baseline': emg3_baseline
        })
    
    return peak_emg_results


def print_results_table(results):
    if not results:
        print("No results to display.")
        return
    
    data = []
    for result in results:
        data.append({
            'Trial Type': result['trial_type'],
            'Trial Number': result['trial_number'],
            'MZ Change': result['MZ_change'],
            'EMG1 (Soleus)': result['EMG1 (Soleus)'],
            'EMG2 (Tibialis Anterior)': result['EMG2 (Tibialis Anterior)'],
            'EMG3 (Gastrocnemius)': result['EMG3 (Gastrocnemius)']
        })

    df = pd.DataFrame(data)
    print(df.to_string(index=False))
    return df  # Return the dataframe for potential export

def plot_trials(results, trial_type, dynamic=False):
    if not results:
        print(f"No {trial_type} trials to plot.")
        return
    
    plt.figure(figsize=(18, 12))
    for i, result in enumerate(results[:3]):
        plt.subplot(3, 2, i * 2 + 1)
        plt.plot(result['time_window'], result['MZ_window'], label='MZ (Original)', alpha=0.5)
        plt.plot(result['time_window'], result['MZ_filtered'], label='MZ (Filtered)', linestyle='--', alpha=0.8)
        plt.axhline(y=result['mz_baseline'], color='g', linestyle='--', label='Baseline')
        plt.axvline(x=result['max_change_time'], color='r', linestyle='--', label=f'Max Change: {np.max(np.abs(result["MZ_window"] - result["mz_baseline"])):.2f}')
        plt.title(f'MZ Channel (Torque) - {trial_type} Trial {i+1}')
        plt.xlabel('Time (s)')
        plt.ylabel('Amplitude')
        plt.legend()

        plt.subplot(3, 2, i * 2 + 2)
        plt.plot(result['time_window'], result['EMG1_processed'], label='Soleus (EMG1)', color='blue')
        plt.plot(result['time_window'], result['EMG2_processed'], label='Tibialis Anterior (EMG2)', color='orange')
        plt.plot(result['time_window'], result['EMG3_processed'], label='Gastrocnemius (EMG3)', color='green')
        plt.axhline(y=result['emg1_baseline'], color='blue', linestyle=':', label='EMG1 Baseline')
        plt.axhline(y=result['emg2_baseline'], color='orange', linestyle=':', label='EMG2 Baseline')
        plt.axhline(y=result['emg3_baseline'], color='green', linestyle=':', label='EMG3 Baseline')
        plt.axvline(x=result['time_window'][result['peak_analysis_start_index']], color='g', linestyle='--', label='Peak Analysis Window Start')
        plt.axvline(x=result['time_window'][result['peak_analysis_end_index']-1], color='b', linestyle='--', label='Peak Analysis Window End')
        plt.title(f'EMG Channels (Processed) - {trial_type} Trial {i+1}')
        plt.xlabel('Time (s)')
        plt.ylabel('Amplitude')
        plt.legend()
    
    plt.tight_layout()
    pdf_path = os.path.join(original_file_dir, f"{original_file_name}-plot.pdf")
    plt.savefig(pdf_path, format='pdf')  # Save as PDF
    if dynamic:
        plt.ion()
        plt.show()
    else:
        plt.show()

# Step 2: Add Interactive Widgets for Control
file_chooser = FileChooser()

file_path_text = widgets.Text(value='', placeholder='File path will be displayed here', description='File Path:', disabled=True)
output = widgets.Output()
adjusted_indices, adjusted_texts, deleted_marks, categories = [], [], [], []

def update_file_path(chooser):
    with output:
        clear_output()
        file_path = chooser.selected
        if file_path:
            file_path_text.value = file_path
            global channel_data, time_base, indices, texts, categories, sample_rate, cutoff_frequency, original_file_name, original_file_dir
            sample_rate = 2000  # Hz
            cutoff_frequency = 50  # Low-pass filter cutoff frequency for MZ
            channel_data, time_base = read_smr_file(file_path)
            indices, texts, categories = parse_text_marks(channel_data['TextMark'], time_base, sample_rate)
            
            # Properly initialize the adjusted indices, texts, and deleted marks lists
            adjusted_indices = indices.copy()
            adjusted_texts = texts.copy()
            deleted_marks = [False] * len(indices)
            
            # Get the original file name and directory
            original_file_name = os.path.splitext(os.path.basename(file_path))[0]
            original_file_dir = os.path.dirname(file_path)
            
            print(f"File loaded: {file_path}")
            print("Text marks available for selection.")

file_chooser.register_callback(update_file_path)

def plot_torque(time_window, MZ_window, idx, text, plot_output, removed=False):
    with plot_output:
        clear_output(wait=True)
        if time_window is not None and MZ_window is not None:
            plt.figure(figsize=(10, 5))
            plt.plot(time_window, MZ_window, label=f'{text}')
            plt.axvline(x=idx / sample_rate, color='r', linestyle='--')
            if removed:
                plt.scatter([idx / sample_rate], [np.max(MZ_window)], color='red', marker='x', s=100, label='Removed')
            plt.title(f'Torque Data Around Text Mark: {text}')
            plt.xlabel('Time (s)')
            plt.ylabel('Torque (MZ)')
            plt.legend()
            plt.show()

def update_plot(change, idx, text, adjust_mark, remove_mark, plot_output, i):
    if remove_mark.value:
        deleted_marks[i] = True
        plot_torque(None, None, idx, text, plot_output, removed=True)
        return
    else:
        deleted_marks[i] = False
    
    seconds_before = 2
    seconds_after = 3
    
    start_index = max(int(adjust_mark.value * sample_rate) - seconds_before * sample_rate, 0)
    end_index = min(int(adjust_mark.value * sample_rate) + seconds_after * sample_rate, len(channel_data['MZ']))
    
    time_window = np.arange(start_index, end_index) / sample_rate
    MZ_window = channel_data['MZ'][start_index:end_index]
    
    adjusted_indices[i] = int(adjust_mark.value * sample_rate)
    
    plot_torque(time_window, MZ_window, int(adjust_mark.value * sample_rate), text, plot_output)

def plot_torque_around_text_marks():
    with output:
        clear_output()
        seconds_before = 10
        seconds_after = 10
        
        global adjusted_indices, adjusted_texts, deleted_marks, categories
        adjusted_indices = indices.copy()
        adjusted_texts = texts.copy()
        deleted_marks = [False] * len(indices)
        
        for i, (idx, text) in enumerate(zip(indices, texts)):
            start_index = max(idx - seconds_before * sample_rate, 0)
            end_index = min(idx + seconds_after * sample_rate, len(channel_data['MZ']))
            
            time_window = np.arange(start_index, end_index) / sample_rate
            MZ_window = channel_data['MZ'][start_index:end_index]
            
            plot_output = widgets.Output()
            display(plot_output)
            
            adjust_mark = widgets.FloatSlider(
                value=idx / sample_rate,
                min=(start_index / sample_rate),
                max=(end_index / sample_rate),
                step=0.001,
                description=f'Adjust {text}:',
                continuous_update=False
            )
            
            remove_mark = widgets.Checkbox(value=False, description='Remove this mark', disabled=False)
            
            display(adjust_mark, remove_mark)
            
            adjust_mark.observe(
                lambda change, i=i, idx=idx, text=text, adjust_mark=adjust_mark, remove_mark=remove_mark, plot_output=plot_output:
                update_plot(change, idx, text, adjust_mark, remove_mark, plot_output, i),
                names='value'
            )
            
            remove_mark.observe(
                lambda change, i=i, idx=idx, text=text, adjust_mark=adjust_mark, remove_mark=remove_mark, plot_output=plot_output:
                update_plot(change, idx, text, adjust_mark, remove_mark, plot_output, i),
                names='value'
            )
            
            update_plot(None, idx, text, adjust_mark, remove_mark, plot_output, i)

plot_torque_button = widgets.Button(description="Plot Torque Around Text Marks")
plot_torque_button.on_click(lambda b: plot_torque_around_text_marks())

process_button = widgets.Button(description="Process and Plot Data")

save_button = widgets.Button(description="Save Results as CSV")

selected_trials_label = widgets.Label(value="Selected PUSH: 0 | PULL: 0")

def update_selected_trials_count():
    selected_categories = [categories[i] for i in range(len(categories)) if not deleted_marks[i]]
    push_count = selected_categories.count('PUSH')
    pull_count = selected_categories.count('PULL')
    selected_trials_label.value = f"Selected PUSH: {push_count} | PULL: {pull_count}"

def process_and_plot_data(b):
    with output:
        clear_output()
        selected_indices = [adjusted_indices[i] for i in range(len(adjusted_indices)) if not deleted_marks[i]]
        selected_texts = [adjusted_texts[i] for i in range(len(adjusted_texts)) if not deleted_marks[i]]
        selected_categories = [categories[i] for i in range(len(categories)) if not deleted_marks[i]]
        
        if not selected_indices:
            print("No valid text marks selected. Please review your selections.")
            return
        
        push_indices = [i for i, t in enumerate(selected_categories) if t == 'PUSH']
        pull_indices = [i for i, t in enumerate(selected_categories) if t == 'PULL']
        
        # Progress bar for data processing
        with tqdm(total=len(push_indices) + len(pull_indices)) as pbar:
            push_results = process_trials(push_indices, 'PUSH', selected_indices, selected_texts)
            pbar.update(len(push_indices))
            pull_results = process_trials(pull_indices, 'PULL', selected_indices, selected_texts)
            pbar.update(len(pull_indices))
        
        plot_trials(push_results, 'PUSH')
        plot_trials(pull_results, 'PULL')
        
        # Print and save results
        push_df = print_results_table(push_results)
        pull_df = print_results_table(pull_results)
        global combined_df
        combined_df = pd.concat([push_df, pull_df], ignore_index=True)
        update_selected_trials_count()

def save_results(b):
    if combined_df is not None:
        save_path = os.path.join(original_file_dir, f"{original_file_name}-biodexresults.csv")
        combined_df.to_csv(save_path, index=False)
        with output:
            print(f"Results saved as '{save_path}'")

process_button.on_click(process_and_plot_data)
save_button.on_click(save_results)

display(file_chooser, file_path_text, plot_torque_button, output, process_button, save_button, selected_trials_label)


FileChooser(path='C:\Users\anehrujee\Desktop\biodex', filename='', title='', show_hidden=False, select_desc='S…

Text(value='', description='File Path:', disabled=True, placeholder='File path will be displayed here')

Button(description='Plot Torque Around Text Marks', style=ButtonStyle())

Output()

Button(description='Process and Plot Data', style=ButtonStyle())

Button(description='Save Results as CSV', style=ButtonStyle())

Label(value='Selected PUSH: 0 | PULL: 0')