In [None]:
import numpy as np
import os
import mne
from reftep_util_funcs import *
import matplotlib.pyplot as plt
import pandas as pd
import shutil

In [None]:
df = pd.read_excel(r"D:\REFTEP_ALL\Source_analysis\Peak_times_all_n15_p30_n45_p60.xlsx")
dipoles_names = ['n15','p30','n45','p60']

In [None]:
for site in ['Tuebingen','Aalto']:
    site_path_sources = fr"D:\REFTEP_ALL\Source_analysis\Source_analysis_{site}"
    subjects_mri_directory = fr"D:\REFTEP_ALL\REFTEP_reco\{site}_recon_all"
    for subject in os.listdir(site_path_sources):
        subname_reco = subject[:7] + "_reco"
        #load peak times
        final_times_to_fit_to = {dipole_name: None for dipole_name in dipoles_names} #initialize a dictionary for storing the times
        center_times = {dipole_name: None for dipole_name in dipoles_names}
        subject_source_path = os.path.join(site_path_sources,subject)

        #check values from the excel
        subject_df = df[df['subject_id']==subject[4:]]
        for dip_ind, dipole_name in enumerate(dipoles_names):
            print(subject,dipole_name,subject_df,subject[4:])
            center_times[dipole_name] = subject_df[dipole_name].iloc[0]/1000
        timelist = list(center_times.values())
        if not all(timelist[i]<timelist[i+1] for i in range(len(timelist)-1)) and min(timelist) > 0 and max(timelist) < 0.08:
            print("FAILED")
            break
        min_edges, max_edges = define_fitting_ranges(center_times, 1/4, 0.011) #min time of 11 ms
        print(min_edges, max_edges, center_times)
        #define time ranges and read the forward solution
        time_ranges = {dipole_name:[min_edges[ind], max_edges[ind]] for ind, dipole_name in enumerate(dipoles_names)}
        transpath = mne.read_trans(os.path.join(subject_source_path,f'{subject}_coreg',f'{subject}-trans.fif')) #eeg-mri coregistration transformation matrix
        forward = mne.read_forward_solution(os.path.join(subject_source_path,f'{subject}-fwd.fif')) #read forward solution
        evoked = mne.read_epochs(os.path.join(subject_source_path,f'{subject}_final_eeg-epo.fif')).average() #average response
        evoked.copy().crop(-0.05, 0.15).plot_joint(list(center_times.values()),title=subject)

        dipoles_folder_path = os.path.join(subject_source_path,f'{subject}_dipoles')
        if os.path.exists(dipoles_folder_path):
            shutil.rmtree(dipoles_folder_path)
        os.makedirs(dipoles_folder_path)
        
        dipole_times = []
        for response, time_range in time_ranges.items():
            t1, t2 = np.round(time_range[0],3), np.round(time_range[1],3)
            dipole, best_free_ori_stc, evoked_fit, best_match, best_pos_ind, residual = lsq_dipole(forward, evoked, tmin=t1, tmax=t2) #get the best dipole for the time range
            dipole_pos_ind = int(dipole.name.split("dipole_")[-1]) #the position index of the dipole
            dipole_time = np.round(dipole.times[0],3)
            dipole_gof =  dipole.gof[0]
            dipole_ori = dipole.ori[0]
            dipole_fixeds = []
            #possible times for the fixed dipole according to the time range
            possible_times = np.round(np.arange(t1,t2+0.001,0.001),3) #times from t1 to t2 with t2 included
            for time in possible_times:
                #tmin=tmax because we only look at one time point per iteration
                dipole_fixed_pos, best_free_ori_stc_now_fixed, evoked_fit_fixed = lsq_dipole_to_pos(forward, evoked, time, time, dipole_pos_ind, n_times=1)
                dipole_fixeds.append(dipole_fixed_pos)
            #get the times that satisfy the 95 gof and 10 angle criterion
            dipole_times.append(dipole_time)
            good_indices = []
            for dipole_index, dipole_fixed in enumerate(dipole_fixeds):
                dotprod = np.dot(dipole_fixed.ori[0],dipole_ori)
                #these are very near one (see lsq_dipole_to_pos) but due to rounding errors etc divide once again
                magnitude_ori_fixed = np.linalg.norm(dipole_fixed.ori[0])
                magnitude_ori_orig = np.linalg.norm(dipole_ori)
                angle = np.degrees(np.arccos(np.clip(dotprod/(magnitude_ori_fixed*magnitude_ori_orig),-1.0,1.0)))
                #print(angle, dipole_fixed.gof[0])
                if dipole_fixed.gof[0] >= dipole_gof*0.9 and angle <= 30:
                    #print(dipole_index, possible_times[dipole_index],angle, dipole_fixed.gof[0],dipole_gof)
                    good_indices.append(dipole_index)
            #print(possible_times, dipole_time)
            best_dipole_index = int(np.where(possible_times==dipole_time)[0][0])
            ind1 = 0 #start from 0 to include the best dipole time
            final_times = [] #ge the times around the best dipole where the criterions match
            while True:
                index_to_look_for_1 = best_dipole_index - ind1
                if index_to_look_for_1 in good_indices:
                    final_times.append(possible_times[index_to_look_for_1])
                else:
                    break
                ind1 += 1
            ind2 = 1
            while True:
                index_to_look_for_2 = best_dipole_index + ind2
                if index_to_look_for_2 in good_indices:
                    final_times.append(possible_times[index_to_look_for_2])
                else:
                    break
                ind2 += 1


            final_times = np.array(sorted(final_times))

            #save information on the dipole and the potential times
            dipole_filepath = os.path.join(dipoles_folder_path,f'{subject}_dipole_{response}') #filepath for the dipole
            potential_times_path = os.path.join(dipoles_folder_path,f'{subject}_dipole_{response}_fitting_times')
            dipole.save(dipole_filepath,overwrite=True) #save the dipole to dipole_filepath
            np.save(potential_times_path,final_times)

            #plot
            diptitle = f'gof:{round(dipole.gof[0],1)}, amp:{round(dipole.amplitude[0]*1e9,1)}, \n time: {dipole.times[0]*1000} ms'
            print(subject, response, dipole.gof[0],dipole.times[0],final_times)
            #print(dipole.times[0])
            final_times_to_fit_to[response] = dipole.times[0]
            print(final_times_to_fit_to)
            mne.viz.plot_dipole_locations(dipole,trans=transpath,subject=subname_reco, subjects_dir=subjects_mri_directory, title=diptitle)
        evoked.copy().crop(-0.05, 0.15).plot_joint(dipole_times,title=subject)