# This notebook performs all the EMG processing, primarily using MNE and other custom functions for artefact correction using a template matching approach

In [None]:
import scipy.io
import numpy as np
import matplotlib.pyplot as plt
import mne
from mne import create_info
from mne.io import RawArray
import pickle
from scipy.signal import firwin, lfilter
from scipy.fftpack import fft
import os
import pandas as pd
from scipy.stats import linregress
import scipy.signal
from BBO_Analysis_Functions import infer_rights, correct_drift, find_correlation_peaks, extract_correlations, correct_data_with_template, correct_data_with_template_2
import matplotlib.pyplot as plt
import matplotlib
matplotlib.use('Qt5Agg')

In [None]:
# File index to extract
EMGpath = "Path to EMG data/EMG/" 
# extract all .mat files in the folder
EMGfiles = [f for f in os.listdir(EMGpath) if f.endswith('.mat')]
for EMGfile in EMGfiles:
    # if the file name contains test, remove it
    if 'test' in EMGfile:
        print(EMGfile)
        EMGfiles.remove(EMGfile)


for index in range(len(EMGfiles)):
    EMGfile = EMGfiles[index]
    participant = EMGfile[:-4]
    fpath = os.path.join(EMGpath, EMGfile)
    data = scipy.io.loadmat(fpath)
    keys = list(data.keys())
    print(participant)
    srate = 2000

    processed_path = "Path to EMG data/Processed/"
    if not os.path.exists(processed_path):
        os.makedirs(processed_path)
        print(f"Folder created: {processed_path}")
    else:
        print(f"Folder already exists: {processed_path}")

    subject_write_path = os.path.join(processed_path, participant)
    if not os.path.exists(subject_write_path):
        os.makedirs(subject_write_path)
        print(f"Folder created: {subject_write_path}")

    # Import the csv file with psychopy data
    responsepath = "Path to Psychopy data/Psychopy/"
    responsefiles = [f for f in os.listdir(responsepath) if f.endswith('.csv') and (f.__contains__(EMGfile[3:-4] + '_' + EMGfile[:3]) and f.startswith(EMGfile[3:-4]+'_') and not f.__contains__('trial'))]
    responsefiles.sort()
    filename = responsefiles[0]
    print(filename)

    # try each file to see if it is the correct one
    for filename in responsefiles:
        fpath = os.path.join(responsepath, filename)
        responses = pd.read_csv(fpath)
        if 'p_port' not in responses.keys():
            continue
        elif len(responses['p_port']>1) == 102:
            break
        
    # Assert that the EEG and the response file match in the participant number
    assert int(filename.split('_')[0]) == index+1 == int(participant[3:]), 'EMG and Psychopy files do not match'

    # Store behavioural data
    age = responses['O_Age.response'].dropna().values[0]
    gender = responses['slider_gender.response'].dropna().values[0]
    participant_response = responses['response_participant.keys'][responses['p_port'] > 1].values
    print(len(participant_response))
    matches = responses['response_participant.corr'][responses['p_port'] > 1].values
    rts = responses['response_participant.rt'][responses['p_port'] > 1].values
    correct_resp = responses['correct_resp'][responses['p_port'] > 1].values
    intensities = responses['intensity_rating.response'].dropna().values
    cues = responses['cue'][responses['p_port'] > 1].values
    validity = responses['validity'][responses['p_port'] > 1].values

    behavioural_data = {
        'age': age,
        'gender': gender,
        'participant_response': participant_response,
        'matches': matches,
        'rts': rts,
        'correct_resp': correct_resp,
        'intensities': intensities,
        'cues': cues,
        'validity': validity
    }

    # Extract the data and the stimulus channels
    left_signal = data['data'][:,0]/1000
    right_signal = data['data'][:,1]/1000
    left_stim = data['data'][:,2]
    invalid_stim = data['data'][:,3]
    right_stim = data['data'][:,4]

    # Delete the data variable
    del data
    timevec = np.arange(0, len(left_signal)/srate, 1/srate)
    def get_indices(stim):
        stim_idx = np.where(np.diff(stim) > 2)[0]
        num_trig = len(stim_idx)
        if len(stim_idx)>45:
            print(f'The original number of stim indices is {len(stim_idx)}')
            stim_idx = stim_idx[-45:]
        return stim_idx, num_trig

    # Stim indices
    left_idx, numleft = get_indices(left_stim)
    right_idx, numright = get_indices(right_stim)
    invalid_idx, numinvalid = get_indices(invalid_stim)
    behavioural_data['left_idx'] = left_idx
    behavioural_data['right_idx'] = right_idx
    behavioural_data['invalid_idx'] = invalid_idx

    Orig_eventcount = {'left': numleft, 'right': numright, 'invalid': numinvalid}
    print(f'Left markers: {len(left_idx)}, Right markers: {len(right_idx)}, Invalid markers: {len(invalid_idx)}')
    if not os.path.exists(os.path.join(EMGpath,'events_in_emg.npy')):
        events_in_emg = {}
    elif os.path.exists(os.path.join(EMGpath,'events_in_emg.npy')):
        events_in_emg = np.load(os.path.join(EMGpath,'events_in_emg.npy'), allow_pickle=True).item()
    events_in_emg[EMGfile] = Orig_eventcount # This will not work
    np.save(os.path.join(EMGpath,'events_in_emg.npy'), events_in_emg) 
    del events_in_emg

    # Fix the markers
    if len(right_idx) < 45:
        print('Right markers are missing')
        right_idx = infer_rights(responses, srate, left_idx)
        print('Fixing right markers')
        right_idx = correct_drift(responses, srate, left_idx, right_idx)
        print(f'Corrected Right markers: {len(right_idx)}')

    if len(invalid_idx) > 12:
        invalid_idx = invalid_idx[-12:]

    assert len(right_idx) == 45, 'Right markers are still missing'
    assert len(left_idx) == 45, 'Left markers are still missing'
    assert len(invalid_idx) == 12, 'Invalid markers are still missing'

    # how many lefts and rights
    print('Number of left cues: ', len(cues[(cues == 'left') & (validity == 'valid')]))
    print('Number of right cues: ', len(cues[(cues == 'right') & (validity == 'valid')]))
    print('Number of invalid left cues: ', len(cues[(cues == 'left') & (validity == 'invalid')]))
    print('Number of invalid right cues: ', len(cues[(cues == 'right') & (validity == 'invalid')]))
    print('First cue: ', cues[0], '; First validity: ', validity[0])

    assert len(left_signal) == len(right_signal), "Left and right signals dont have the same length."

    emg_data = np.vstack([left_signal, right_signal])
    print(f'EMG data shape: {emg_data.shape}')
    ch_names = ['EMG_left', 'EMG_right']  # Channel names
    ch_types = ['emg', 'emg']  # Channel types
    info = create_info(ch_names=ch_names, sfreq=srate, ch_types=ch_types)
    raw_EMG = RawArray(emg_data, info)
    raw_EMG = raw_EMG.resample(1000)
    info = raw_EMG.info

    # Filter the EMG data
    raw_EMG.filter(l_freq=3, h_freq=50, picks = ch_names,fir_design='firwin')
    raw_EMG.notch_filter(freqs=(50,100), filter_length='auto', picks = ch_names, notch_widths = 3)

    ############################## Clean data with template matching ############################## 
    raw_data = raw_EMG.get_data(picks = 'emg')
    template_28 = np.load('Path to EMG data/Processed/BBO28/BBO28_EMG_template.npy')
    # Normalize the template before convolution
    norm_template_28 = template_28.mean(axis = 0)
    norm_template_28 /= np.linalg.norm(norm_template_28)

    All_peaks = find_correlation_peaks(raw_data,norm_template_28)
    All_correlations = extract_correlations(raw_data, norm_template_28, All_peaks)
    Corrected_data, All_errors = correct_data_with_template(
        raw_data=raw_data,
        norm_template_28=norm_template_28,
        All_peaks=All_peaks,
        corrthresh=0.8
    )

    # Now use the second template
    template_2 = np.load('Path to EMG data/Processed/BBO13/template_2.npy')
    template_2 = template_2.mean(axis = 0)
    template_2 /= np.linalg.norm(template_2)
    All_peaks = find_correlation_peaks(Corrected_data,template_2)
    All_correlations = extract_correlations(Corrected_data, template_2, All_peaks)
    Corrected_data_2, All_errors_2, subtracted_segments = correct_data_with_template_2(
        Input_data=Corrected_data,
        template_2=template_2,
        All_peaks=All_peaks,
        corrthresh=0.8
    )

    Corrected_EMG = RawArray(Corrected_data_2, info)
    # del Corrected_data, All_errors, All_peaks, All_correlations

    # Define Events
    events = np.zeros((len(left_idx) + len(right_idx) + len(invalid_idx), 3))
    events[:len(left_idx), 0] = left_idx
    events[:len(left_idx), 2] = 1
    events[len(left_idx):len(left_idx) + len(right_idx), 0] = right_idx
    events[len(left_idx):len(left_idx) + len(right_idx), 2] = 2
    events[len(left_idx) + len(right_idx):, 0] = invalid_idx
    events[len(left_idx) + len(right_idx):, 2] = 3
    events[:, 0] = (events[:, 0] / 2).round() # downsample the events
    events = events.astype(int)
    events = events[events[:, 0].argsort()]
    event_id = {'Left': 1, 'Invalid':3, 'Right':2}

    tmin = -4
    tmax = 0
    epochs_EMG = mne.Epochs(Corrected_EMG, events, event_id, tmin, tmax, baseline=None, preload=True)
    epochs_EMG.apply_baseline(baseline=(None, None))

    # Time-frequency
    frequencies = np.arange(3, 51, 1)  # Define frequencies of interest
    n_cycles = frequencies/2
    EMG_tfr =  epochs_EMG.compute_tfr(method = 'morlet', 
                                        freqs = frequencies,
                                        n_cycles = n_cycles,
                                        output='power',
                                        picks='emg',
                                        average=False, 
                                        return_itc=False)
                                    
    # Save the extracted behavioural data 
    file_path = os.path.join(subject_write_path, f'{participant}_behavioural_data.pkl')
    with open(file_path, 'wb') as pickle_file:
        pickle.dump(behavioural_data, pickle_file)

    # Save the processed EMG data
    Processed_epochs_filename = os.path.join(subject_write_path, f'{participant}_EMG-epo.fif')
    epochs_EMG.save(Processed_epochs_filename, overwrite=True)
    print(f'File saved: epochs_EMG')

    # Save the processed EMG tfr data
    Processed_tfr_filename = os.path.join(subject_write_path, f'{participant}_EMG-tfr.h5')
    EMG_tfr.save(Processed_tfr_filename, overwrite=True)
    print(f'File saved: EMG_tfr')

In [None]:
# To load the pickle file
# with open(file_path, 'rb') as pickle_file:
#     loaded_data = pickle.load(pickle_file)

In [91]:
# Code for template extraction
trial_idx = 53
this_trial = sample_data[trial_idx,0,:]
thresh = this_trial.std()*3
half_len = int(264/2)
second_half_len = int(264/2) + 100
# Find the peaks
peaks = scipy.signal.find_peaks(this_trial, height = thresh)
plt.plot(this_trial)
plt.plot(peaks[0], this_trial[peaks[0]], 'ro')
plt.axhline(thresh, color = 'r')
plt.show()
print(np.diff(peaks[0]))


[264 802 263 533 266 534 266]


In [None]:
# for each peak, if the peak is not at the edge of the epoch, extract the data around the peak
if 'template2' not in globals():
    template2 = []
for peak in peaks[0]:
    if peak > half_len and peak < len(this_trial) - second_half_len:
        template2.append(this_trial[peak-half_len:peak+second_half_len])
print(len(template2))

In [36]:

template_2 = np.load('Path to EMG data/Processed/BBO13/template_2.npy')
template_2 = template_2.mean(axis = 0)
template_2 /= np.linalg.norm(template_2)
Input_data = Corrected_data
All_peaks = find_correlation_peaks(Corrected_data,template_2)
corrthresh = 0.7
smooth = False
template_length = len(template_2)
Corrected_data_2 = np.copy(Input_data)
if smooth:
    from scipy.ndimage import gaussian_filter1d
    template_2 = gaussian_filter1d(template_2, sigma=6)
# Taper the edges of the template
negative_ramp = [0, 100]
positive_ramp = [template_length - 100, template_length]
ramp_up = np.hanning((negative_ramp[1] - negative_ramp[0]) * 2)[:(negative_ramp[1] - negative_ramp[0])]
ramp_down = np.hanning((positive_ramp[1] - positive_ramp[0]) * 2)[(positive_ramp[1] - positive_ramp[0]):]
tapered_template = np.copy(template_2)
tapered_template[:negative_ramp[1]] *= ramp_up
tapered_template[positive_ramp[0]:] *= ramp_down
subtracted_segments = []
# Define the windows for scaling
negative_win = [187, 217] 
positive_win = [120, 150]

# Initialize error dictionary
All_errors = {'left': np.zeros(len(All_peaks[0])), 'right': np.zeros(len(All_peaks[1]))}
subtracted_segments = {
'left': np.zeros((len(All_peaks[0]), template_length)), 
'right': np.zeros((len(All_peaks[1]), template_length))}
Scaled_templates = {
    'left': np.zeros((len(All_peaks[0]), template_length)),
    'right': np.zeros((len(All_peaks[1]), template_length))
}
# Process each channel and peak
for channel in range(Corrected_data_2.shape[0]):
    for idx, peak in enumerate(All_peaks[channel]):
        tempsegment = Corrected_data_2[channel, peak - (template_length // 2):peak + (template_length // 2)]
        if len(tempsegment) != template_length:
            continue

        # Check if it is a match
        correlation = np.corrcoef(tempsegment, template_2)[0, 1]
        if correlation > corrthresh:
            scaling_window = [positive_win[0], negative_win[1]]
            slope, intercept, _, _, _ = linregress(
                tapered_template[scaling_window[0]:scaling_window[1]],
                tempsegment[scaling_window[0]:scaling_window[1]]
            )
            scaled_template = tapered_template * slope + intercept

            # Further scale each peak individually
            scaled_template[positive_win[0]:positive_win[1]] *= (
                tempsegment[positive_win[0]:positive_win[1]].max() / scaled_template.max()
            )
            scaled_template[negative_win[0]:negative_win[1]] *= (
                tempsegment[negative_win[0]:negative_win[1]].min() / scaled_template.min()
            )

            # Subtract the scaled template
            subtracted_seg = tempsegment - scaled_template

            # Calculate squared error
            squared_residuals = (subtracted_seg) ** 2
            error = np.sum(squared_residuals)
            if channel == 0:
                All_errors['left'][idx] = error
                subtracted_segments['left'][idx] = subtracted_seg
                Scaled_templates['left'][idx] = scaled_template
            elif channel == 1:
                All_errors['right'][idx] = error
                subtracted_segments['right'][idx] = subtracted_seg
                Scaled_templates['right'][idx] = scaled_template

            # Update the corrected data
            Corrected_data_2[channel, peak - (template_length // 2):peak + (template_length // 2)] = subtracted_seg


