# Made by Gleb Perevoznyuk, HSE. 2024
## All rights reserved. Do not modify or remove this notice.

# Installation, imports and functions

## For setting up your environment, you only need to install these packages once per machine and environment.


In [None]:
# !pip install os-sys
# !pip install openpyxl
# !pip install pandas
# !pip install pathlib  
# !pip install numpy
# !pip install matplotlib
# !pip install scipy
# !pip install mne
# !pip install screeninfo


## For each session, you need to rerun these import statements to ensure all necessary libraries are loaded.


In [None]:
import os
import openpyxl
import pandas as pd
from pathlib import Path
import re
import mne


import numpy as np

#%matplotlib widget
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt
from screeninfo import get_monitors

import matplotlib.pyplot as plt

from scipy.signal import find_peaks
from scipy.signal import butter, filtfilt

## Functions.

In [None]:
def create_participant_condition_dict(folder_path):
    '''
    Creates a dictionary where the keys are participant names, and the values are nested dictionaries. 
    These nested dictionaries have conditions as keys, and each condition has a nested dictionary with the file path.

    Parameters:
    folder_path (str): Path to the folder containing .vhdr files.

    Returns:
    dict: A dictionary where the keys are participant names, and the values are nested dictionaries 
          with conditions as keys and a nested dictionary containing the file path.

    Example:
    Input:
    The specified folder contains the following files:
    MaSh_13_s1_B_1.vhdr
    MaSh_13_s1_BM_1.vhdr

    Output:
    {
        'MaSh_13_s1': {
            'B_1': {'path': 'C:\\path\\to\\folder\\MaSh_13_s1_B_1.vhdr'},
            'BM_1': {'path': 'C:\\path\\to\\folder\\MaSh_13_s1_BM_1.vhdr'}
        }
    }

    Description:
    1. The function recursively iterates through all files and directories in the specified folder.
    2. For each file with a .vhdr extension that does not end with 'ttt.vhdr', it splits the file name into parts.
    3. The participant name is formed from the first two parts of the file name.
    4. The condition is determined from the third and fourth parts of the file name if they contain 'B' but not 'BM'. 
       Otherwise, the condition is determined only from the third part.
    5. If the participant name is not already in the dictionary, a new nested dictionary for conditions is created.
    6. If the condition is not already in the nested dictionary, a new dictionary for file path is created.
    7. The file path is stored in the nested dictionary under the 'path' key.
    '''

    participant_dict = {}

    for root, dirs, files in os.walk(folder_path):
        for file_name in files:
            if file_name.endswith('.vhdr') and not file_name.endswith('ttt.vhdr'):
                parts = file_name.split('_')
                participant_name = parts[0] + '_' + parts[1]
                if 'B' in parts[3] and not 'BM' in parts[3]:
                    condition = parts[3] + '_' + parts[4].split('.')[0]
                else: 
                    condition = parts[3].split('.')[0]
                
                if participant_name not in participant_dict:
                    participant_dict[participant_name] = {}

                participant_dict[participant_name][condition] = {'path': os.path.join(root, file_name)}

    return participant_dict
    

In [None]:
def upload_and_filtrate(vhdr_file):
    """
    Loads and preprocesses a BrainVision data file.

    Parameters:
    vhdr_file (str): Path to the BrainVision .vhdr file.

    Returns:
    tuple: A tuple containing:
        - raw (mne.io.Raw): The raw data object loaded from the BrainVision file.
        - sfreq (float): The sampling frequency of the data.
        - ch_list (list of str): A list of channel names in the data.
    """
    raw = mne.io.read_raw_brainvision(vhdr_file, preload=True)
    sfreq = raw.info["sfreq"]
    ch_list = raw.ch_names
    return raw, sfreq, ch_list


In [None]:
def filtrations(raw, fs):
    """
    Applies high-pass and notch filters to the raw EEG data.

    Parameters:
    raw (mne.io.Raw): The raw EEG data object.
    fs (float): The sampling frequency of the data.

    Returns:
    mne.io.Raw: The filtered raw EEG data object.
    """
    # Define filter parameters
    high_cutoff = 15  
    order = 2
    notch_freq_range = [48, 52]

    # Create Butterworth high-pass filter
    b, a = butter(order, high_cutoff / (0.5 * fs), btype='high')
    
    # Apply high-pass filter
    raw_data = raw.get_data()
    filtered_data = filtfilt(b, a, raw_data, axis=1)
    
    # Create Butterworth band-stop (notch) filter
    br, ar = butter(order, [notch_freq_range[0] / (0.5 * fs), notch_freq_range[1] / (0.5 * fs)], btype='bandstop')
    
    # Apply band-stop filter
    filtered_data = filtfilt(br, ar, filtered_data, axis=1)
    
    # Create new Raw object with filtered data
    info = raw.info
    raw_filtered = mne.io.RawArray(filtered_data, info)

    raw_filtered.set_annotations(raw.annotations)
    
    # Transfer events
    events, _ = mne.events_from_annotations(raw)
    event_id = {99999:'New Segment/', 128:'Stimulus/S128'}
    annotations = mne.annotations_from_events(events, fs, event_id)
    raw_filtered.set_annotations(annotations)
    
    return raw_filtered
    

In [None]:
def plot_selected_channels(raw, raw_filtered, channel_names, first_label, second_label, participant_name, condition_name):
    """
    Plots the original and filtered signals for selected channels.

    Parameters:
    raw (mne.io.Raw): The raw EEG data object.
    raw_filtered (mne.io.Raw): The filtered EEG data object.
    channel_names (list of str): List of channel names to plot.
    first_label (str): Label for the original signal.
    second_label (str): Label for the filtered signal.
    participant_name (str): Name of the participant.
    condition_name (str): Condition name.

    """
    times = raw.times
    num_channels = len(channel_names)

    plt.figure(figsize=(15, num_channels * 2))
    
    for i, channel_name in enumerate(channel_names):
        original_data = raw.get_data(picks=channel_name)[0]
        clean_data = raw_filtered.get_data(picks=channel_name)[0]

        plt.subplot(num_channels, 1, i + 1)
        plt.plot(times, original_data, label=first_label)
        plt.plot(times, clean_data, label=second_label, linestyle='--')
        plt.title(f'{participant_name} - {condition_name} - Signal Before and After Filtering - {channel_name}')
        plt.xlabel('Time (s)')
        plt.ylabel('Amplitude')
        plt.legend()

    plt.tight_layout()
    plt.show()
    plt.close()


In [None]:
def picking_channel(raw, ch_list):
    """
    Creates a dictionary of raw data objects for each specified channel.

    Parameters:
    raw (mne.io.Raw): The raw EEG data object.
    ch_list (list of str): List of channel names to pick from the raw data.

    Returns:
    dict: A dictionary where keys are channel names and values are corresponding raw data objects containing only that channel's data.
    """
    dict_ch_data = {}
    for channel_name in ch_list:
        raw_copy = raw.copy()
        raw_copy.pick([channel_name])  
        dict_ch_data[channel_name] = raw_copy
        
    return dict_ch_data


In [None]:
def epoching_by_channel(channel_data_dict):
    """
    Creates epochs for each channel based on TMS pulse events.

    Parameters:
    channel_data_dict (dict): A dictionary where keys are channel names and values are corresponding raw data objects.

    Returns:
    dict: A dictionary where keys are channel names and values are epoch objects created from the raw data.
    """
    epoched_dict = {}
    for channel, ch_data in channel_data_dict.items():
        ch_events, ch_events_id = mne.events_from_annotations(ch_data)
        ch_events_id = {'TMS pulse': 128}
        ch_epochs = mne.Epochs(ch_data, ch_events, ch_events_id, tmin=-0.5, tmax=1, baseline=(-0.1, -0.01), preload=True)
        epoched_dict[channel] = ch_epochs
    return epoched_dict


In [None]:
def process_participant_data(participant_dict):
    """
    Processes the EEG data for each participant and condition.
    The processing includes loading, filtering, epoching, and storing the data.

    Parameters:
    participant_dict (dict): A dictionary with participant names as keys and conditions as nested dictionaries.
                             Each condition contains the path to the corresponding .vhdr file.

    Returns:
    dict: A dictionary structure similar to the input, but with additional keys for the sampling frequency (sfreq),
          channel list (ch_list), and epochs for each channel.
    """
    raw_data_dict = {}

    for part_name, conditions in participant_dict.items():
        for cond_name, info in conditions.items():
            # Load and filter the data
            path = info['path']
            raw, sfreq, ch_list = upload_and_filtrate(path)
            
            # Update the dictionary with metadata
            participant_dict[part_name][cond_name]['sfreq'] = sfreq
            participant_dict[part_name][cond_name]['ch_list'] = ch_list
            
            # Apply filtering to the raw data
            raw_filtered = filtrations(raw, sfreq)
            
            # Store raw and filtered data for plotting later
            raw_data_dict[(part_name, cond_name)] = (raw, raw_filtered, ch_list)
            
            # Create epochs for each channel directly
            epoched_dict = epoching_by_channel({ch: raw_filtered.copy().pick([ch]) for ch in ch_list})
            participant_dict[part_name][cond_name]['epochs'] = epoched_dict
    
    # Plot the original and filtered signals after processing all data
    #for (part_name, cond_name), (raw, raw_filtered, ch_list) in raw_data_dict.items():
        #plot_selected_channels(raw, raw_filtered, ch_list, 'Original signal', 'Filtered signal', part_name, cond_name)

    return participant_dict


In [None]:
def calculate_p2p_amplitude(epoch_data):
    """
    Calculates the peak-to-peak amplitude of the transmitted epoch data.
    
    Parameters:
    epoch_data (array): array of signal data.
    
    Returns:
    tuple: (peak-to-peak amplitude, peak value, valley value, peak time, valley time)
    """
    peaks, _ = find_peaks(epoch_data)
    troughs, _ = find_peaks(-epoch_data)
    
    p2p_amplitude = 0
    peak_val, trough_val, peak_time, trough_time = None, None, None, None
    if peaks.size > 0 and troughs.size > 0:
        peak_val = epoch_data[peaks].max()
        trough_val = epoch_data[troughs].min()
        p2p_amplitude = float(peak_val) - float(trough_val)
        peak_time = peaks[np.argmax(epoch_data[peaks])]
        trough_time = troughs[np.argmin(epoch_data[troughs])]
    
    return p2p_amplitude, peak_val, trough_val, peak_time, trough_time


In [None]:
import matplotlib.pyplot as plt
from screeninfo import get_monitors

def plot_epochs_for_all_channels(epoched_dict, ch_list, condition, time_window=(-0.5, 0.5), auto_accept=False):
    """
    Plots all channels for each epoch iteratively with the specified time window.

    Parameters:
    epoched_dict (dict): Dictionary containing epoched data.
    ch_list (list of str): List of channel names.
    condition (str): The condition name, automatically passed from the dictionary.
    time_window (tuple): Time window to display on the x-axis.
    auto_accept (bool): If True, automatically accept all epochs; if False, prompt for user input.
    
    Returns:
    pd.DataFrame: DataFrame containing the peak-to-peak amplitudes for each channel and epoch.
    """
    p2p_data = {muscle: [] for muscle in epoched_dict.keys()}
    
    # Get screen dimensions
    monitor = get_monitors()[0]  
    screen_width, screen_height = monitor.width, monitor.height
    
    # Calculate figure size based on screen dimensions
    fig_width = screen_width / 100  # Adjust scale factor as needed
    fig_height = screen_height / 100
    
    for epoch_idx in range(len(epoched_dict[ch_list[0]])):
        fig, axs = plt.subplots(len(ch_list) // 2 + len(ch_list) % 2, 2, figsize=(fig_width, fig_height), sharex=True)
        fig.suptitle(f'Epoch {epoch_idx + 1} - Condition: {condition}', fontsize=16)
        print(f'Epoch {epoch_idx + 1} - Condition: {condition}')
        
        for i, (ax, channel_name) in enumerate(zip(axs.flat, ch_list)):
            epoch_data = epoched_dict[channel_name].get_data(units='uV')[epoch_idx, 0, :]
            times = epoched_dict[channel_name].times
            
            p2p_amplitude, peak_val, trough_val, peak_time, trough_time = calculate_p2p_amplitude(epoch_data)

            ax.plot(times, epoch_data, label=f'{channel_name} - Epoch {epoch_idx + 1}')
            if peak_val is not None and trough_val is not None:
                ax.plot(times[peak_time], peak_val, 'ro')  # Red dot for the peak
                ax.plot(times[trough_time], trough_val, 'yo')  # Yellow dot for the though
                ax.axvline(times[peak_time], color='r', linestyle='--') # Red line for the peak
                ax.axvline(times[trough_time], color='y', linestyle='--') # Yellow line for the though
                if isinstance(p2p_amplitude, (int, float)):
                    ax.text(0.5, 0.8, f'P2P Amplitude: {p2p_amplitude:.2f}', 
                            transform=ax.transAxes, fontsize=10, verticalalignment='top', 
                            bbox=dict(facecolor='white', alpha=0.8))
                else:
                    ax.text(0.5, 0.8, f'P2P Amplitude: {p2p_amplitude}', 
                            transform=ax.transAxes, fontsize=10, verticalalignment='top',
                            bbox=dict(facecolor='white', alpha=0.8))
            ax.axvline(x=0, color='k', linestyle='-', label='TMS Pulse')
            ax.set_xlim(time_window)
            ax.set_title(channel_name)
            ax.set_xlabel('Time (s)')
            ax.set_ylabel('Amplitude')

            ax.legend(loc='upper left')

            yticks = ax.get_yticks()
            ax.set_yticks(yticks[1:])

            xticks = ax.get_xticks()
            ax.set_xticks(xticks[1:])

        plt.tight_layout()
        plt.show()
        plt.close()

        if auto_accept:
            epoch_accepted = '1'
        else:
            epoch_accepted = input("Enter 0 to reject this epoch or 1 to accept this epoch: ")
        
        for channel_name in ch_list:
            if epoch_accepted == '1':
                epoch_data = epoched_dict[channel_name].get_data(units='uV')[epoch_idx, 0, :]
                p2p_amplitude, _, _, _, _ = calculate_p2p_amplitude(epoch_data)
                p2p_data[channel_name].append(p2p_amplitude)
            else:
                p2p_data[channel_name].append(float('nan'))

    p2p_df = pd.DataFrame(p2p_data)
    p2p_df.index.name = condition  
    
    return p2p_df


In [None]:
def process_epochs_and_plot(participant_data):
    """
    Processes the epochs for each channel within each condition for each participant.
    It calculates the peak-to-peak amplitude for each epoch and plots the results.

    Parameters:
    participant_data (dict): The dictionary containing participant EEG data, conditions, and epochs.
    
    Returns:
    dict: A dictionary containing the peak-to-peak amplitude DataFrames for each condition of each participant.
    """
    p2p_amplitude_results = {}

    for participant, conditions in participant_data.items():
        p2p_amplitude_results[participant] = {}

        for condition, info in conditions.items():
            epochs = info['epochs']
            ch_list = info['ch_list']
            
            p2p_df = plot_epochs_for_all_channels(epochs, ch_list, condition)
            
            p2p_amplitude_results[participant][condition] = p2p_df
    
    return p2p_amplitude_results


In [None]:
def save_ptp_amplitude_to_excel(processed_data, folder_path):
    """
    Saves the peak-to-peak amplitude data for each participant and condition into an Excel file.

    Parameters:
    processed_data (dict): Dictionary containing the processed data for each participant and condition.
    folder_path (str): Path to the folder where the Excel file will be saved.
    """
    participant_name = list(processed_data.keys())[0]
    output_file = os.path.join(folder_path, f'{participant_name}_All_ptp_amplitude.xlsx')

    with pd.ExcelWriter(output_file) as writer:
        for part_name, conditions in processed_data.items():
            combined_df = pd.DataFrame()

            for cond_name, df in conditions.items():
                if isinstance(df, pd.DataFrame):
                    print(f"Processing {part_name} - {cond_name}")
                    if isinstance(df.columns, pd.MultiIndex):
                        condition_df = df
                    else:
                        df.columns = pd.MultiIndex.from_product([[cond_name], df.columns])
                        condition_df = df
                    
                    combined_df = pd.concat([combined_df, condition_df], axis=1)
                else:
                    print(f"No DataFrame found for {part_name} - {cond_name}")

            if not combined_df.empty:
                combined_df.to_excel(writer, sheet_name=part_name)
            else:
                print(f"No data to write for {part_name}")

    print(f'Saved P2P amplitudes to {output_file}')

# Run all the code below to get data

In [None]:
folder_path = r'C:\Users\himik\OneDrive\gotlibb\HSE\mep_extr\data\MaSh_13' #!!! CHANGE !!!
participant_dict = create_participant_condition_dict(folder_path)

participant_dict

In [None]:
processed_data = process_participant_data(participant_dict)
processed_data

In [None]:
p2p_results = process_epochs_and_plot(processed_data)

In [None]:
p2p_results

In [None]:
save_ptp_amplitude_to_excel(p2p_results, folder_path)