In [1]:
import mne
from itertools import compress
import os.path as op
import numpy as np
import pandas as pd
from mne.io import read_raw_nirx as nirx
from copy import deepcopy
import os
from matplotlib import pyplot as plt

In [3]:
#### GLOBAL VARIABLE: LIST OF SHORT CHANNELS

# These are broadly uninteresting to hyperscanning analysis.
ALL_SHORT_CHANNELS = [
 'S3_D23 hbo',
 'S3_D23 hbr',
 'S4_D24 hbo',
 'S4_D24 hbr',
 'S5_D31 hbo',
 'S5_D31 hbr',
 'S8_D32 hbo',
 'S8_D32 hbr',
 'S10_D25 hbo',
 'S10_D25 hbr',
 'S11_D26 hbo',
 'S11_D26 hbr',
 'S12_D27 hbo',
 'S12_D27 hbr',
 'S13_D28 hbo',
 'S13_D28 hbr',
 'S15_D29 hbo',
 'S15_D29 hbr',
 'S16_D30 hbo',
 'S16_D30 hbr',
 'S18_D33 hbo',
 'S18_D33 hbr',
 'S19_D34 hbo',
 'S19_D34 hbr',
 'S20_D35 hbo',
 'S20_D35 hbr',
 'S21_D36 hbo',
 'S21_D36 hbr',
 'S23_D37 hbo',
 'S23_D37 hbr',
 'S24_D38 hbo',
 'S24_D38 hbr']

In [4]:
####### FUNCTIONS #########


# Given list of channels (Such as irrelevant short channels), drop them.
def filter_relevant_channels(fnirs_data):
    fnirs_data.drop_channels(ALL_SHORT_CHANNELS)

# Changes to THIS function will allow us to subdivide conversations if we wish.
def fix_convo_annotations(raw_intensity):
    raw_intensity.annotations.rename({'9.0': 'rest',
                                  '11.0' : 'Convo2',
                                  '10.0': 'Convo1'})
    for i in range(30): # Remove all unchosen triggers.
        raw_intensity.annotations.delete(raw_intensity.annotations.description == f'{i}.0')
    raw_intensity.annotations.duration = np.array([120, 300, 60, 300, 120])
    return raw_intensity


# Quite simply, load the NIRSport2 data given a path, and apply fix_convo_annotations.    
def load_nirx_from_input_path(input_path):
    data = nirx(input_path, verbose = 'CRITICAL')
    raw_intensity = data.load_data()
    loaded_data = fix_convo_annotations(raw_intensity)
    return loaded_data

# Converts first to optical density, computes the SCI, and interpolates if necessary.
def convert_to_haemoglobin_and_interpolate(loaded_data, sci_threshold = 0.6):
    raw_optical_density = mne.preprocessing.nirs.optical_density(loaded_data)
    sci = mne.preprocessing.nirs.scalp_coupling_index(raw_optical_density)
    raw_optical_density.info['bads'] = list(compress(raw_optical_density.ch_names, sci < sci_threshold))
    raw_optical_density.interpolate_bads(reset_bads = True, method = dict(fnirs = 'nearest'))
    raw_haemoglobin = mne.preprocessing.nirs.beer_lambert_law(raw_optical_density, ppf=6)
    return raw_haemoglobin

# Helper function, because MNE's built in pick_channels() function edits the original object in place!
def pick_channels_deepcopy(raw_haemoglobin, channel_name):
    haemoglobin_copy = deepcopy(raw_haemoglobin)
    single_channel_data = haemoglobin_copy.pick_channels([channel_name], verbose = False)
    return single_channel_data

# Because we didn't use the hyperscan application, we need to use triggers to make time series "start" at the same time.
def truncate_time_series(single_channel_data, rounding, verbose = False):

    # Useful for debugging
    if verbose:
        print(single_channel_data.times.shape)
        print(single_channel_data._data[0].shape)

    # Get triggers.
    triggers = single_channel_data.annotations.onset
    # Find earliest onset
    earliest_onset = np.round(triggers[0], rounding)

    # Sanity check: is this onset actually in the times vector?
    assert earliest_onset in np.round(single_channel_data.times, rounding)

    # Figure out at what exact measurement the experiment starts. This should be greater than 0.
    experiment_start_idx = np.where(np.round(single_channel_data.times, rounding) == earliest_onset)[0][0]
    assert experiment_start_idx >= 0

    # Figure out when last trigger (should be a rest trigger) occured in the recording.
    last_onset = np.round(triggers[-1], rounding)
    assert last_onset in np.round(single_channel_data.times, rounding)
    
    # Experiment ends APPROXIMATELY at following index:
    experiment_end_idx = np.where(np.round(single_channel_data.times, rounding) == last_onset)[0][0] + 611

    #Truncate measurement times to match experiment start and end.
    output_timeseries = single_channel_data.times[experiment_start_idx:experiment_end_idx]
    # Truncate actual measurements to match experiment start and end.
    output_data = single_channel_data._data[0][experiment_start_idx:experiment_end_idx]

    # Re-configure times vector to new zero
    output_timeseries =  output_timeseries - output_timeseries[0]

    # Re-configure trigger onset times to new zero
    triggers = triggers - triggers[0]

    # A bunch more sanity checks
    assert len({i >= 0 for i in output_timeseries}) == 1
    assert output_timeseries[1] > 0
    assert output_timeseries[0] == 0
    assert triggers[1] > 0
    assert triggers[0] == 0
    if verbose:
        print(output_timeseries.shape)
        print(output_data.shape)
    return output_timeseries, output_data, triggers

# Saves a time series (That is, a vector of measurements and measurement times) plus associated triggers.
def save_time_series(output_timeseries, output_data, triggers, output_path, filename):
    data_dict = { "c1"  : output_timeseries,"c2" : list(output_data)}
    triggers_dict = {"triggers" : triggers}
    # Turn times and measurements into dataframe
    pd.DataFrame(data = data_dict).to_csv(output_path +"/"+ filename  +".tsv", sep="\t", index=None, header = False)
    # Save triggers (identical across all channels, need only be saved once)
    if not os.path.exists(output_path +"/" + "triggers.tsv"):
        pd.DataFrame(data = triggers_dict).to_csv(output_path +"/" + "_triggers.tsv", sep="\t", index=None, header = False)

# Function that bundles the whole thing together.
def nirx_to_timeseries(input_path, output_path, participant_nr, verbose = False):
    try:
        loaded_data = load_nirx_from_input_path(input_path)
    except FileNotFoundError:
        print("WARNING: " + input_path + " not found.")
        return
    raw_haemoglobin = convert_to_haemoglobin_and_interpolate(loaded_data)
    filter_relevant_channels(raw_haemoglobin)
    channel_list = raw_haemoglobin.ch_names
    for channel_name in channel_list:
        filename =  f"{participant_nr}" + "_" + channel_name
        single_channel_data = pick_channels_deepcopy(raw_haemoglobin, channel_name)
        output_timeseries, output_data, triggers = truncate_time_series(single_channel_data, rounding = 10, verbose = verbose)
        save_time_series(output_timeseries, output_data, triggers, output_path, filename)

# Helper function which creates new directories.
def make_folder(newpath):
    if not os.path.exists(newpath):
        os.makedirs(newpath)


# FUNCTION FOR GETTING AN OVERVIEW OF BAD CHANNELS

#https://mne.discourse.group/t/interpolation-of-bad-channels-in-fnirs-data/4100/5

#https://mne.tools/stable/auto_tutorials/preprocessing/15_handling_BadChannels.html

def check_bads(input_path, plot_name):
    make_folder("Plots/BadChannelsPlots")
    try:
        loaded_data = load_nirx_from_input_path(input_path)
    except FileNotFoundError:
        print("WARNING: " + input_path + " not found.")
        return
    raw_optical_density = mne.preprocessing.nirs.optical_density(loaded_data)
    sci = mne.preprocessing.nirs.scalp_coupling_index(raw_optical_density)
    # Specify SCI cutoff here
    indeces = [i for i, v in enumerate(sci) if v < 0.4]
    raw_haemoglobin = mne.preprocessing.nirs.beer_lambert_law(raw_optical_density, ppf=6)
    
    BadChannels = [raw_haemoglobin.ch_names[i] for i in indeces]
    BadChannels = [i for i in BadChannels if i not in ALL_SHORT_CHANNELS]
    fig, ax = plt.subplots(layout="constrained")
    ax.hist(sci)
    ax.set(xlabel="Scalp Coupling Index", ylabel="Count", xlim=[0, 1])
    
    plt.savefig("Plots/BadChannelsPlots" + f"/{plot_name}_bad_channel_plot.png")
    plt.close()
    
    return BadChannels


In [None]:

# Check bad channels

pair_input_paths = ["../Data/AllParticipants" + "/" + folder for folder in os.listdir("../Data/AllParticipants")]
for pair_number, pair_path in enumerate(pair_input_paths):
    
    print(pair_number)
    if pair_number == 15: # Pair 15 isn't complete.
        continue

    data1_visit1 = check_bads( pair_path + "/Visit1/Participant1/convo", plot_name = f"{pair_number}_p1_v1")
    data2_visit1 = check_bads( pair_path + "/Visit1/Participant2/convo", plot_name = f"{pair_number}_p2_v1")
    ## Visit 4
    data1_visit4 =  check_bads( pair_path + "/Visit4/Participant1/convo",plot_name = f"{pair_number}_p1_v4")
    data2_visit4 = check_bads( pair_path + "/Visit4/Participant2/convo", plot_name = f"{pair_number}_p1_v4")
    pd.DataFrame(data = data1_visit1).to_csv("../Data/AllParticipants/BadChannelsLists/" f"{pair_number}" + "_BadChannels_v1_p1.tsv", sep="\t", index=None, header = False)
    pd.DataFrame(data = data2_visit1).to_csv("../Data/AllParticipants/BadChannelsLists/" f"{pair_number}" + "_BadChannels_v1_p2.tsv", sep="\t", index=None, header = False)
    pd.DataFrame(data = data1_visit4).to_csv("../Data/AllParticipants/BadChannelsLists/" f"{pair_number}" + "_BadChannels_v4_p1.tsv", sep="\t", index=None, header = False)
    pd.DataFrame(data = data2_visit4).to_csv("../Data/AllParticipants/BadChannelsLists/" f"{pair_number}" + "_BadChannels_v4_p2.tsv", sep="\t", index=None, header = False)



In [None]:
make_folder("HaemoglobinTimeSeries")

# Now that bad channels have been checked and plotted, load all the NIRSport2 data and sae as csv-based timeline files.

pair_input_paths = ["../Data/AllParticipants" + "/" + folder for folder in os.listdir("..\Data\AllParticipants")]
for pair_number, pair_path in enumerate(pair_input_paths[1:23]):
    print(pair_number + 1)
    if pair_number +1 == 15: # Pair 15 isn't complete.
        continue
        
    # Create output directories
    output_path_visit1 = "HaemoglobinTimeSeries" + f"/{pair_number + 1}" + "/Visit1"
    make_folder(output_path_visit1)
    output_path_visit4 = "HaemoglobinTimeSeries" + f"/{pair_number + 1}" + "/Visit4"
    make_folder(output_path_visit4)
    
    # Load data and save to appropriate output directory

    ## Visit 1
    nirx_to_timeseries(input_path = pair_path + "/Visit1/Participant1/convo", output_path = output_path_visit1, participant_nr = 1)
    nirx_to_timeseries(input_path = pair_path + "/Visit1/Participant2/convo", output_path = output_path_visit1, participant_nr = 2)
    ## Visit 4
    nirx_to_timeseries(input_path = pair_path + "/Visit4/Participant1/convo", output_path = output_path_visit4, participant_nr = 1)
    nirx_to_timeseries(input_path = pair_path + "/Visit4/Participant2/convo", output_path = output_path_visit4, participant_nr = 2)

