In [None]:
import scipy
import os
import numpy as np
import mne
import matplotlib.pyplot as plt

import utils.variables as v
import utils.valid_recs as vrecs

In [None]:
class Recording:
    # Data paths
    root = 'C:/Users/annej/OneDrive/Documents/GitHub/MASTER-eeg-stress-det'
    dir_raw = root + '/Data/Raw_eeg'
    dir_filtered = root + '/Data/Init_data'
    
    # Parameters
    Fs = 250
    ch_type = 'eeg'
    n_channels = 8
    
    def __init__(self, sub_nr, ses_nr, run_nr):
        self.sub_nr = sub_nr
        self.ses_nr = ses_nr
        self.run_nr = run_nr
        
        # Load data
        self.load_data()
                
        # Create mne RawArray
        info = mne.create_info(8, sfreq=self.Fs, ch_types=self.ch_type, verbose=None)
        #print(info)
        self.raw_arr = mne.io.RawArray(self.data, info)
        
        mne.rename_channels(self.raw_arr.info, v.MAPPING)

        # Do initial filtering
        self.filt_arr = self.init_filter()
        
        # Set montage
        montage = mne.channels.make_standard_montage('standard_1020')
        self.filt_arr.set_montage(montage)   
    #----------------------------------------------------------------------------------------------------
     
    def load_data(self):
        dir = self.dir_raw
        data_key = 'raw_eeg_data'
        # Load one recording
        filename = f"/sub-{self.sub_nr}_ses-{self.ses_nr}_run-{self.run_nr}.mat"
        f = dir + filename
        self.data = scipy.io.loadmat(f)[data_key]

    def save_data(self):
        title = f"sub-{self.sub_nr}_ses-{self.ses_nr}_run-{self.run_nr}"
        clean_data = self.filt_arr.to_data_frame(scalings=1e6)
        clean_data = clean_data.to_numpy()
        clean_data = np.transpose(clean_data)
        clean_dict = {
            "Clean_data" : clean_data[1:, :]  #First column of dataFrames is not data
        }
        scipy.io.savemat(f'{self.root}/Data/Init_data/{title}.mat', clean_dict)

    def save_psd(self):
        title = f"sub-{self.sub_nr}_ses-{self.ses_nr}_run-{self.run_nr}"
        psd_data = self.psd.get_data()
        psd_dict = {
            "psd_data" : psd_data
        }
        scipy.io.savemat(f'{self.root}/Data/PSD_data/{title}.mat', psd_dict)
    
    
    def init_filter(self):
        #self.raw_arr.compute_psd().plot()
        band_pass = self.raw_arr.copy().filter(1, 40)
        #reject = band_pass.copy().filter(52,48, l_trans_bandwidth = 0.1, h_trans_bandwidth = 0.1)
        #notch = band_pass.copy().notch_filter(freqs=[50,100], trans_bandwidth = 4)
        sav_gol = band_pass.copy().savgol_filter(h_freq=35, verbose=False)
        #sav_gol.compute_psd().plot()
        return sav_gol
    
    def compute_and_save_filtered_psd(self):
        self.psd = self.filt_arr.compute_psd()
        self.psd.plot()
        #self.save_psd()


    def init_ICA(self):
        self.ica = mne.preprocessing.ICA(n_components=8, max_iter=10000, random_state=97)
        self.ica.fit(self.filt_arr)
        
    def plot_sources(self):
        self.ica.plot_sources(self.filt_arr, title=f'ICA components sub-{self.sub_nr}_ses-{self.ses_nr}_run-{self.run_nr}', show_scrollbars=False)
        self.ica.plot_components(colorbar=True, reject='auto')
        
    def plot_properties(self, components):
        self.ica.plot_properties(self.filt_arr, picks = components)
        
    def test_exclude(self, components):
        self.ica.plot_overlay(self.filt_arr, exclude=components, picks='eeg', stop = 300.)
        #self.ica.plot_overlay(self.filt_arr, exclude=components, picks='eeg', show = True)
 

    def exclude_ICA(self, components):
        self.ica.exclude = components
        self.reconst_arr = self.filt_arr.copy()
        self.ica.apply(self.reconst_arr)
    
    def plot(self, data_type, save=False):
        if data_type == 'ica' and save == True:
            with mne.viz.use_browser_backend('matplotlib'):
                title = f"ICA components sub-{self.sub_nr}_ses-{self.ses_nr}_run-{self.run_nr}"
                fig = self.ica.plot_sources(self.filt_arr, title=title, 
                                            show_scrollbars=False)
                fig.savefig(f'{self.root}/Figures/{title}.png') 

        else:
            if data_type == 'raw':
                data = self.raw_arr
                title = f"Raw data sub-{self.sub_nr}_ses-{self.ses_nr}_run-{self.run_nr}"
                data.compute_psd().plot()
            elif data_type == 'filtered': 
                data = self.filt_arr
                title = f"Filtered data sub-{self.sub_nr}_ses-{self.ses_nr}_run-{self.run_nr}"
                data.compute_psd().plot()
            elif data_type == 'reconstructed':
                data = self.reconst_arr
                title = f"Reconstructed data sub-{self.sub_nr}_ses-{self.ses_nr}_run-{self.run_nr}"
                data.compute_psd().plot()
                pass

            if not save:
                data.plot(duration = 25, title=title, n_channels=self.n_channels, scalings=None, show_scrollbars=False)
            else:
                with mne.viz.use_browser_backend('matplotlib'):
                    fig = data.plot(duration = 30, title=f'{title}', n_channels=8, scalings=None, show_scrollbars=False)
                    fig.savefig(f'{self.root}/Figures/{title}.png') 



In [None]:
valid_recs = vrecs.get_valid_recs('raw', 'np')

for rec in valid_recs:
    sub_nr, ses_nr, run_nr = rec.split('_')
    test = Recording(sub_nr, ses_nr, run_nr)
    test.save_data()

In [None]:
#test.plot('raw')
#test.plot('filtered')

In [None]:
#test.compute_and_save_filtered_psd()

In [None]:
#test.init_ICA()
#test.plot_sources()

In [None]:
#test.plot_properties([2])

In [None]:
#test.test_exclude([0,1,2])

In [None]:
#test.exclude_ICA([0,1,2])
#test.plot('filtered')
#test.plot('reconstructed')

In [None]:
#test.plot('raw', save=True)
#test.plot('filtered', save=True)
#test.plot('ica', save=True)
#test.plot('reconstructed', save=True)

In [None]:
#test.save_data()