In [None]:
import mne
import os
import numpy as np
import matplotlib.pyplot as plt


fname_fsaverage_src = r"D:\REFTEP_ALL\REFTEP_reco\Aalto_recon_all\fsaverage\fsaverage-oct-6-src.fif"
src_to = mne.read_source_spaces(fname_fsaverage_src)
position_names = ['n15','p30','n45','p60','handknob']
dipole_names = position_names[:-1] #all except handknob
handknob_pos = np.array([-40,-15, 62])/1000
distance = 2e-2 #cm
all_source_positions_fsavarage = np.array(list(src_to[0]['rr']) + list(src_to[1]['rr'])) #source positions
used_or_not_fsavarage = np.array(list(src_to[0]['inuse']) + list(src_to[1]['inuse']))
used_source_positions_fsaverage = all_source_positions_fsavarage[used_or_not_fsavarage==1] #used is 1 and not used is 0
vertices_used_fsaverage = np.array(list(src_to[0]['vertno']) + list(src_to[1]['vertno']))
used_source_positions_fsaverage_hemis = [src_to[0]['rr'][src_to[0]['inuse']==1], src_to[1]['rr'][src_to[1]['inuse']==1]] #used is 1 and not used is 0
#mri_to_mni_trans_fsaverage = mne.read_talxfm('fsaverage', subjects_dir=r"D:\REFTEP_ALL\REFTEP_reco\Aalto_recon_all") #this is identify matrix in fsaverage so no need to use in fsaverage, mni talairach = freesurfer mri surface coords
parcellation_name = "dipoles_handknob_parc"
trans_to_fsaverage = True
sites = ['Aalto','Tuebingen']

In [None]:
%matplotlib qt
plotter = False

In [None]:
for site in sites:
    sourcepath_site = f'D:\REFTEP_ALL\Source_analysis\Source_analysis_{site}'
    subjects_mri_directory = f"D:\REFTEP_ALL\REFTEP_reco\{site}_recon_all"
    for subject in os.listdir(sourcepath_site):
        sourcepath_subject = os.path.join(sourcepath_site, subject)
        subname_reco = subject[:7] + "_reco" #subject without _rep ending if _rep is there, and added _reco
        transpath = os.path.join(sourcepath_subject,f'{subject}_coreg',f'{subject}-trans.fif') #eeg-mri coregistration transformation matrix
        trans = mne.transforms.read_trans(transpath)

        src_fname = os.path.join(subjects_mri_directory,subname_reco,f'{subname_reco}-oct-6-src.fif')
        src  = mne.read_source_spaces(src_fname)

        all_source_positions = np.array(list(src[0]['rr']) + list(src[1]['rr'])) #source positions
        used_or_not = np.array(list(src[0]['inuse']) + list(src[1]['inuse']))
        used_source_positions = all_source_positions[used_or_not==1] #used is 1 and not used is 0
        vertices_used = np.array(list(src[0]['vertno']) + list(src[1]['vertno']))
        used_source_positions_hemis = [src[0]['rr'][src[0]['inuse']==1], src[1]['rr'][src[1]['inuse']==1]] #used is 1 and not used is 0

        for pos_name in position_names: #go through all responses, get their positions and get create labels based on them in fsaverage space
            labelname = subject + "_around_" + pos_name + "_label"
            if pos_name in dipole_names:
                dipole_filepath = os.path.join(sourcepath_subject,f'{subject}_dipoles',f'{subject}_dipole_{pos_name}')
                dipole = mne.read_dipole(dipole_filepath)
                pos = dipole.pos[0] #this is in head coordinates
                if trans_to_fsaverage:
                    pos = mne.head_to_mni(pos, subname_reco, trans, subjects_dir=subjects_mri_directory)/1000 #get mni coordinates of the dipole
                    distances = np.linalg.norm(used_source_positions_fsaverage - pos, axis=1) #distances to sources from dipole position in fsaverage
            elif pos_name == 'handknob'and trans_to_fsaverage:
                pos = handknob_pos #in mni talairach, freesurfer surface coordinate space
                distances = np.linalg.norm(used_source_positions_fsaverage - pos, axis=1) #distances to sources from dipole position
            else: #no handknob estimate for individual subject, only fsaverage
                continue
            best_pos_ind = np.argsort(distances)[0] #get the closest source location
            print(best_pos_ind)
            #get the source positions closest to the closest source location along the surface
            if trans_to_fsaverage:
                adj = mne.spatial_dist_adjacency(src_to,dist=distance).toarray()[best_pos_ind]
            pos_inds = np.array([ind for ind in range(len(adj)) if adj[ind]==1] + [best_pos_ind])#the best pos ind is not in its own adjagency

            labels_around_pos = []
            for hemi_ind, hemi in enumerate(['lh','rh']):
                #used position indices for this hemisphere, sometimes there can be a duplicate vertex..?
                bad_pos_index = 0
                if trans_to_fsaverage:
                    used_vertices = []
                    pos_inds_hemi = []
                    for pos_ind in pos_inds:
                        pos_now = used_source_positions_fsaverage[pos_ind]
                        for pos_compare in used_source_positions_fsaverage_hemis[hemi_ind]:
                            if pos_now[0]==pos_compare[0] and pos_now[1]==pos_compare[1] and pos_now[2]==pos_compare[2]:
                                used_vertices.append(vertices_used_fsaverage[pos_ind])
                                pos_inds_hemi.append(pos_ind)

                    normal_points = used_source_positions_fsaverage_hemis[hemi_ind]
                    positions_around = used_source_positions_fsaverage[pos_inds_hemi]
                if len(used_vertices) > 0:
                    #sort the vertices and positions by vertex number (required by mne.Label)
                    sorted_pairs = sorted(zip(used_vertices,positions_around), key=lambda pair:pair[0])
                    sorted_vertices_around = [sp[0] for sp in sorted_pairs]
                    sorted_positions_around = np.array([sp[1] for sp in sorted_pairs])
                    #plot the labels
                    if plotter:
                        fig = plt.figure()
                        ax = fig.add_subplot(111,projection='3d')
                        #plot all source locs in the hemi
                        ax.scatter(normal_points[:,0],normal_points[:,1],normal_points[:,2], color='white',label=f"source pos in {hemi}")

                        #plot points near the source loc
                        ax.scatter(sorted_positions_around[:,0],sorted_positions_around[:,1],sorted_positions_around[:,2], color='blue',label=f"positions around {pos_name} loc in {hemi}")
                        #plot the original pos
                        ax.scatter(pos[0],pos[1],pos[2],color='red',label='original pos')

                        ax.set_xlabel("X")
                        ax.set_ylabel("Y")
                        ax.set_zlabel("Z")
                        ax.legend()
                        ax.set_title(hemi + "_" + pos_name)
                        plt.show()
                        brain = mne.viz.Brain('fsaverage', hemi='both', surf='white', subjects_dir=subjects_mri_directory, title=hemi + "_" + pos_name)
                    #name the label
                    labelname_hemi = labelname + "_" + hemi
                    if trans_to_fsaverage:
                        labelnow = mne.Label(vertices=sorted_vertices_around, hemi=hemi, pos=sorted_positions_around, name=labelname_hemi, subject='fsaverage')
                        labels_around_pos.append(labelnow)
            if trans_to_fsaverage:
                if len(labels_around_pos) > 1:
                    label_around_pos = labels_around_pos[0] + labels_around_pos[1]
                    print("odd label",label_around_pos)
                    raise ValueError("odd label")
                else:
                    label_around_pos = labels_around_pos[0]
                label_around_pos.name = labelname #change the labelname
                if plotter:
                    brain.add_label(label_around_pos, color="red")
                custom_label_one = [label_around_pos]
                parcellation_name = f"{subject}_{pos_name}_label_fsaverage"
                mne.write_labels_to_annot(custom_label_one, subject='fsaverage', parc=parcellation_name, overwrite=True, subjects_dir=r"D:\REFTEP_ALL\REFTEP_reco\Aalto_recon_all",hemi='both', sort=True)