# Morlet Time Frequency Analysis Transform 

### Modules

In [None]:
import mne 
import numpy as np
np.set_printoptions(threshold=10000)
import pickle
import os
import gc

### Define dictionaries, subject, & conditions

In [None]:
# path to data files
data_path = "./"

# subject 
subj = 'xx' 

# list of conditions
condition_list = ['produce_music', 'perceive_music_produced', 'perceive_music_new', 'perceive_music_newrepetition', 'produce_speech', 'perceive_speech_produced', 'perceive_speech_new', 'perceive_speech_newrepetition']

# dictionary path 
dictionary_path = os.path.join(data_path, subj, "dictionary/")

## Run Time-Frequency Analysis 

### TFA seperately for each channel (selected channels in pick) 

In [None]:
def calculate_TFA(channel, epochs, tmin, tmax, fmin, fmax, frequencies, n_cycles): 
    # for channel in picks_selected:
    print(channel)
    
    # calculate TFA 
    power = mne.time_frequency.tfr_morlet(inst=epochs, freqs=frequencies, n_cycles=n_cycles, use_fft=False, return_itc=False, decim=1, n_jobs=-1, picks=channel, zero_mean=True, average=True, output='power', verbose=None)
                
    # cut data   
    trim = 1
    end = 300 
    epoch_duration = 2.98
    sfreq = power.info['sfreq']
    epoch_duration_tp = int(epoch_duration * sfreq)
        
    tfa = power
    times = tfa.times
    trim_start_index = int((trim * tfa.info['sfreq']) + 1)
    trim_end_index = int((end - trim) * tfa.info['sfreq'])         
    tfa_trimmed = tfa.crop(tmin=times[trim_start_index], tmax=times[trim_end_index]) 
            
    # epoch data 
    start_indices = np.arange(0, len(tfa_trimmed.times), epoch_duration_tp, dtype=int)
        
    power_new = tfa_trimmed.data
    power_segmented = np.zeros((len(start_indices), 1, len(frequencies), epoch_duration_tp))
    
    for i, start in enumerate(start_indices):
        end = start + epoch_duration_tp
        power_segmented[i,:,:,:] = power_new[:,:,start:end]        
        
    # average over timepoints  #-> (n_epochs, n_channels, n_freqs)   
    timepoints_averaged = np.mean(power_segmented, axis=-1)

    del power_new, tfa_trimmed, tfa, power, trim_start_index, trim_end_index, trim, end, epoch_duration, sfreq, epoch_duration_tp
    gc.collect()
    
    return(timepoints_averaged)
    

In [None]:
for condition in condition_list: 
    print(condition)
    
    # epochs 
    preprocessed_path = data_path + subj + "/preprocessed/" + condition + "/"
    for files in os.listdir(preprocessed_path):
        if 'day1_bipolar_epochs_preprocessed.fif' in files:
            path = preprocessed_path + files + '/'
            epochs = mne.read_epochs(path, preload=False)      

    # parameters for TFA 
    tmin = 0 
    tmax = 300 
    fmin = 1
    fmax = 180
    frequencies = np.logspace(np.log10(fmin), np.log10(fmax), num = 50, base = 10)  # log10 freq scale 
    n_cycles = frequencies / frequencies[0]  # number of cycles 
     
    # channels 
    picks = epochs.ch_names
    
    # calculate TFA and epoch data 
    results = {}
    j = 0
    for channel in picks:  
        channel = channel
        result = calculate_TFA(channel, epochs, tmin, tmax, fmin, fmax, frequencies, n_cycles)
        results[j] = result
        j += 1 
        print(j)
        
        # clear memory 
        del result, channel
        gc.collect()
        
    # put results together again  #-> (n_epochs, n_channels, n_freqs)
    morlet_results = np.empty((100, 0, len(frequencies)))
    for key in results: 
        morlet_results = np.concatenate((morlet_results, results[key]), axis=1)
        
    # store data
    with open(dictionary_path + condition + '_day1_morlet_results.pickle', 'wb') as f:
        pickle.dump(morlet_results, f)
    
    # clear memory 
    del morlet_results, results, picks, epochs, tmin, tmax, fmin, fmax, n_cycles           
    gc.collect()
    