In [None]:
import numpy as np
import os
import mne
from reftep_util_funcs import *
import matplotlib.pyplot as plt
import pandas as pd
import shutil

In [None]:
dipoles_names = list(reversed(['n15','p30','n45','p60']))
min_edges = list(reversed([0.012, 0.025, 0.04, 0.055]))
max_edges = list(reversed([0.025, 0.04, 0.055, 0.073]))
time_ranges = {dipole_name:[min_edges[ind], max_edges[ind]] for ind, dipole_name in enumerate(dipoles_names)}#time ranges for each response
min_dipole_time_diff = 0.005 #min difference between responses
print(time_ranges)
angle_diff_thresh = 10
gof_perc_thresh = 0.9

In [None]:
for site in ['Tuebingen','Aalto']:
    site_path_sources = fr"D:\REFTEP_ALL\Source_analysis\Source_analysis_{site}"
    subjects_mri_directory = fr"D:\REFTEP_ALL\REFTEP_reco\{site}_recon_all"
    for subject in os.listdir(site_path_sources):
        subname_reco = subject[:7] + "_reco"
        #load peak times
        subject_source_path = os.path.join(site_path_sources,subject)
        #read the forward solution
        transpath = mne.read_trans(os.path.join(subject_source_path,f'{subject}_coreg',f'{subject}-trans.fif')) #eeg-mri coregistration transformation matrix
        forward = mne.read_forward_solution(os.path.join(subject_source_path,f'{subject}-fwd.fif')) #read forward solution
        evoked = mne.read_epochs(os.path.join(subject_source_path,f'{subject}_final_eeg_post-epo.fif'), proj=False).average() #average response without re-applying average reference

        # create folders for dipoles
        dipoles_folder_path = os.path.join(subject_source_path,f'{subject}_dipoles')
        if os.path.exists(dipoles_folder_path):
            shutil.rmtree(dipoles_folder_path)
        os.makedirs(dipoles_folder_path)

        min_dipole_time = False #no dipole time yet
        best_dipoles_times = [] #store all "optimal" dipole times here
        for response, time_range in time_ranges.items():
            t1, t2 = time_range[0], time_range[1]
            #adjust the maximum range of the response range based on the previous (in the sequence, but later in time in reality) response
            if response != "p60":
                time_to_future_dipole = np.abs(min_dipole_time - t2)
                if time_to_future_dipole >= min_dipole_time_diff:
                    tmax_edge = t2
                else:
                    tmax_edge =  min_dipole_time - min_dipole_time_diff
            else:
                tmax_edge = t2
            print(f"Using times from {t1} till {tmax_edge} for {response}")
            possible_times = evoked.copy().crop(t1,tmax_edge).times #times from t1 to t2 with tmax_edge included
            best_dipole_time = None
            best_dipole_gof = -np.inf #init gof of the dipole resulting in highest window score
            highest_n_times = -np.inf #init number of time points to consider in the
            min_dipole_time = None
            best_dipole = None
            final_times = None
            best_window_score = -np.inf

            for potential_time in possible_times:
                dipole, best_free_ori_stc, evoked_fit, best_match, best_pos_ind, residual = r2_dipole(forward, evoked, tmin=potential_time, tmax=potential_time) #get the best dipole for the time range
                dipole_pos_ind = int(dipole.name.split("dipole_")[-1]) #the position index of the dipole
                dipole_time = dipole.times[0]
                dipole_gof = dipole.gof[0]

                #possible times for the fixed dipole according to the time range
                dipole_fixeds = []
                for possible_time in possible_times:
                    #tmin=tmax because we only look at one time point per iteration
                    dipole_fixed_pos, _, _ = dipole_to_pos(forward, evoked, possible_time, possible_time, dipole_pos_ind, n_times=1, maximize="r2", ori_fixed=None)
                    dipole_fixeds.append(dipole_fixed_pos)

                #get the times that satisfy the gof and angle criterion
                good_indices = get_good_dipole_indices(dipole_fixeds, dipole, angle_diff_thresh, gof_perc_thresh) #indices where the angle difference threshold is satisfied
                best_dipole_index = int(np.where(possible_times==dipole_time)[0][0]) #index in times array of the best dipole
                good_times, good_indices_continuous = get_fitting_times(possible_times, best_dipole_index, good_indices) #continuous time range where the angle difference threshold is satisfied around dipole_time
                window_length = len(good_times) #number of samples in the window
                amplitude_of_window = np.sqrt(np.mean([dipole_fixeds[good_index].amplitude[0] for good_index in good_indices_continuous])*1e9) #in nAm and then square root to make more stable for the score
                r2_of_window = np.mean([dipole_fixeds[good_index].gof[0]/100 for good_index in good_indices_continuous]) #r2 scaled back to -inf...1 (between -inf and 100 in the dipole)
                window_score = window_length * r2_of_window * amplitude_of_window
                if window_score > best_window_score: #check if the best score is exceeded and update values
                    best_window_score = window_score
                    final_times = good_times
                    highest_n_times = len(final_times)
                    best_dipole_gof = dipole_gof
                    best_dipole_time = dipole_time
                    min_dipole_time = np.min(final_times)
                    best_dipole = dipole
            best_dipoles_times.append(best_dipole_time)

            #save information on the dipole and the potential times
            dipole_filepath = os.path.join(dipoles_folder_path,f'{subject}_dipole_{response}') #filepath for the dipole
            potential_times_path = os.path.join(dipoles_folder_path,f'{subject}_dipole_{response}_fitting_times')
            best_dipole.save(dipole_filepath,overwrite=True) #save the dipole to dipole_filepath
            np.save(potential_times_path,final_times)

            #plot dipoles and evoked responses
            diptitle = f'gof:{round(best_dipole_gof,1)}, amp:{round(best_dipole.amplitude[0]*1e9,1)}, \n time: {best_dipole_time*1000} ms'
            print(subject, response, best_dipole_gof,best_dipole_time,final_times)
            mne.viz.plot_dipole_locations(best_dipole,trans=transpath,subject=subname_reco, subjects_dir=subjects_mri_directory, title=diptitle)
        evoked.copy().plot_joint(list(reversed(best_dipoles_times)),title=subject)