In [None]:
import mne
import os
import matplotlib.pyplot as plt
import pandas as pd
import ast
import numpy as np
import scipy
import xml.etree.ElementTree as ET
from scipy.spatial.distance import cdist


%matplotlib qt

# a function for plotting digitization coordinates
def plot_digipoints(digitization_montage, subject, scatter=False):
    channel_names = list(digitization_montage.keys())
    digipoints = list(digitization_montage.values())
    digi_x = [p[0] for p in digipoints]
    digi_y = [p[1] for p in digipoints]
    digi_z = [p[2] for p in digipoints]
    fig, ax = plt.subplots(subplot_kw={'projection':'3d'})
    ax.set_title(subject)
    if not scatter:
        ax.plot(digi_x, digi_y, digi_z)
    else:
        ax.scatter(digi_x, digi_y, digi_z)
    for p1, p2, p3, label in zip(digi_x, digi_y, digi_z, channel_names):
        ax.text(p1,p2,p3,label)
    ax.set_axis_off()
    ax.xaxis.pane.fill=False
    ax.yaxis.pane.fill=False
    ax.zaxis.pane.fill=False
    ax.grid(False)

def plot_digipoints2d(digitization_montage, subject):
    ch_montage = {key:val for key, val in digitization_montage.items() if key not in ['nas','lpa','rpa']}
    montage = mne.channels.make_dig_montage(ch_pos=ch_montage, nasion=digitization_montage['nas'], lpa=digitization_montage['lpa'], rpa=digitization_montage['rpa'])
    fig = mne.viz.plot_montage(montage)
    fig.suptitle(subject)


#Functions for reading digitization coordinates from .nbe and .xml files.
#Also includes functions for correcting them
#------------------------------------------------------------------------------------------------------------------------------------------------------------------
    
def read_digitization_xml(filepath, save_to_dir=None, subject_identifier=None):
    digitization_montage = {}
    # parse the xml file and loop through it
    tree = ET.parse(filepath)
    coordinate_space = tree.getroot().get('coordinateSpace')
    if coordinate_space not in ['RAS','LPS']:
        return False
    if coordinate_space=='LPS':
        print(f"detected LPS system in {filepath}, transforming to RAS later...")
    for marker in tree.getroot():
        point_name= marker.attrib['description'] #name of the electode or the anatomical landmark
        #get the digitization coordinate in the order of x-y-z
        colvec = marker.find('ColVec3D')
        #set to float and divide by 1000 to scale from millimeters to meters
        digitization_coordinate = np.array([float(colvec.get('data0')),float(colvec.get('data1')) ,float(colvec.get('data2'))])/1000
        if coordinate_space == 'LPS': #transform then to RAS
            digitization_coordinate = np.matmul(digitization_coordinate, np.array([[-1, 0, 0], [0, -1, 0], [0, 0, 1]]))
        if point_name == "NAS":
            point_name = 'nas'
        if point_name == "RTR":
            point_name = 'rpa'
        if point_name == "LTR":
            point_name = 'lpa'
        digitization_montage[point_name] = digitization_coordinate #set the point with the data
    if save_to_dir and subject_identifier: #save the information
        np.save(f'{save_to_dir}/{subject_identifier}_original_digitizations', digitization_montage)
    return digitization_montage


def read_digitization_nbe(filepath_to_read, save_to_dir=None, subject_identifier=None, correct_point_order_path=None, use_scalp_landmarks=False):
    if correct_point_order_path and not isinstance(correct_point_order_path,list):
        point_order = np.load(correct_point_order_path) #digitization order
    elif isinstance(correct_point_order_path,list): #chieti specification
        point_order  = correct_point_order_path
    else:
        print("correct_point_order_path not specified so can not know order of digitization points")
        return False
    number_of_points_to_record = len(point_order)
    file = open(filepath_to_read,"r",encoding= 'unicode_escape')
    lines = file.readlines()
    for ind1, origline in enumerate(lines): #go through all lines in the file
        if "coordinate system" in str(origline).lower():
            print(str(origline))
            if "mri coordinate system" not in str(origline).lower():
                print("coordinate system should be MRI coordinate system, check the .nbe log.")
                return False
        if use_scalp_landmarks and "Landmarks (mm)" in origline:
            lpa, rpa, nas =  False, False, False
            for line in lines[ind1:]:
                line_splitted = line.split("\t") #split the line to a list
                if "Scalp landmark: Left ear" in line:
                    lpa = np.array([float(line_splitted[1]),float(line_splitted[2]),float(line_splitted[3])])/1000
                elif "Scalp landmark: Nose/Nasion" in line:
                    nas = np.array([float(line_splitted[1]),float(line_splitted[2]),float(line_splitted[3])])/1000
                elif "Scalp landmark: Right ear" in line:
                    rpa = np.array([float(line_splitted[1]),float(line_splitted[2]),float(line_splitted[3])])/1000
                if lpa is not False and rpa is not False and nas is not False:
                    break
        if "Digitization Exam Description: " in str(origline): #find the digitization exam
            digitization_montage = {}
            digi_ind = 0
            #initialize list, dictionary and digi_ind here to always make sure that the last digitization exam is used 
            for line in lines[ind1:]: #go through the exam
                if "Point" in line: #only examine lines where a digitization point has been recorded
                    line_splitted = line.split("\t") #split the line to a list
                    #get the digitization coordinate (and scale to meters)
                    digitization_point =  np.array([float(line_splitted[1]),float(line_splitted[2]),float(line_splitted[3])])/1000
                    #align the point with the name of the electode or the anatomical landmark if possible
                    if digi_ind < len(point_order):
                        point_name = point_order[digi_ind]
                        digitization_montage[point_name] = digitization_point
                    else:
                        faulty_name = f'At_least_bad_{digi_ind}'
                        digitization_montage[faulty_name] = digitization_point
                    digi_ind += 1
    #combine the information and save
    if use_scalp_landmarks:
        digitization_montage['lpa'] = lpa
        digitization_montage['nas'] = nas
        digitization_montage['rpa'] = rpa
    number_of_recorded_points = len(list(digitization_montage.values()))
    if number_of_points_to_record != number_of_recorded_points:
        print(f'Recorded points in {filepath_to_read} is {number_of_recorded_points} even though it should be {number_of_points_to_record}')
    if save_to_dir and subject_identifier:  #save the information
        np.save(f'{save_to_dir}/{subject_identifier}_original_digitizations', digitization_montage)
    return digitization_montage

def detect_outliers_by_label(dictionaries, subjects, threshold=3):
    grouped_points = {}
    grouped_points_orig = {}
    for d in dictionaries:
        positions = np.array([pos for _, pos in d.items()])
        meanpos = np.mean(positions,axis=0)
        norm_of_maximums = np.linalg.norm(np.max(positions, axis=0))
        #print(norm_of_maximums)
        #print(meanpos)
        for label, pos in d.items():
            origpos = pos
            pos =  pos - meanpos
            pos = pos/norm_of_maximums
            if label not in grouped_points:
                grouped_points[label] = []
                grouped_points_orig[label] = []
            grouped_points[label].append(pos) #add the position under the label
            grouped_points_orig[label].append(origpos) #add the position under the label
    
    outliers = {} # init dict for outliers

    for label, points in grouped_points.items():
        points_arr = np.array(points)
        points_orig = np.array(grouped_points_orig[label])

        #calculate mean and sd
        mean = np.mean(points_arr, axis=0)
        sd = np.std(points_arr, axis=0)

        label_outliers = [] #init a list for outliers for this label

        for index, point in enumerate(points_arr): #go through all points and calculate z-scores and compare to threshold
            z_score = (point - mean) / sd
            if np.any(np.abs(z_score) > threshold):
                label_outliers.append((subjects[index], points_orig[index], point, mean, sd))

        if len(label_outliers) > 0:
            outliers[label] = label_outliers

    return outliers

def find_nearest_labels(locations, k):
    labels = list(locations.keys())
    coords = np.array(list(locations.values()))
    distances =  cdist(coords,coords)
    nearest = {}
    for i, label in enumerate(labels):
        nearest_inds = np.argsort(distances[i])[1:k+1]
        nearest[label] = set(labels[j] for j in nearest_inds)
    return nearest

def check_nearest_consistency(dictionaries, k, threshold):
    all_nearest = [find_nearest_labels(d,k=k) for d in dictionaries]
    outliers = {}
    for label in all_nearest[0].keys():
        nearest_sets = [nearest[label] for nearest in all_nearest]
        common_neighbors = set.intersection(*nearest_sets)
        if len(common_neighbors) < threshold:
            outliers[label] = {'common_neighbors':common_neighbors, 'neighbor_sets': nearest_sets}
        #if len(set(map(frozenset,nearest_sets))) > 1:
            #outliers[label]=nearest_sets
    return outliers

def get_default_channel_pos(default_channel_locations_file, bad_position_names, channel_order, use_fiducials=False):
    default_locations = mne.channels.read_custom_montage(default_channel_locations_file)
    default_positions = {}
    bad_positions = {}
    #go through the default channel positions
    default_channel_locations_names = default_locations.ch_names
    for pos_index, digipoint in enumerate(default_locations.dig):
        pos_name_ident = str(digipoint['ident'])
        if use_fiducials:
            if "LPA" in pos_name_ident :
                pos_name = "lpa"
            elif "RPA" in pos_name_ident :
                pos_name = "rpa"
            elif "NASION" in pos_name_ident:
                pos_name = "nas"
            else:
                pos_name = default_channel_locations_names[pos_index-3]
        else:
            pos_name = default_channel_locations_names[pos_index-3]
        if pos_name in channel_order: #if the channel is in the currently used channel order list
            position_default = digipoint['r']
            default_positions[pos_name] = position_default
            if pos_name in bad_position_names:
                #position of the bad electrode in default channel positions
                bad_positions[pos_name] = position_default
    return default_positions, bad_positions


def recover_bad_digi_coordinates(bad_electrodes_in_default, good_point_positions, default_point_positions, good_point_positions_full=False):
    good_positions = np.array(list(good_point_positions.values())) #good positions of electrodes
    default_channel_positions_good = {label:value for label, value in default_point_positions.items() if label not in bad_electrodes_in_default.keys()}
    default_positions = np.array(list(default_channel_positions_good.values()))

    # center the coordinate systems to the same mean and get the scale between the coordinate systems and apply scaling
    centered_good_points = good_positions-np.mean(good_positions,axis=0)
    centered_good_default_points = default_positions-np.mean(default_positions,axis=0)
    mean_norm_good_pos = np.mean(np.linalg.norm(centered_good_points,axis=1))
    mean_norm_default_pos = np.mean(np.linalg.norm(centered_good_default_points,axis=1))
    scale = mean_norm_good_pos/mean_norm_default_pos

    #get the rotation matrix
    r, _ = scipy.spatial.transform.Rotation.align_vectors(centered_good_points, centered_good_default_points)
    R = r.as_matrix()

    #compute translation
    t = np.mean(good_positions, axis=0) - scale*np.dot(np.mean(default_positions,axis=0),R) #times scale?

    #construct the transformation matrix
    T = np.eye(4)
    T[:3,:3] = scale*R
    T[:3,3] = t
    if good_point_positions_full is not False:
        good_point_positions = good_point_positions_full
    return recover_positions(bad_electrodes_in_default,T, good_point_positions)


def recover_positions(bad_electrodes_in_default,T, good_point_positions):
    #transform the bad electrodes from the default coordinates to the head of the subject
    for bad_electrode in bad_electrodes_in_default:
        if bad_electrode in good_point_positions:
            print("Bad electrode can not be in good electrodes")
            return False
        point_homogeneous = np.append(bad_electrodes_in_default[bad_electrode],1)
        good_point_positions[bad_electrode] = np.dot(T,point_homogeneous)[:3] #apply transformation

    return good_point_positions

#------------------------------------------------------------------------------------------------------------------------------------------------------------------


In [None]:
digitizations_path = r"D:\REFTEP_ALL\Digitization/"
default_channel_locations_file = os.path.join(digitizations_path,'standard_1005.elc')


point_order_aalto = ['Fp1','Fpz','Fp2','AF7','AF3','AFz','AF4','AF8','F7','F5','F3','F1','Fz','F2','F4','F6','F8',
                        'FT7','FC5','FC3','FC1','FCz','FC2','FC4','FC6','FT8','T7','C5','C3','C1','Cz','C2','C4','C6','T8',
                        'TP7','CP5','CP3','CP1','CPz','CP2','CP4','CP6','TP8','P7','P5','P3','P1','Pz','P2','P4','P6','P8',
                        'PO7','PO3','POz','PO4','PO8','O1','Oz','O2','Iz','ref','lpa','nas','rpa'] #this is the electrode path that should be taken
digitizations_aalto = r"D:\REFTEP_ALL\Digitization\Digitizations_Aalto/"
correct_point_order_path_aalto = os.path.join(digitizations_aalto, "correct_point_order_Aalto")
#np.save(correct_point_order_path_aalto, point_order_aalto)

point_order_paths = {'Aalto':f'{correct_point_order_path_aalto}.npy'}
runtue = False #whether to run the sites electrode alignemnt
runaalto = False

In [None]:
if runtue:
    site = 'Tuebingen'
    filepath = rf"D:\REFTEP_ALL\Neuronavigation_nexstim_localite\{site}_localite/"
    identifier = "EEGMarkers" #used to identify correct files
    digitizations_site = os.path.join(digitizations_path,f'Digitizations_{site}')
    for subject in os.listdir(filepath):
        if "stimulation_times" not in subject:
            digitization_folder = os.path.join(filepath, subject, 'EEG')
            files_in_digitization_folder = os.listdir(digitization_folder)
            #get the EEG marker files
            eeg_marker_files = [file for file in files_in_digitization_folder if file.startswith(identifier)]
            #get the most recent file
            most_recent_file = max(eeg_marker_files, key=lambda file:file[len(identifier):])
            #define the filepath and read the digitization file
            eeg_marker_filepath = os.path.join(digitization_folder,most_recent_file)
            if "rep" in subject: #create a new subject identifier
                subject_identifier_new = f'REFTEP{subject[-7:]}'
            else:
                subject_identifier_new = f'REFTEP{subject[-3:]}'
            subject_directory_digitization = os.path.join(digitizations_site, subject_identifier_new)
            if not os.path.exists(subject_directory_digitization):
                os.mkdir(subject_directory_digitization)
            digimontage = read_digitization_xml(eeg_marker_filepath, save_to_dir=subject_directory_digitization, subject_identifier=subject_identifier_new)
            plot_digipoints(digimontage, subject_identifier_new, scatter=True)
    subject_dirs = os.listdir(digitizations_site)
    subject_dirs2 = [directory for directory in subject_dirs if "REFTEP" in directory]
    dictionaries = [np.load(f'{digitizations_site}/{subject}/{subject}_original_digitizations.npy',allow_pickle=True).item() for subject in subject_dirs2]
    outliers = detect_outliers_by_label(dictionaries,subject_dirs2,3) #check for outliers
    print("Outliers found:", len(outliers))
    for outlier in outliers:
        print(outlier, outliers[outlier])

In [None]:
site = 'Tuebingen'
filepath = rf"D:\REFTEP_ALL\Neuronavigation_nexstim_localite\{site}_localite/"
identifier = "EEGMarkers" #used to identify correct files
digitizations_site = os.path.join(digitizations_path,f'Digitizations_{site}')
subs = []
dictionaries = []
for subject in os.listdir(filepath):
    if "stimulation_times" not in subject:
        subs.append(subject)
        digitization_folder = os.path.join(filepath, subject, 'EEG')
        files_in_digitization_folder = os.listdir(digitization_folder)
        #get the EEG marker files
        eeg_marker_files = [file for file in files_in_digitization_folder if file.startswith(identifier)]
        #get the most recent file
        most_recent_file = max(eeg_marker_files, key=lambda file:file[len(identifier):])
        #define the filepath and read the digitization file
        eeg_marker_filepath = os.path.join(digitization_folder,most_recent_file)
        if "rep" in subject: #create a new subject identifier
            subject_identifier_new = f'REFTEP{subject[-7:]}'
        else:
            subject_identifier_new = f'REFTEP{subject[-3:]}'
        subject_directory_digitization = os.path.join(digitizations_site, subject_identifier_new)
        digimontage = read_digitization_xml(eeg_marker_filepath, save_to_dir=None, subject_identifier=None)
        plot_digipoints2d(digimontage, subject_identifier_new)
        dictionaries.append(digimontage)
outliers = detect_outliers_by_label(dictionaries,subs,3)
print("Outliers found:", len(outliers))
for outlier in outliers:
    print(outlier, outliers[outlier])

#check outliers based on neighbors
outliers_pos = check_nearest_consistency(dictionaries, k=8, threshold=5)
if outliers_pos:
    print("outliers found based on neighbors")
    for label, data in outliers_pos.items():
        print(f'{label} has outliers')
        print(f"common neighbors:{sorted(data['common_neighbors'])}")
        for i, nearest in enumerate(data['neighbor_sets']):
            print(f'dict {i+1}: {nearest}')
else:
    print("no outliers found based on neighbors")

In [None]:
if runaalto:
    site = 'Aalto'
    filepath = rf"D:\REFTEP_ALL\Neuronavigation_nexstim_localite\{site}_nexstim/"
    identifier = ".nbe" #used to identify correct files
    digitizations_site = os.path.join(digitizations_path,f'Digitizations_{site}')
    for subject in os.listdir(filepath):
        if "stimulation_times" not in subject:
            nbe_folder = os.path.join(filepath, subject)
            files_in_nbe_folder = os.listdir(nbe_folder)
            #get the nexstim .nbe file
            nbe_files = [file for file in files_in_nbe_folder if file.endswith(identifier)]
            if len(nbe_files) > 1:
                print("More than 1 .nbe file. Check again!")
                break
            #get the most recent file
            #define the filepath and read the digitization file
            nbe_filepath = os.path.join(nbe_folder,nbe_files[0])
            subject_identifier_new = f'REFTEP{subject[-3:]}'
            subject_directory_digitization = os.path.join(digitizations_site, subject_identifier_new)
            if not os.path.exists(subject_directory_digitization):
                os.mkdir(subject_directory_digitization)
            if "117" in subject: #fiducials not recorded in digimontage, so use the recorded scalp marks for nasion and ears
                use_scalp_landmarks = True
            else:
                use_scalp_landmarks = False
            digimontage = read_digitization_nbe(nbe_filepath, save_to_dir=subject_directory_digitization, subject_identifier=subject_identifier_new,
                                                 correct_point_order_path=point_order_paths[site], use_scalp_landmarks=use_scalp_landmarks)
            plot_digipoints(digimontage, subject_identifier_new)


    print(site)
    subject_dirs = os.listdir(digitizations_site)
    subject_dirs = [directory for directory in subject_dirs if "REFTEP" in directory and os.path.isdir(os.path.join(digitizations_site,directory)) and '119' not in directory]
    dictionaries = [np.load(f'{digitizations_site}/{subject}/{subject}_original_digitizations.npy',allow_pickle=True).item() for subject in subject_dirs]
    outliers = detect_outliers_by_label(dictionaries,subject_dirs)

    print("Outliers found (note that these may not be all the mistakes):", len(outliers))
    for outlier in outliers: #check outliers
        print(outlier, outliers[outlier])

## Added information on bad digitization points to excel files
## Now the bad points will be adjusted/corrected

In [None]:
if runaalto:
    filepath = r"D:\REFTEP_ALL\Digitization\Digitizations_Aalto"
    corrections_filename = "digitization_corrections_aalto.xlsx"
    df = pd.read_excel(os.path.join(filepath,corrections_filename))
    identifier = "_original_digitizations.npy" #used to identify correct files
    final_original_digitizations_filename = "_final_original_digitizations"
    final_corrected_digitizations_filename = "_final_corrected_digitizations"
    final_digitizations_filename = "_final_digitizations"
    correct_point_order = np.load(f'{correct_point_order_path_aalto}.npy') #aalto-specific
    use_fiducials = False
    exclusions = ['ref','rpa','lpa','nas']
    %matplotlib qt
    for file in os.listdir(filepath):
        if "REFTEP" in file and '119' not in file: #only check subjects but not REFTEP119
            subject = file
            digitization_folder = os.path.join(filepath, subject)
            files_in_digitization_folder = os.listdir(digitization_folder)
            #get the nexstim .nbe file
            original_digitization_files = [file for file in files_in_digitization_folder if file.endswith(identifier) and "final" not in file]
            if len(original_digitization_files ) > 1:
                print("More than 1 original file for digitization information. Check again!")
                break
            #get the most recent file
            #define the filepath and read the digitization file
            subject_information = df[df.iloc[:,0] == subject]
            digitized_order = subject_information['Digitized_order']
            bad_positions = subject_information['Bad_or_missing_locations']
            digitization_filepath = os.path.join(digitization_folder,original_digitization_files[0])
            digitization_montage = np.load(digitization_filepath, allow_pickle=True).item()
            if digitized_order.isna().any() and bad_positions.isna().any():
                subject_corrected_original_digitizations_filename = subject + final_original_digitizations_filename
                #print(f"{subject} seems to have ok original digitization. Saving the identical file as {subject_corrected_original_digitizations_filename}")
                np.save(os.path.join(digitization_folder,subject_corrected_original_digitizations_filename),digitization_montage)
            else:
                digivalues = list(digitization_montage.values()) #values in the recorded digitization montage
                digitized_order = digitized_order.values.tolist()
                bad_positions = bad_positions.values.tolist()
                digitized_order = ast.literal_eval(digitized_order[0])
                bad_positions = ast.literal_eval(bad_positions[0])
                if "117" in subject:
                    digitized_order = digitized_order + ['lpa','nas','rpa'] #not in digimontage but from anatomical landmarks
                digitization_montage_not_nones = {key:value for key, value in zip(digitized_order,digivalues) if key}
                #plot_digipoints(digitization_montage_not_nones,subject)
                digimontage_real_order = {key:digitization_montage_not_nones[key] for key in correct_point_order if key not in bad_positions and key != "ref"}
                #plot_digipoints(digimontage_real_order,subject)
                default_positions, bad_positions_in_default = get_default_channel_pos(default_channel_locations_file, bad_positions, correct_point_order, use_fiducials=use_fiducials)
                #plot_digipoints(default_positions, subject)
                dafault_channel_names = list(default_positions.keys())
                good_positions = {label:value for label, value in digimontage_real_order.items() if label not in bad_positions and label not in exclusions}
                digitization_positions = recover_bad_digi_coordinates(bad_positions_in_default, good_positions, default_positions)
                if "121" in subject: #manual adjustment after reco
                    digitization_positions['TP8'] = [digitization_positions['TP8'][0]+0.0065,digitization_positions['TP8'][1],digitization_positions['TP8'][2]]
                if "120" in subject: #manual adjustment after reco
                    digitization_positions['T8'] = [digitization_positions['T8'][0]-0.008,digitization_positions['T8'][1],digitization_positions['T8'][2]]
                electrode_pos_in_correct_order = {label: (digitization_positions[label] if label not in exclusions else digimontage_real_order[label]) for label in correct_point_order if label !='ref'}
                digitization_montage = electrode_pos_in_correct_order
                plot_digipoints(digitization_montage, subject)
                subject_corrected_original_digitizations_filename = subject + final_corrected_digitizations_filename
                np.save(os.path.join(digitization_folder,subject_corrected_original_digitizations_filename),digitization_montage)
            digitization_montage = {key:digitization_montage[key] for key in digitization_montage.keys() if key != 'ref'} #drop ref
            #plot_digipoints(digitization_montage, subject)
            subject_final_digitizations_filename = subject + final_digitizations_filename
            np.save(os.path.join(digitization_folder,subject_final_digitizations_filename),digitization_montage)
        
        #plot_digipoints(digimontage, subject)
    subject_dirs = os.listdir(filepath)
    subject_dirs2 = [directory for directory in subject_dirs if "REFTEP" in directory and os.path.isdir(os.path.join(filepath,directory)) and '119' not in directory and '112' not in directory]
    dictionaries = [np.load(f'{filepath}/{subject}/{subject}{final_digitizations_filename}.npy',allow_pickle=True).item() for subject in subject_dirs2]
    outliers = detect_outliers_by_label(dictionaries,subject_dirs2,3)
    #check outliers again
    print("Outliers found:", len(outliers))
    for outlier in outliers:
        print(outlier, outliers[outlier])

    outliers_pos = check_nearest_consistency(dictionaries, k=8, threshold=5)
    if outliers_pos:
        print("outliers found based on neighbors")
        for label, data in outliers_pos.items():
            print(f'{label} has outliers')
            print(f"common neighbors:{sorted(data['common_neighbors'])}")
            for i, nearest in enumerate(data['neighbor_sets']):
                print(f'dict {i+1}: {nearest}')
    else:
        print("no outliers found based on neighbors")

In [None]:
if runaalto: #plot subjects that you want to inspect more
    for subject in subject_dirs:
        if '114' in subject or '125' in subject:
            dictio =  np.load(f'{filepath}/{subject}/{subject}{final_digitizations_filename}.npy',allow_pickle=True).item()
            plot_digipoints(dictio, subject)

In [None]:
ready_with_digitization=True

# Load epochs files and set the digitization montages to them and save as .fif files

In [None]:
import shutil #set the subject-specific electrode positions to the montage to later create the head<->MRI transform and forward model
if ready_with_digitization:
    failed=False
    for site in ['Aalto']:
        filepath = rf"D:\REFTEP_ALL\Digitization\Digitizations_{site}"
        subjects_eeg_directory = rf"D:\REFTEP_ALL\EEG_preprocessing_data\Preprocessing_{site}"
        source_site = rf"D:\REFTEP_ALL\Source_analysis\Source_analysis_{site}"
        os.makedirs(source_site, exist_ok=True)
        for subject in os.listdir(subjects_eeg_directory):
            if subject not in ['sub-105','sub-110','sub-118','sub-120']:
                continue
            eeg_filename = f"{subject}_EEG_aligned_final.set"
            eeg_filepath = os.path.join(subjects_eeg_directory,subject,eeg_filename)
            reftep_subject = 'REFTEP' + str(subject.split("-")[-1])
            if site=="Aalto":
                final_digitizations_filename = "_final_digitizations"
                digimontage = np.load(f'{filepath}/{reftep_subject}/{reftep_subject}{final_digitizations_filename}.npy',allow_pickle=True).item()
            elif site=='Tuebingen':
                digimontage = np.load(f'{filepath}/{reftep_subject}/{reftep_subject}_original_digitizations.npy',allow_pickle=True).item()
            else:
                failed = True
                break
            sourcepath_subject = os.path.join(source_site,subject)
            if not os.path.exists(sourcepath_subject):
                    os.mkdir(sourcepath_subject)
            eeg_filepath_fif = os.path.join(sourcepath_subject,f"{subject}_final_eeg-epo.fif")
            epochs = mne.io.read_epochs_eeglab(eeg_filepath)
            #epochs.info['dig'] = None #clear the current default channel location information
            default = epochs.info['dig']
            ch_montage = {key:val for key, val in digimontage.items() if key not in ['nas','lpa','rpa'] and key in epochs.info['ch_names']}
            digitization_montage = mne.channels.make_dig_montage(ch_pos=ch_montage, nasion=digimontage['nas'], lpa=digimontage['lpa'], rpa=digimontage['rpa'])
            epochs.set_montage(digitization_montage) #set the digitization montage to the object
            epochs.set_eeg_reference(ref_channels='average',projection=True) #add average reference projection to tell mne that it has been applied (already done in matlab before), applying it later again is not a problem
            real = epochs.info['dig']
            for digip in real:
                if digip in default:
                    print("FAILED")
                    failed = True
                    break
            epochs.save(eeg_filepath_fif,overwrite=True)
        if failed:
            print("FAILED")
            break