### Step 1: Load modules and functions

In [None]:
# ------------------------------------------------------------------------------ 
# Define parameters:
# ------------------------------------------------------------------------------ 
p_value_threshold       = 0.05
integration_start_ms    = 70
integration_end_ms      = 170
test_train_split        = 0.85 # percentage of the train split, i.e. test split size is 1 - this.
min_QC_channel_tresh    = 0.10 # percentage of channels, which need to be good to keep this session.
min_number_QC_channel_tresh = 15 # minimum number of QC channels required to merge.


In [1]:
import xarray as xr
import h5py
from brainio.assemblies import NeuronRecordingAssembly
from pynwb import NWBHDF5IO, NWBFile
import glob, os, yaml
import pytz  # This is required to handle timezone conversions
from datetime import datetime
from uuid import uuid4
import numpy as np
import scipy.io
import os, glob, json
import pandas as pd
from pynwb.file import Subject
import logging, sys, shutil
import re

cwd = os.getcwd()
sys.path.append(os.path.dirname(cwd))
root_dir        = '/braintree/home/aliya277/inventory_new'

def read_names(filename):
    assignment  = filename.split('.')[0].split('-')[1]
    number      = filename.split('.')[0].split('-')[2]
    return np.asarray([assignment, number])

def create_prom_nwb(config, path, experiment_name):

    desired_timezone = pytz.timezone('US/Eastern')
    pattern = r'Fixation\s*:\s*(\d+), Visual\s*:\s*(\d+), ON/OFF\s*:\s*(\d+/\d+)'
    matches = re.search(pattern, config['session_info']['session_description'])
    if matches: notes_experiment = matches.group(0)
    else: 
        pattern = r'(Fixation|Visual|ON/OFF)\s*:\s*(\d+|\d+-\d+|\d+/\d+)'
        matches = re.findall(pattern, config['session_info']['session_description'])
        notes_experiment = ', '.join(f'{key}: {value}' for key, value in matches)

    ################ CREATE NWB FILE WITH METADATA ################################
    ###############################################################################
    nwbfile = NWBFile(
        session_description     = 'This NWB is derived from the processing of neurophysiological recordings in the inferior temporal (IT)\
        cortex of one macaque monkey (indicated in the file). These recordings were conducted during Rapid Serial Visual Presentation (RSVP) of \
        randomized images, each presented at the center of gaze during fixation. The dataset in this file includes a responses from a number of \
        individual neural recording sites collected via chronically implanted Utah arrays. For each recording site, the multiunit or single unit \
        response of that site is summarized as a set of peristimulus time histograms (PSTH). Each PSTH is derived from the site’s response to repeated, \
        emporally randomized, presentations of the same image. Thus, the number of PSTHs is identical to the number of images in the stimulus set \
        (typically 100-2000 images). The number of repetitions of each image is indicated in the file and the data from individual repetitions are \
        available in the file. For larger image sets these recordings were made over a series of consecutive days (usually <5 days). Quality control \
        measures (e.g. consistency of response pattern over images) collected on each day are used to check for site stability across those days \
        before producing the final PSTH estimates for each site. The corresponding visual images used in these recordings are linked or stored in \
        the "Stimulus_Template" section of each NWB file.',
        identifier              = str(uuid4()),
        session_start_time      = desired_timezone.localize(config['metadata']['session_start_time']),
        file_create_date        = datetime.now(desired_timezone), #desired_timezone.localize(config['metadata']['file_create_date']),
        experimenter            = config['general']['lab_info']['experimenter'],
        experiment_description  = config['general']['experiment_info']['experiment_description'],
        session_id              = experiment_name, #config['session_info']['session_id'],
        lab                     = config['general']['lab_info']['lab'],                     
        institution             = config['general']['lab_info']['university'],                                    
        keywords                = config['general']['experiment_info']['keywords'],
        surgery                 = config['general']['experiment_info']['surgery'],
        notes                   = notes_experiment
    )

    ################ CREATE SUBJECT ################################################
    ################################################################################
    nwbfile.subject = Subject(
        subject_id  = config['subject']['subject_id'],
        date_of_birth= config['subject']['date_of_birth'],
        species     = config['subject']['species'],
        sex         = config['subject']['sex'],
        description = config['subject']['description'],
    )

    ################ CREATE HARDWARE LINKS #########################################
    ################################################################################
    nwbfile.create_device(
        name        = config['hardware']['system_name'], 
        description = config['hardware']['system_description'], 
        manufacturer= config['hardware']['system_manuf']
    )

    nwbfile.create_device(
        name        = config['hardware']['adapter_manuf'], 
        description = config['hardware']['adapter_description'], 
        manufacturer= config['hardware']['adapter_manuf']
    )

    nwbfile.create_device(
        name        = config['hardware']['monitor_name'], 
        description = config['hardware']['monitor_description'], 
        manufacturer= config['hardware']['monitor_manuf']
    )

    nwbfile.create_device(
        name        = config['hardware']['photodiode_name'], 
        description = config['hardware']['photodiode_description'], 
        manufacturer= config['hardware']['photodiode_manuf']
    )
    
    nwbfile.create_device(
        name        = 'Software Used', 
        description = str(['Mworks Client: '+config['software']['mwclient_version'],\
                        'Mworks Server: '+config['software']['mwserver_version'],\
                        'OS: '+config['software']['OS'],\
                        'Intan :'+config['software']['intan_version']])
    )

    ################ CREATE ELECTRODE LINKS ########################################
    ################################################################################
    electrodes = nwbfile.create_device(
        name        = config['hardware']['electrode_name'], 
        description = config['hardware']['electrode_description'], 
        manufacturer= config['hardware']['electrode_manuf']
    )
    if config['subject']['subject_id'] != 'solo':
        all_files = sorted(os.listdir(os.path.join(path, 'SpikeTimes')))
        
        name_accumulator = []
        for file in all_files:
            name_accumulator.append(read_names(file))
        names = np.vstack(name_accumulator)

    elif config['subject']['subject_id'] == 'solo':
        with open('/braintree/home/aliya277/sachis_data/solo_mapping.json', 'r') as file:
            data = json.load(file)
        list_values = data['neuroid_id'].values()
        names_list = [value.split('-') for value in list_values]
        list_assignment_number_int = [[pair[0], int(pair[1])] for pair in names_list]

        # Convert the list to a numpy array
        names = np.array(list_assignment_number_int, dtype='object')


    nwbfile.add_electrode_column(name="label", description="label of electrode")
    groups, count_groups = np.unique(names[:,0], return_counts =True)
    ids                  = names[:,1]
    counter              = 0
    # create ElectrodeGroups A, B, C, ..
    for group, count_group in zip(groups, count_groups):
        if len(groups) == 6:
            electrode_description = "Serialnumber: {}. Adapter Version: {}".format(config['array_info']['array_{}'.format(group)]['serialnumber'],\
                            config['array_info']['array_{}'.format(group)]['adapterversion']),
        else: 
            electrode_description = "Serialnumber: {}".format(config['array_info']['array_{}'.format(group)]['serialnumber']),
                
        
        electrode_group = nwbfile.create_electrode_group(
            name        = "group_{}".format(group),
            description = electrode_description[0],
            device      = electrodes,
            location    = 'hemisphere, region, subregion: '+str([config['array_info']['array_{}'.format(group)]['hemisphere'],\
                                config['array_info']['array_{}'.format(group)]['region'],
                                config['array_info']['array_{}'.format(group)]['subregion']]),
            position    = config['array_info']['array_{}'.format(group)]['position']
        )

        # create Electrodes 001, 002, ..., 032 in ElectrodeGroups per channel
        for ichannel in range(count_group):
            nwbfile.add_electrode(
                group       = electrode_group,
                label       = ids[counter],
                location    = 'row, col, elec'+str(json.loads(config['array_info']['intan_electrode_labeling_[row,col,id]'])[counter])
            )
            counter += 1     


    return nwbfile

df = pd.read_excel( os.path.dirname(cwd)+'/pico_inventory.xlsx' , sheet_name='Sheet2')
df['prom'] = ''
df['prom train'] = ''
df['prom test'] = ''
df['removed nan reps/session'] = ''
df['number of qc channels'] = ''
df['number of total reps'] = ''
df['sessions kept'] = ''


### Step 2: Create per experiment nwb file with good channels only

In [2]:
def experiment_processed(df, exp_path):
    imageset = os.path.basename(exp_path).split('.')[0].split('_')[1:]
    
    if len(imageset) == 1: imageset = imageset[0]
    elif len(imageset) > 1: imageset = '_'.join(imageset)
    mask = df['ImageSet'] == imageset
    index = df.index[mask].tolist()[0]
    if df.iloc[index]['proc_nwb'].startswith("P-Values added."): return True
    else: 
        print("    ", df.iloc[index]['proc_nwb'])
        return False

def update_sheet(df, exp_path, text, which_nwb):
    imageset = os.path.basename(exp_path).split('.')[0].split('_')[1:]
    if len(imageset) == 1: imageset = imageset[0]
    elif len(imageset) > 1: imageset = '_'.join(imageset)
    mask = df['ImageSet'] == imageset
    index = df.index[mask].tolist()[0]
    df.at[index, which_nwb] = text

def find_norm_with_date(realdate):
    all_paths = []
    normalizer_file_paths = glob.glob(os.path.join(root_dir, '[norm]*', '*', '*', '[!h5]*'))
    for norm_file_path in normalizer_file_paths:
        date, time = os.path.basename(norm_file_path).split('.')[-2].split('_')
        if date == realdate: 
            all_paths.append(norm_file_path)

    if len(all_paths) == 0: return None   
    
    return all_paths

def zscore_psth_using_normalizers(psth, realdate, integration_start_ms, integration_end_ms):

    norm_paths = find_norm_with_date(realdate)
    for norm_path in norm_paths:

        if norm_path != None:
            norm_nwb_file_path = glob.glob(os.path.join(norm_path, '*[nwb]'))[0]
            
            print(f"Using normalizer file {os.path.basename(norm_nwb_file_path)}")
            io = NWBHDF5IO(norm_nwb_file_path, "r") 
            norm_nwbfile = io.read()
            try: 
                normalizer_psth = norm_nwbfile.scratch['psth'][:]
                normalizer_meta = norm_nwbfile.scratch['psth meta'][:]
                io.close()
            except Exception as error: print(error)
            
            if normalizer_psth.shape[-1] != psth.shape[-1]: continue 
            # print(normalizer_psth.shape)
           
            '''
            This part of the code has been recycled from Sachi's Code: 
            /spike-tools-chong/spike_tools/utils/spikeutils.py/combine_sessions()
            '''
            
            assert len(normalizer_psth.shape) == 4 , 'Normalizer PSTH has wrong shape.'         # num_images x num_repetitions x num_timebins x num_channels
            
            timebase = np.arange(int(normalizer_meta[0]), int(normalizer_meta[1]), int(normalizer_meta[2]))
            t_cols = np.where((timebase >= integration_start_ms) & (timebase < integration_end_ms))[0]

            if normalizer_psth.shape[0] == 86: # norm_FOSS: 85 normalizers images + 1 gray image
                images_ids = np.arange(0,86)
                images_no_grey = np.where(images_ids != 26)[0]
                normalizer_p_no_grey = normalizer_psth[images_no_grey,:,:,:] # Select all images except grey (#26)
            
            elif normalizer_psth.shape[0] == 26: # norm_HVM
                normalizer_p_no_grey = normalizer_psth[:-1,:,:,:] # Select all images except grey (#26)
            
            #print(normalizer_p_no_grey[:, :, t_cols, :])
            n_p = np.nanmean(normalizer_p_no_grey[:, :, t_cols, :], 2) # then mean 70-170 time bins
            n_p = n_p.reshape(-1, normalizer_p_no_grey.shape[-1])  # Reshape so that first two axes collapse into one

            mean_response_normalizer = np.nanmean(n_p, 0)   # Mean across images x reps
            std_response_normalizer  = np.nanstd(n_p, 0)    # Std across images x reps

            psth = np.subtract(psth, mean_response_normalizer[np.newaxis, np.newaxis, np.newaxis, :])
            psth = np.divide(psth, std_response_normalizer[np.newaxis, np.newaxis, np.newaxis, :],
                            where=std_response_normalizer!=0)
            
            return psth, normalizer_psth, normalizer_meta, norm_nwb_file_path

        else: 
            print(f'No normalizer found for day {realdate}.')
            return [None]


In [12]:
# ------------------------------------------------------------------------------ 
# Load excel file and find all experiment names goind on brainscore. 
# ------------------------------------------------------------------------------ 
list_of_bs_exp_names = []
for index, row in df.iterrows():
    if row['BrainScore']=='Y': list_of_bs_exp_names.append(row['ImageSet'])
    
experiment_file_paths = glob.glob(os.path.join(root_dir, '[exp]*', '*'))
experiment_file_paths = [d for d in experiment_file_paths if 'VideoStimulusSet' not in os.path.basename(d)]

for experiment_path in experiment_file_paths: 
    # ------------------------------------------------------------------------------ 
    # Skip files, which do not go on BrainScore.
    # ------------------------------------------------------------------------------ 
    experiment_name =  "_".join(os.path.basename(experiment_path).split('.')[0].split('_')[1:])

    if experiment_name not in list_of_bs_exp_names: 
        continue 
        
    # if os.path.basename(experiment_path)!='exp_NSD-COCO.sub_pico': continue 

    days    = glob.glob(os.path.join(experiment_path, '*[!npy][!nwb]'))
    n_days  = len(days)
    n_sessions = 0
    for day in days :
        n_sessions += len(glob.glob(os.path.join(experiment_path, day, '*',  '*proc.nwb')))

    print('________________________________________________________________________________')
    print(f'{os.path.basename(experiment_path)} has {n_days} days and {n_sessions} sessions')
    
    # ------------------------------------------------------------------------------ 
    # Skip files, which have no nwb files or if combined nwb file already exists.
    # ------------------------------------------------------------------------------ 
    if n_sessions == 0: continue
    combined_exists = False
    train_exists    = False
    test_exists     = False
    min_QC_channel_tresh_local = 0.0

    if os.path.isfile(os.path.join(experiment_path, f"{os.path.basename(experiment_path)}.prom.nwb")):
        print(f'Prom nwb file exists for {os.path.basename(experiment_path)}')
        io = NWBHDF5IO(os.path.join(experiment_path, f"{os.path.basename(experiment_path)}.prom.nwb"), "r") 
        combined_nwb = io.read()
        min_QC_channel_tresh_local = float(combined_nwb.scratch['PSTHs_QualityApproved_SessionMerged'].description.split(' %')[-2].split(' ')[-1])/100
        total_reps = combined_nwb.scratch['PSTHs_QualityApproved_SessionMerged'][:].shape[1]
        total_QC_channels = combined_nwb.scratch['PSTHs_QualityApproved_SessionMerged'][:].shape[3]
        pattern = r'\d{8}_\d{6}'
        dates_times = re.findall(pattern, combined_nwb.scratch['PSTHs_QualityApproved_SessionMerged'].description)

        io.close()

        print(f'QC channel threshold: {min_QC_channel_tresh_local*100} %')
        print(f'Total reps number, QC channel number: {total_reps}, {total_QC_channels}')
        update_sheet(df, experiment_path, 'Done.', 'prom')
        update_sheet(df, experiment_path, total_QC_channels , 'number of qc channels')
        update_sheet(df, experiment_path, str(dates_times) , 'sessions kept')
        update_sheet(df, experiment_path, total_reps , 'number of total reps')
        combined_exists = True

    if os.path.isfile(os.path.join(experiment_path, f"{os.path.basename(experiment_path)}.prom_train.nwb")):
        print(f'Prom train nwb file exists for {os.path.basename(experiment_path)}')
        # io = NWBHDF5IO(os.path.join(experiment_path, f"{os.path.basename(experiment_path)}.prom_train.nwb"), "r") 
        # combined_nwb_train = io.read()
        update_sheet(df, experiment_path, 'Done', 'prom train')
        train_exists    = True

    if os.path.isfile(os.path.join(experiment_path, f"{os.path.basename(experiment_path)}.prom_test.nwb")):
        print(f'Prom test nwb file exists for {os.path.basename(experiment_path)}')
        io = NWBHDF5IO(os.path.join(experiment_path, f"{os.path.basename(experiment_path)}.prom_test.nwb"), "r") 
        combined_nwb_test = io.read()
        update_sheet(df, experiment_path, 'Done.', 'prom test')
        test_exists     = True
        

    if combined_exists and train_exists and test_exists: continue 

    # ------------------------------------------------------------------------------ 
    # Create combined nwb files, if all previous steps are taken.
    # ------------------------------------------------------------------------------ 
    if experiment_processed(df, experiment_path) == True: 
        first_day = days[0].split('.')[-1]
        with open(os.path.join(glob.glob(os.path.join(experiment_path, days[0], '*'))[0],f"config_nwb.yaml") , "r") as f:
                first_rec_config = yaml.load(f, Loader = yaml.FullLoader)
                if combined_exists == False: combined_nwb        = create_prom_nwb(first_rec_config, glob.glob(os.path.join(experiment_path, days[0], '*'))[0], os.path.basename(days[0]).split('.')[0])
                if train_exists    == False: combined_nwb_train  = create_prom_nwb(first_rec_config, glob.glob(os.path.join(experiment_path, days[0], '*'))[0], os.path.basename(days[0]).split('.')[0])
                if test_exists     == False: combined_nwb_test   = create_prom_nwb(first_rec_config, glob.glob(os.path.join(experiment_path, days[0], '*'))[0], os.path.basename(days[0]).split('.')[0])
            
        combined_good_channel_ids       = []
        combined_date_times             = []
        combined_normalizer_date_times  = []
        train_IDs                       = []
        test_IDs                        = []
        removed_nan_reps                = []
        
        # ------------------------------------------------------------------------------ 
        # Load all the proc nwb files and combine them to a prom nwb file.
        # ------------------------------------------------------------------------------ 

        pass_saving       = False
        all_good_channels = []
        all_psth          = []
        all_psth_zscored  = []
        all_psth_normalizers = []

        for day, i_day in zip(days, range(n_days)) :
            print(f'Loading files from day {i_day+1}/{n_days}')
            exp_nwb_paths = (glob.glob(os.path.join(experiment_path, day, '*',  '*proc.nwb')))
            
            for exp_nwb_path in exp_nwb_paths:
                
                # ------------------------------------------------------------------------------ 
                # Load recording nwb file.
                # ------------------------------------------------------------------------------ 
                try:
                    io = NWBHDF5IO(exp_nwb_path, "r") 
                    exp_nwbfile = io.read()
                except:
                    print(f"Cannot open nwb. file {os.path.basename(exp_nwb_path)}")
                    io.close()
                    update_sheet(df, experiment_path, f"Cannot open nwb. file {os.path.basename(exp_nwb_path)}", 'prom')
                    pass_saving = True
                    continue 
                    
                try: 
                    psth        = exp_nwbfile.scratch['psth'][:]
                    psth_meta   = exp_nwbfile.scratch['psth meta'][:]

                    if os.path.basename(experiment_path) == 'exp_muri1320.sub_pico' and i_day == 0:
                        # This particular session for this experiment has been calculated with different 
                        # parameter settings within an experiment. Therefore, I am manually correcting this previous error. 
                        psth = psth[:,:,10:40,:]
                        psth_meta = np.array([0, 300, 10])
                    
                    print(f'    PSTH Meta of file {i_day+1}/{n_days}: {psth_meta}')
                    
                    # ------------------------------------------------------------------------------ 
                    # Define Train and Test set stimulus ids for this stimulus set. This is 
                    # consistend over the whole project.
                    # ------------------------------------------------------------------------------ 
                    if i_day == 0:
                        n_stimuli   = psth.shape[0]
                        stimulus_id = np.arange(n_stimuli)
                        np.random.shuffle(stimulus_id)
                        split_index = int(n_stimuli * test_train_split)
                        
                        train_ids = stimulus_id[:split_index]
                        test_ids  = stimulus_id[split_index:]

                        train_ids = sorted(train_ids)
                        test_ids  = sorted(test_ids)

                        train_IDs.append(train_ids)
                        test_IDs.append(test_ids)  
                    else:
                        train_ids = train_IDs[0]
                        test_ids  = test_IDs[0]

                    # print('    Train IDs: ',train_ids)
                    # print('    Test IDs: ', test_ids)

                    # ------------------------------------------------------------------------------ 
                    # Calculate z-scored psth.
                    # ------------------------------------------------------------------------------                     
                    current_day = day.split('.')[-1]
                    nan_exists = True
                    nan_rep_counter = 0
                    while nan_exists == True:
                        if np.isnan(psth).sum() != 0: 
                            nan_rep_counter += 1
                            psth = psth[:,:-1,:,:] # remove last rep of nans 
                        else: nan_exists = False
                    removed_nan_reps.append(nan_rep_counter)
                    print(f'    Removed nan-reps: {nan_rep_counter}')
                    psth_zscored, normalizer_psth, normalizer_meta, normalizer_filename = zscore_psth_using_normalizers(psth, current_day, integration_start_ms, integration_end_ms)
                    
                    # TO DO!! CHECK IF THIS IS NECESSARY
                    if normalizer_meta[0]!=0 and normalizer_meta[1]!=300:
                        normalizer_psth = normalizer_psth[:,:,int(-normalizer_meta[0]/10):int(-normalizer_meta[0]/10)+30,:]
                        normalizer_meta = np.array([0, 300, 10])
                    
                    #all_meta_normalizers.append(normalizer_meta)
                    all_psth_normalizers.append(normalizer_psth)
                    combined_normalizer_date_times.append(os.path.basename(normalizer_filename).split('.')[-3])

                    # ------------------------------------------------------------------------------ 
                    # Remove all the bad channels for the z-scored and not z-scored psth. 
                    # This is the quality checked (QC) psth
                    # ------------------------------------------------------------------------------ 
                    p_values = exp_nwbfile.scratch['PValuesPerChannel'][:]
                    good_channel_ids = p_values<p_value_threshold
                    combined_good_channel_ids.append(good_channel_ids)
                    combined_date_times.append(os.path.basename(exp_nwb_path).split('.')[-3])
                    
                    psth_QC         = psth[:,:,:,good_channel_ids] 
                    psth_zscored_QC = psth_zscored[:,:,:,good_channel_ids] 

                    print(f"    psth_zscored and original psth shape:   {psth_zscored.shape} {psth.shape}")
                    print(f'    psth_QC and original psth shape:        {psth_QC.shape} {psth.shape}')
                    print(f'    psth_zscored_QC and psth_zscored shape: {psth_zscored_QC.shape} {psth_zscored.shape}')
                    print(f'    normalizer psth shape and meta:         {normalizer_psth.shape} {normalizer_meta}')
                    # ------------------------------------------------------------------------------ 
                    # Remove the last 'nan' rep.
                    # ------------------------------------------------------------------------------
                    for i in range(nan_rep_counter):
                        if np.isnan(psth_QC).sum() != 0: psth_QC = psth_QC[:,:-1,:,:] # remove last rep of nans 
                        if np.isnan(psth_zscored_QC).sum() != 0: psth_zscored_QC = psth_zscored_QC[:,:-1,:,:] # remove last rep of nans 
                        if np.isnan(psth_zscored).sum() != 0: psth_zscored = psth_zscored_QC[:,:-1,:,:] # remove last rep of nans 
                        
                    assert np.isnan(psth_QC).sum() == 0, "Nan Repetitions are still present. This should not be the case."

                    print(f'    psth_QC shape after removing nan-reps: {psth_QC.shape}')
                    all_psth.append(psth)
                    all_psth_zscored.append(psth_zscored)                  

                    # ------------------------------------------------------------------------------ 
                    # Add to psth_QC and psth_zscored_QC combined nwb files. 
                    # One file containing all the data, one test and one train.
                    # ------------------------------------------------------------------------------ 
                    def description_QC(psth_meta, z_scored, train_test): 
                        if z_scored == True: psth_array_type = 'Z-scored (using the normalizer recordings from this day) '
                        else: psth_array_type =''
                        if train_test != 'combined': 
                            StimIDs = 'A corresponding index-to-stimulus ID mapping is available in the array StimuliIDs. '
                        else: StimIDs =''
                        return f"{psth_array_type}PSTH array with dimensions corresponding to \
                            [stimuli x repetitions x time bins x good quality channels], \
                            where 'quality channels' are those with p-values (calculated using the normalizer recordings) less \
                            than {p_value_threshold}, indicating high signal quality.\
                            {StimIDs}The PSTH meta are the following for [start_time_ms, stop_time_ms, tb_ms]: {psth_meta}"
                        
                    if combined_exists == False:
                        combined_nwb.add_scratch(
                            psth_QC,
                            name=f"PSTHs_QualityApproved_{os.path.basename(exp_nwb_path).split('.')[-3]}",
                            description=description_QC(psth_meta, z_scored=False, train_test='combined'))
                        combined_nwb.add_scratch(
                            psth_zscored_QC,
                            name=f"PSTHs_QualityApproved_ZScored_{os.path.basename(exp_nwb_path).split('.')[-3]}",
                            description=description_QC(psth_meta, z_scored=True, train_test='combined'))
                    if train_exists == False:
                        combined_nwb_train.add_scratch(
                            psth_QC[train_ids,:,:,:],
                            name=f"PSTHs_QualityApproved_{os.path.basename(exp_nwb_path).split('.')[-3]}",
                            description=description_QC(psth_meta, z_scored=False, train_test='train'))
                        combined_nwb_train.add_scratch(
                            psth_zscored_QC[train_ids,:,:,:],
                            name=f"PSTHs_QualityApproved_ZScored_{os.path.basename(exp_nwb_path).split('.')[-3]}",
                            description=description_QC(psth_meta, z_scored=True, train_test='train'))
                    if test_exists == False:
                        combined_nwb_test.add_scratch(
                            psth_QC[test_ids,:,:,:],
                            name=f"PSTHs_QualityApproved_{os.path.basename(exp_nwb_path).split('.')[-3]}",
                            description=description_QC(psth_meta, z_scored=False, train_test='test'))
                        combined_nwb_test.add_scratch(
                            psth_zscored_QC[test_ids,:,:,:],
                            name=f"PSTHs_QualityApproved_ZScored_{os.path.basename(exp_nwb_path).split('.')[-3]}",
                            description=description_QC(psth_meta, z_scored=True, train_test='test'))
                    
                except Exception as error:
                    print("An error occurred:", error) 
                    io.close()
                    pass_saving = True
                    update_sheet(df, experiment_path, error, 'prom')
                    continue 

                io.close()
    

        # ------------------------------------------------------------------------------ 
        # Combine all sessions using the common good channels. This step is also performed 
        # if only one session is available, to keep the standardized format of the prom.nwb files.
        # If there is a day, which has very low number of good channels, we will discard that 
        # session in the prom psth. 
        # ------------------------------------------------------------------------------
        if pass_saving == True: continue
        
        prom_psth_QC          = []
        prom_psth_QC_train    = []
        prom_psth_QC_test     = []
            
        prom_psth_zscored_QC          = []
        prom_psth_zscored_QC_train    = []
        prom_psth_zscored_QC_test     = []

        prom_psth_normalizers         = []

        if min_QC_channel_tresh_local == 0.0: min_QC_channel_tresh_local = min_QC_channel_tresh

        # find the session id's to keep, i.e. sessions which have enough good channels
        n_channels    = len(combined_good_channel_ids[0])
        min_nchannels = int(n_channels*min_QC_channel_tresh_local)
        proc_session_mask_keep  = np.full(len(combined_good_channel_ids), False, dtype=bool) 
        for i_session in range(len(combined_good_channel_ids)):
            if sum(combined_good_channel_ids[i_session]) >= min_nchannels: proc_session_mask_keep[i_session] = True

           
        combined_good_channel_ids_keep = [array for array, include in zip(combined_good_channel_ids, proc_session_mask_keep) if include]
        common_QC_channels             = np.logical_and.reduce(combined_good_channel_ids_keep)

        # Reduce min_QC_channel_tresh if the number of common quality approved channels is lower than min_number_QC_channel_tresh
        if common_QC_channels.sum() < min_number_QC_channel_tresh:
            print(f'    Reducing the threshold to {min_QC_channel_tresh_local-0.05}, because number of common QC channels is {common_QC_channels.sum()}')
            min_QC_channel_tresh_local = min_QC_channel_tresh_local-0.05
            n_channels    = len(combined_good_channel_ids[0])
            min_nchannels = int(n_channels*min_QC_channel_tresh_local)
            proc_session_mask_keep  = np.full(len(combined_good_channel_ids), False, dtype=bool) 
            for i_session in range(len(combined_good_channel_ids)):
                if sum(combined_good_channel_ids[i_session]) >= min_nchannels: proc_session_mask_keep[i_session] = True
                
            combined_good_channel_ids_keep = [array for array, include in zip(combined_good_channel_ids, proc_session_mask_keep) if include]
            common_QC_channels             = np.logical_and.reduce(combined_good_channel_ids_keep)
        
        if common_QC_channels.sum() < min_number_QC_channel_tresh: pass_saving == True
        
        print(f'    Per Session QC channel threshold: {min_QC_channel_tresh_local*100} %')
        print('    Number of common QC channels: ', common_QC_channels.sum())

        if pass_saving == True: continue
            
        first_session = True
        if n_sessions > 0:
            for psth_QC, psth_zscored_QC, psth_normalizer, i in zip(all_psth, all_psth_zscored, all_psth_normalizers, range(len(all_psth))):
                if proc_session_mask_keep[i] == True:
                    if first_session == True:
                        prom_psth_QC       = psth_QC[:,:,:,common_QC_channels]
                        prom_psth_QC_train = prom_psth_QC[list(train_ids),:,:,:]
                        prom_psth_QC_test  = prom_psth_QC[list(test_ids),:,:,:]

                        prom_psth_zscored_QC       = psth_zscored_QC[:,:,:,common_QC_channels]
                        prom_psth_zscored_QC_train = prom_psth_zscored_QC[list(train_ids),:,:,:]
                        prom_psth_zscored_QC_test  = prom_psth_zscored_QC[list(test_ids),:,:,:]

                        prom_psth_normalizers = psth_normalizer

                        first_session = False
                    else:
                        prom_psth_QC_temp          = psth_QC[:,:,:,common_QC_channels]
                        prom_psth_zscored_QC_temp  = psth_zscored_QC[:,:,:,common_QC_channels]

                        prom_psth_QC       = np.hstack((prom_psth_QC,       prom_psth_QC_temp))
                        prom_psth_QC_train = np.hstack((prom_psth_QC_train, prom_psth_QC_temp[list(train_ids),:,:,:]))
                        prom_psth_QC_test  = np.hstack((prom_psth_QC_test,  prom_psth_QC_temp[list(test_ids),:,:,:]))
            
                        prom_psth_zscored_QC       = np.hstack((prom_psth_zscored_QC,       prom_psth_zscored_QC_temp))
                        prom_psth_zscored_QC_train = np.hstack((prom_psth_zscored_QC_train, prom_psth_zscored_QC_temp[list(train_ids),:,:,:]))
                        prom_psth_zscored_QC_test  = np.hstack((prom_psth_zscored_QC_test,  prom_psth_zscored_QC_temp[list(test_ids),:,:,:]))

                        prom_psth_normalizers = np.hstack((prom_psth_normalizers,psth_normalizer))
        if n_sessions == 1:
            proc_session_mask_keep[0] = True
            common_QC_channels = combined_good_channel_ids[0]
            psth_QC         = all_psth[0]
            psth_zscored_QC = all_psth_zscored[0]
            psth_normalizer = all_psth_normalizers[0]
            
            prom_psth_QC       = psth_QC[:,:,:,common_QC_channels]
            prom_psth_QC_train = prom_psth_QC[list(train_ids),:,:,:]
            prom_psth_QC_test  = prom_psth_QC[list(test_ids),:,:,:]

            prom_psth_zscored_QC       = psth_zscored_QC[:,:,:,common_QC_channels]
            prom_psth_zscored_QC_train = prom_psth_zscored_QC[list(train_ids),:,:,:]
            prom_psth_zscored_QC_test  = prom_psth_zscored_QC[list(test_ids),:,:,:]

            prom_psth_normalizers = psth_normalizer
    
        print("---------- PROM FILES: -----------------------------------------")
        print(f'    prom_QC shape: {prom_psth_QC.shape}, prom_QC_train shape: {prom_psth_QC_train.shape}, prom_QC_test shape: {prom_psth_QC_test.shape}')
        print(f'    prom_zscored_QC shape: {prom_psth_zscored_QC.shape}, prom_zscored_QC_train shape: {prom_psth_zscored_QC_train.shape}, prom_zscored_QC_test shape: {prom_psth_zscored_QC_test.shape}')
        print(f'    normalizers shape: {prom_psth_normalizers.shape}')
        
        
        # ------------------------------------------------------------------------------ 
        # Add masks and stimIDs to to combined nwb files.
        # ------------------------------------------------------------------------------ 
        def description_prom_QC(psth_meta, combined_date_times, z_scored): 
            if z_scored == True: psth_array_type = 'Z-scored '
            else: psth_array_type =''
            return f"Array of shape [stimuli x all repetitions x time bins x logical 'and' of quality approved channels], \
                containing all {psth_array_type}PSTHs, which have at least {min_QC_channel_tresh_local*100} % quality approved channels, stacked in one matrix.\
                All repetitions are stacked, and the quality approved channels \
                of the recordings are combined using a logical 'and' operation, i.e., \
                common_quality_approved_channels = np.logical_and.reduce(QualityApprovedChannelMasks[days with at least {min_QC_channel_tresh_local*100} % quality approved channels]). \
                The resulting common_quality_approved_channels is used to mask the channel dimensions. \
                Notably, if n_sessions equals 1, the merged PSTH will include the PSTHs_QualityApproved with all quality approved channels,\
                regardless of whether this number meets the {min_QC_channel_tresh_local*100} % threshold.\
                This array contains the quality approved recordings from the following Dates & Times: \
                {[array for array, include in zip(combined_date_times, proc_session_mask_keep) if include]}\
                The PSTH meta are the following for [start_time_ms, stop_time_ms, tb_ms]: {psth_meta}"
        
        descripion_QualityApprovedChannelMasks=f"List of boolean arrays, indexed by recording date & time,\
            each of num channel length, marking 'quality approved channels' as True for p-value < {p_value_threshold}.\
            Dates & times covered:{combined_date_times}"
        description_StimulusID=f"List of array containting the stimulus IDs that directly correspond to the stimulus \
            indexes in PSTHs_QualityApproved and PSTHs_QualityApproved_ZScored. Each entry in the array aligns with the respective stimulus ID in the PSTH;\
            for example, the first entry in the array corresponds to the stimulus ID of the first entry in the PSTH, and so on.\
            The corresponding stimulus for each stimulus ID is located in the respective stimulus set."
        description_normalizer=f"Array of shape [stimuli x all repetitions x time bins x channels], \
            containing PSTH data from the normalizer files, used in the quality check and z-scoring.\
            Normalizer Dates & Times: {combined_normalizer_date_times}"

        
        if combined_exists==False: 
            combined_nwb.add_scratch(
                combined_good_channel_ids,
                name=f"QualityApprovedChannelMasks",
                description=descripion_QualityApprovedChannelMasks)
                
            combined_nwb.add_scratch(
                prom_psth_QC,
                name=f"PSTHs_QualityApproved_SessionMerged",
                description=description_prom_QC(psth_meta, combined_date_times, z_scored=False))
                    
            combined_nwb.add_scratch(
                prom_psth_zscored_QC,
                name=f"PSTHs_QualityApproved_ZScored_SessionMerged",
                description=description_prom_QC(psth_meta, combined_date_times, z_scored=True))
                        
            combined_nwb.add_scratch(
                prom_psth_normalizers,
                name=f"PSTHs_Normalizers_SessionMerged",
                description=description_normalizer)

        if train_exists == False: 
            combined_nwb_train.add_scratch(
                combined_good_channel_ids,
                name=f"QualityApprovedChannelMasks",
                description=descripion_QualityApprovedChannelMasks)
        
            combined_nwb_train.add_scratch(
                train_ids,
                name=f"StimuliIDs",
                description=description_StimulusID)
                
            combined_nwb_train.add_scratch(
                prom_psth_QC_train,
                name=f"PSTHs_QualityApproved_SessionMerged",
                description=description_prom_QC(psth_meta, combined_date_times, z_scored=False))
                    
            combined_nwb_train.add_scratch(
                prom_psth_zscored_QC_train,
                name=f"PSTHs_QualityApproved_ZScored_SessionMerged",
                description=description_prom_QC(psth_meta, combined_date_times, z_scored=True))

            combined_nwb_train.add_scratch(
                prom_psth_normalizers,
                name=f"PSTHs_Normalizers_SessionMerged",
                description=description_normalizer)

        if test_exists==False:
            combined_nwb_test.add_scratch(
                combined_good_channel_ids,
                name=f"QualityApprovedChannelMasks",
                description=descripion_QualityApprovedChannelMasks)
        
            combined_nwb_test.add_scratch(
                test_ids,
                name=f"StimuliIDs",
                description=description_StimulusID)
        
            combined_nwb_test.add_scratch(
                prom_psth_QC_test,
                name=f"PSTHs_QualityApproved_SessionMerged",
                description=description_prom_QC(psth_meta, combined_date_times, z_scored=False))

            combined_nwb_test.add_scratch(
                prom_psth_zscored_QC_test,
                name=f"PSTHs_QualityApproved_ZScored_SessionMerged",
                description=description_prom_QC(psth_meta, combined_date_times, z_scored=True))

            combined_nwb_test.add_scratch(
                prom_psth_normalizers,
                name=f"PSTHs_Normalizers_SessionMerged",
                description=description_normalizer)        
        
        update_sheet(df, experiment_path, 'Done.', 'prom')
        update_sheet(df, experiment_path, 'Done.', 'prom train')
        update_sheet(df, experiment_path, 'Done.', 'prom test')
        update_sheet(df, experiment_path, removed_nan_reps, 'removed nan reps/session')
        update_sheet(df, experiment_path, common_QC_channels.sum() , 'number of qc channels')
        update_sheet(df, experiment_path, str([array for array, include in zip(combined_date_times, proc_session_mask_keep) if include]) , 'sessions kept')
        
        print(' Sessions kept: ', [array for array, include in zip(combined_date_times, proc_session_mask_keep) if include])
        # ------------------------------------------------------------------------------ 
        # Save experiment nwb file.
        # ------------------------------------------------------------------------------ 
        print('... saving combined NWB Files.')

        if combined_exists==False: 
            io = NWBHDF5IO(os.path.join(experiment_path, f"{os.path.basename(experiment_path)}.prom.nwb"), "w") 
            io.write(combined_nwb)
            io.close()
            print("Combined file saved.")
        
        if train_exists == False: 
            io = NWBHDF5IO(os.path.join(experiment_path, f"{os.path.basename(experiment_path)}.prom_train.nwb"), "w") 
            io.write(combined_nwb_train)
            io.close()
            print("Train file saved.")
        
        if test_exists==False:
            io = NWBHDF5IO(os.path.join(experiment_path, f"{os.path.basename(experiment_path)}.prom_test.nwb"), "w") 
            io.write(combined_nwb_test)
            io.close()
            print("Test file saved.")

        # display(combined_nwb)
        # display(combined_nwb_train)
        # display(combined_nwb_test)
                 

________________________________________________________________________________
exp_muri1320.sub_pico has 11 days and 12 sessions
     Across: using normalizer file norm_FOSS.sub_pico.20220615_113442.proc.nwb and norm_FOSS.sub_pico.20220706_141433.proc.nwb 
Number of channels do not match. 288 != 192 
Across: using normalizer file norm_FOSS.sub_pico.20220615_113442.proc.nwb and norm_FOSS.sub_pico.20220706_142235.proc.nwb 
'psth'
No psth available for normalizers ('20220615', '20220706').

________________________________________________________________________________
exp_ko_context_size.sub_pico has 1 days and 1 sessions
Loading files from day 1/1


Saving combined NWB Files.
Combined file saved.
Train file saved.
Test file saved.
________________________________________________________________________________
exp_robustness_guy_d1_v40.sub_pico has 1 days and 3 sessions
Loading files from day 1/1
Within: using normalizer file norm_FOSS.sub_pico.20230928_101016.proc.nwb


  n_p = np.nanmean(normalizer_p_no_grey[:, :, t_cols, :], 2) # then mean 70-170 time bins


Normalized and Original PSTH (330, 50, 30, 192) (330, 50, 30, 192)
Within: using normalizer file norm_FOSS.sub_pico.20230928_101016.proc.nwb
Normalized and Original PSTH (330, 4, 30, 192) (330, 4, 30, 192)
Within: using normalizer file norm_FOSS.sub_pico.20230928_101016.proc.nwb
Normalized and Original PSTH (330, 11, 30, 192) (330, 11, 30, 192)
Saving combined NWB Files.
Combined file saved.
Train file saved.
Test file saved.
________________________________________________________________________________
exp_muri1320-2023-v1.sub_pico has 4 days and 8 sessions
Loading files from day 1/4
Within: using normalizer file norm_FOSS.sub_pico.20230127_160227.proc.nwb
Normalized and Original PSTH (1320, 4, 30, 192) (1320, 4, 30, 192)
Within: using normalizer file norm_FOSS.sub_pico.20230127_160227.proc.nwb
Normalized and Original PSTH (1320, 5, 30, 192) (1320, 5, 30, 192)
Loading files from day 2/4
Within: using normalizer file norm_FOSS.sub_pico.20230130_140402.proc.nwb
Normalized and Original

  n_p = np.nanmean(normalizer_p_no_grey[:, :, t_cols, :], 2) # then mean 70-170 time bins


Normalized and Original PSTH (300, 8, 200, 192) (300, 8, 200, 192)
Loading files from day 2/4
Within: using normalizer file norm_HVM.sub_pico.20230502_145301.proc.nwb
Normalized and Original PSTH (300, 8, 200, 192) (300, 8, 200, 192)
Loading files from day 3/4
An error occurred: 'psth'
Loading files from day 4/4
Within: using normalizer file norm_HVM.sub_pico.20230504_114437.proc.nwb
Normalized and Original PSTH (300, 8, 200, 192) (300, 8, 200, 192)
________________________________________________________________________________
exp_IAPS.sub_pico has 2 days and 2 sessions
Loading files from day 1/2
Within: using normalizer file norm_FOSS.sub_pico.20230517_115439.proc.nwb
Normalized and Original PSTH (1183, 15, 30, 192) (1183, 15, 30, 192)
Loading files from day 2/2
Within: using normalizer file norm_FOSS.sub_pico.20230518_103908.proc.nwb
Normalized and Original PSTH (1183, 15, 30, 192) (1183, 15, 30, 192)
Saving combined NWB Files.
Combined file saved.
Train file saved.
Test file saved

In [None]:
# Update Sheet 2
xls = pd.ExcelFile(f'{os.path.dirname(cwd)}/pico_inventory.xlsx')
sheets = {sheet: xls.parse(sheet) for sheet in xls.sheet_names}

sheets['Sheet2'] = df  

with pd.ExcelWriter(f'{os.path.dirname(cwd)}/pico_inventory.xlsx', engine='openpyxl', mode='w') as writer:
    for sheet_name, sheet_df in sheets.items():
        sheet_df.to_excel(writer, sheet_name=sheet_name, index=False)    

### Step 3: Validate the experiment nwb files

In [None]:
experiment_file_paths = glob.glob(os.path.join(root_dir, '[exp]*', '*'))
for experiment_path in experiment_file_paths:
    if os.path.basename(experiment_path).startswith('exp'): 
        path = os.path.join(experiment_path, f"{os.path.basename(experiment_path)}_combined.nwb")
        if os.path.isfile(path):
            try:
                    io = NWBHDF5IO(path, "r") 
                    nwbfile = io.read()
                    display(nwbfile)
                    io.close()
                    break
            except: print(f'This File can not be opened: {os.path.basename(experiment_path)}')
            
        else: print(f'No combined nwb found in: {os.path.basename(experiment_path)}')

