# Check the NIRS data stream 



# Initializations and imports

In [None]:
import os

import numpy as np
import pprint
import json

# if we are in the notebooks folder, go back to the root folder
if os.path.basename(os.getcwd()) == 'notebooks':
    os.chdir('..')
    cwd = os.getcwd()
    print('Changed working directory to: ', cwd)

conditionsFile = 'data/NeuArm_conditions.csv'
conditions = np.genfromtxt(conditionsFile, delimiter=',', names=True, dtype=None, encoding=None)
print(conditions.dtype.names)


# Load a data file

In [None]:
# select one file and print a summary of its content 
iLine = 11
relativePath = conditions[iLine]['relativePath']

# load the data
import pyxdf
fullpath = os.path.join(cwd, relativePath)
data, header = pyxdf.load_xdf(fullpath, synchronize_clocks=True, dejitter_timestamps=False, verbose=False)

# print all  stream names
for stream in data:
    name = stream['info']['name'][0]
    type = stream['info']['type'][0]
    nSamples = stream['time_stamps'].shape[0]
    duration = round(stream['time_stamps'][-1] - stream['time_stamps'][0])

    import datetime
    duration = datetime.timedelta(seconds=duration)

    print(name + " (" + type + ')\n  ' + str(duration) + ', ' + str(nSamples) + ' samples ')
# print the header 
print(header)

# Check the NIRS data 

In [None]:
# get the NIRS data
nirsStreams = [s for s in data if ( s['info']['type'][0] == 'NIRS' and  s['info']['name'][0] == 'Oxysoft') ]

# check that we have only one NIRS stream
if len(nirsStreams) != 1:
    print('Found ' + str(len(nirsStreams)) + ' NIRS streams. Aborting.')
    exit()

# get the first (and only) NIRS stream
nirsStream = nirsStreams[0] 

# parse the content of the stream
nirsData = nirsStream['time_series']
nirsTime = nirsStream['time_stamps']
nirsDesc = nirsStream['info']['desc'][0]
nirsChannelCount = int(nirsStream['info']['channel_count'][0])
nirsSamplingRate = float(nirsStream['info']['nominal_srate'][0])

# get the labels for the NIRS data
nirsLabels = []
for chan in nirsStream['info']['desc'][0]['channels'][0]['channel']:
    label = chan['label']
    unit = chan['unit']
    type = chan['type']
    nirsLabels.append( {'label': label, 'unit': unit, 'type': type} )

nirsData = np.array(nirsData)
nirsTime = np.array(nirsTime)
nirsTime = nirsTime - nirsTime[0]

print('NIRS data shape: ', nirsData.shape)
print('NIRS sampling rate: ', nirsSamplingRate)
# print the labels one by one 
for label in nirsLabels:
    print(label['label'][0])


## Nexw oder of channels 
From Gerard, the order of channels is: 
>     nirs_data_Xdf_OD.d = streams{1,iData}.time_series([2 4 6 8 26 28 30 32 1 3 5 7 25 27 29 31],:)'; % On remplace les data
Do not forget to shift matlab indices by -1 to get python indices.


In [None]:
ch_names = ['S1_D1 hbo', 'S1_D1 hbr', 'S2_D1 hbo', 'S2_D1 hbr',
            'S3_D1 hbo', 'S3_D1 hbr', 'S4_D1 hbo', 'S4_D1 hbr',
            'S5_D2 hbo', 'S5_D2 hbr', 'S6_D2 hbo', 'S6_D2 hbr',
            'S7_D2 hbo', 'S7_D2 hbr', 'S8_D2 hbo', 'S8_D2 hbr']


# for Homer 2 
newChannelOrder = [1,   3,   5,   7,   25,   27,   29,   31,   0,   2,   4,   6,   24,   26,   28,   30 ]

# for snirf 
# order should be S1_D1 hbo, S1_D1 hbr, S2_D1 hbo, S2_D1 hbr, S3_D1 hbo, S3_D1 hbr, S4_D1 hbo, S4_D1 hbr, S5_D2 hbo, S5_D2 hbr, S6_D2 hbo, S6_D2 hbr, S7_D2 hbo, S7_D2 hbr, S8_D2 hbo, S8_D2 hbr
newChannelOrder = [ 1,   0,   3,   2,   5,   4,   7,   6,   25,   24,   27,   26,   29,   28,   31,   30 ]


# sort newChannelOrder and print it
#newChannelOrder.sort()
print(newChannelOrder)                   

# reorder the labels according to the new channel order
labels = [nirsLabels[i] for i in newChannelOrder]

# print the labels one by one
for label in labels:
    print(label['label'][0], ch_names[labels.index(label)])

# reorder the data according to the new channel order
# nirsData = nirsData[:, newChannelOrder]



# https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7316018/
# hbo <=> 760 nm
# hbr <=> 830 nm  


# Convert the NIRS data to a mne object 
The goal is to convert the NIRS data to a mne object, i.e., a snirf object.  
This will allow us to use the mne functions to analyze the data.  

We follow the steps in the tutorial: https://mne.tools/mne-nirs/stable/auto_examples/general/plot_01_data_io.html#sphx-glr-auto-examples-general-plot-01-data-io-py in the section "Custom data Import" 

## Set channel sampling rate, names and types according to the MNE conventions

> In MNE-Python the naming of channels MUST follow the structure S#_D# type where # is replaced by the appropriate source and detector numbers and type is either hbo, hbr or the wavelength.

In [None]:

ch_names = []
ch_types = []
wavelengths = []
for l in nirsLabels:
    if l['type'][0] == 'NIRS':
        #print(l['label'])
        txt = l['label'][0].split(' ')
        #print(txt)
        S = txt[1][-1]
        D = txt[3][1:]
        wavelength = txt[-1][1:-3]
        mne_label = 'S' + S + '_D' + D + ' ' + wavelength
        #print(mne_label)
        ch_names.append(mne_label)
        # all types are OD 
        ch_types.append('fnirs_od')
        wavelengths.append(wavelength) 
print(ch_names)
print(ch_types)
print(wavelengths)


## Set the measurement info 
Previous exploration showed that the last two columns are battery and LSL1, and that there is data in the first 8 and last 8 channels of the 32 NIRS channels.  
Consequently, we set the measurement info to include only these channels.

In [None]:
# remove the last two columns (battery and LSL)
data = nirsData[:, 0:-2] 

# keep the first 8 and last 8 columns
idx = [0,1,2,3,4,5,6,7, -8,-7,-6,-5,-4,-3,-2,-1]
data = data[:, idx]
data = data.T # transpose the data for mne
ch_names = ch_names[0:8] + ch_names[-8:]
ch_types = ch_types[0:8] + ch_types[-8:]
print(data.shape)
for i in range(len(ch_names)):
    print(i+1, ch_names[i] + ' ' + ch_types[i])


## Create mne Raw object

## Ro-order the channels according to the mne conventions
 > Channels must be ordered as source detector pairs with alternating frequencies 757 & 852 nm. 



In [None]:
ch_names = ['S1_D1 hbo', 'S1_D1 hbr', 'S2_D1 hbo', 'S2_D1 hbr',
            'S3_D1 hbo', 'S3_D1 hbr', 'S4_D1 hbo', 'S4_D1 hbr',
            'S5_D2 hbo', 'S5_D2 hbr', 'S6_D2 hbo', 'S6_D2 hbr',
            'S7_D2 hbo', 'S7_D2 hbr', 'S8_D2 hbo', 'S8_D2 hbr']

ch_types = ['hbo', 'hbr', 'hbo', 'hbr',
            'hbo', 'hbr', 'hbo', 'hbr',
            'hbo', 'hbr', 'hbo', 'hbr',
            'hbo', 'hbr', 'hbo', 'hbr']

data = nirsData[:, newChannelOrder]
data = data.T # transpose the data for mne





In [None]:
import mne
# create the info object
info = mne.create_info(ch_names=ch_names, sfreq=nirsSamplingRate, ch_types=ch_types)
raw = mne.io.RawArray(data, info, verbose=True)
print(raw.info)

# print(raw.info.keys())

montage = mne.channels.make_standard_montage('artinis-octamon')
#raw.set_montage(montage)


In [None]:
#Dictionary of channel positions. Keys are channel names and values are 3D coordinates - array of shape (3,) - in native digitizer space in m.

DetPos =[
        [ 0 ,       0 , 0 ],
        [ 0 , -8.4497 , 0]]

SrcPos = [ 
    [ -2.4748 ,   2.4748  , 0.    ],
    [ -2.4748 ,  -2.4748  , 0.    ],
    [  2.4748 ,   2.4748  , 0.    ],
    [  2.4748 ,  -2.4748  , 0.    ],
    [ -2.4748 ,  -5.9748  , 0.    ],
    [ -2.4748 , -10.9245  , 0.    ],
    [  2.4748 ,  -5.9748  , 0.    ],
    [  2.4748 , -10.9245  , 0.    ] ]

# divide by 100 to convert to meters for mne
DetPos = np.array(DetPos) / 100
SrcPos = np.array(SrcPos) / 100

ch_positions = {
    'D1': DetPos[0], 
    'D2': DetPos[1],
    'S1': SrcPos[0],
    'S2': SrcPos[1],
    'S3': SrcPos[2],
    'S4': SrcPos[3],
    'S5': SrcPos[4],
    'S6': SrcPos[5],
    'S7': SrcPos[6],
    'S8': SrcPos[7],
                }

print(ch_positions)
#print( montage.get_positions())

montage = mne.channels.make_dig_montage(ch_pos=ch_positions) 

raw.set_montage(montage)

raw.plot_sensors(show_names=False)
raw.compute_psd(average='mean').plot()


## Add the wavelength information

In [None]:
# Add the wavelength information to the info["chs"] structure

# add the wavelength information to the info["chs"] structure
for i, ch in enumerate(raw.info["chs"]):
    #ch['loc'] = np.array([wavelengths[i], 0, 0, 0])
    # ch['unit'] = mne.io.constants.FIFF.FIFF_UNIT_MOL
    # ch['coil_type'] = mne.io.constants.FIFF.FIFFV_COIL_FNIRS_OD

    ch['loc'][9] = float(wavelengths[i])
    print(i, ch)


## Create the montage and add it to the mne Raw object

In [None]:
montage = mne.channels.make_standard_montage('standard_1020')
raw.set_montage(montage)
# raw.plot_sensors()

In [None]:
import matplotlib.pyplot as plt

# New figure with a boxplot for each channel
fig, ax = plt.subplots(figsize=(15,10))
ax.boxplot(nirsData)
ax.set_xlabel('Channel')
ax.set_ylabel('Value')
ax.set_title('Boxplot for each channel')

# add x-tick labels with the channel names
chanLabel = []
for i in range(nirsCount):
    label = nirsLabels[i]['label'][0] + ' [' + nirsLabels[i]['unit'][0] + ']'
    chanLabel.append(label)
ax.set_xticklabels(chanLabel, rotation='vertical') # should be the last line before the show() command
plt.show()

    

In [None]:
# TODO - print the info in a nice way

def dmp(l, indent=0):
    s = ""

    if isinstance(l, str):
        s += l + "\n"
    elif isinstance(l, list):
        #s += "\n."
        for i in l:
            ii = indent
            s += "\n"+ ("-" * ii) 
            s += dmp(i, ii) 
    elif isinstance(l, dict):
        ii = indent
        s += "\n"+ ("*" * ii) 
        for key, value in l.items():
            s += (str(key) + ": " + dmp(value, ii) )  
    return s
 
#pprint.pprint(nirsDesc, indent=10, depth=2)
# print(dmp(nirsDesc, 0))