In [None]:
s_file = "/Users/denismottet/Documents/GitHub/NeuArm-DataAnalysis/data/AgePie/AgePie_A16.snirf"
delay = 6985
duration = 8105


# Copy the original file to a new file, 

and then replace the content of the new file with the content of the original file.


In [None]:
new_fname = "new_AgePie_A16_2.snirf"

from snirf import Snirf


def copy_snirf_file(file_name, new_file_name):
    snirf = Snirf(file_name, "r")
    snirf.save(new_file_name)
    snirf.close()  

copy_snirf_file(s_file, new_fname)

with Snirf(new_fname, 'r+') as snirf:
     validation_results = snirf.validate()
     #validation_results.display()



# Load the snirf file with MNE

In [None]:

import mne

raw = mne.io.read_raw_snirf(s_file, preload=True, verbose="CRITICAL")
raw.close()

raw_data = raw.get_data()
print("raw_data.shape: {}".format(raw_data.shape))
# plot the raw data (to see the markers-annotations)
raw.plot()


In [None]:
import numpy as np

snirf = Snirf(new_fname, "r+")


# get the nirs time series 
nirs_data = snirf.nirs[0].data[0].dataTimeSeries
nirs_time = snirf.nirs[0].data[0].time
nirs_meas_list = snirf.nirs[0].data[0].measurementList
n_channels = len(nirs_meas_list)
n_samples = len(nirs_time)
print("n_channels = {}, n_samples = {}".format(n_channels, n_samples))
print("nirs_data.shape = {}".format(nirs_data.shape))

# metaDataTags
print("date = {}".format(snirf.nirs[0].metaDataTags.MeasurementDate))
print("time = {}".format(snirf.nirs[0].metaDataTags.MeasurementTime))
print("Subject  = {}".format(snirf.nirs[0].metaDataTags.SubjectID))

# probe
print("sourcePos2D 2D = {}".format(snirf.nirs[0].probe.sourcePos2D))
print("detectorPos2D 2D = {}".format(snirf.nirs[0].probe.detectorPos2D))

def get_channel_names(snirf):
    """get the names of the channels"""
    channel_names = []
    n_sources = len(snirf.nirs[0].probe.sourceLabels)
    n_detectors = len(snirf.nirs[0].probe.detectorLabels)
    n_wavelengths = len(snirf.nirs[0].probe.wavelengths)
    print("n_sources = {}, n_detectors = {}, n_wavelengths = {}".format(n_sources, n_detectors, n_wavelengths))

    for w in range(n_wavelengths):
        for s in range(n_sources):
            d = int(np.floor_divide(s , n_sources / n_detectors))
            channel_names.append(
                "{}-{}-{:3.0f}".format(
                    snirf.nirs[0].probe.sourceLabels[s],
                    snirf.nirs[0].probe.detectorLabels[d],
                    snirf.nirs[0].probe.wavelengths[w],
                )
            )
    return channel_names


def print_annotations(annotations):
    """print the annotations"""
    n_annotations = len(annotations)
    for a in range(n_annotations):
        if annotations[a].data.ndim == 1:
            print("{} = {}".format(annotations[a].name, annotations[a].data))
        else:
            for i in range(len(annotations[a].data)):
                print("{} = {}".format(annotations[a].name, annotations[a].data[i]))

# print the annotations
print("annotations")
print_annotations(snirf.nirs[0].stim)


# modify the annotations 
beg = snirf.nirs[0].data[0].time[delay]
end = snirf.nirs[0].data[0].time[delay + duration]
print("beg = {}, end = {}, duration ={}".format(beg, end, end - beg))

annotations = snirf.nirs[0].stim
n_annotations = len(annotations)
# the first time is always ZERO, hence the annotations are shifted by beg
for a in range(n_annotations):
    if annotations[a].data.ndim == 1:
        annotations[a].data[0] = annotations[a].data[0] - beg
    else:
        for i in range(len(annotations[a].data)):
            annotations[a].data[i][0] = annotations[a].data[i][0] - beg


# print the new annotations
print("new annotations")
print_annotations(snirf.nirs[0].stim)

# modify the time and the nirs data

new_nirs_data = nirs_data[delay:delay+duration,:]
new_nirs_time = nirs_time[delay:delay+duration]

print("new_nirs_time: ", new_nirs_time.shape)
print("new_nirs_data: ", new_nirs_data.shape)

snirf.nirs[0].data[0].dataTimeSeries = new_nirs_data
snirf.nirs[0].data[0].time = new_nirs_time

# save the modifications
snirf.save()
snirf.close()


# plot the nirs time series
import matplotlib.pyplot as plt

plt.figure()
plt.plot(nirs_time, nirs_data, linewidth=0.5, label=get_channel_names(snirf))
plt.title("nirs time series")
plt.xlabel("time (s)")
plt.ylabel("nirs")
plt.legend()
plt.show()


# Load the new file and plot the data

In [None]:
new = mne.io.read_raw_snirf(new_fname, preload=True, verbose="CRITICAL")


new.plot()



# Verify the annotations 

In [None]:
# get the annotations (there is no Events in this file)
annotations = new.annotations

# print all annotations
duration_s = duration / new.info["sfreq"]   
print(annotations)
for i in range(len(annotations.description)):
    e = annotations.description[i]
    t = annotations.onset[i]
    print("{:2d}: {:10.3f}s = {} ".format(i, t, e))
    if t > duration_s:
        print("time of event is greater than duration")


