### Step 1: Load modules and functions

In [2]:
import xarray as xr
import h5py
from brainio.assemblies import NeuronRecordingAssembly
from pynwb import NWBHDF5IO, NWBFile
from pynwb.base import Images
from pynwb.image import RGBImage, ImageSeries
import glob, os, yaml, pynwb
import pytz  # This is required to handle timezone conversions
from datetime import datetime
from uuid import uuid4
import numpy as np
import scipy.io
import os, glob, json
import pandas as pd
from pynwb.file import Subject
import logging, sys, re
from PIL import Image
import shutil
import textwrap
import matplotlib.pyplot as plt
from IPython.display import display as display_image
import random
import hashlib

cwd = os.getcwd()
sys.path.append(os.path.dirname(cwd))
root_dir        = '/braintree/home/aliya277/inventory_new'
df = pd.read_excel( os.path.dirname(cwd)+'/pico_inventory.xlsx' , sheet_name='Sheet2')


In [25]:
def update_sheet(df, exp_path, location, text):
    imageset = os.path.basename(exp_path).split('.')[0].split('_')[1:]
    if len(imageset) == 1: imageset = imageset[0]
    elif len(imageset) > 1: imageset = '_'.join(imageset)
    mask = df['ImageSet'] == imageset
    index = df.index[mask].tolist()[0]
    df.at[index, location] = text

def extract_number(filename):
    # Extract the number from the filename and return it as an integer
    match = re.search(r'\d+', filename)
    return int(match.group()) if match else 0

def update_prom_nwb(experiment_path, experiment_name, list_images_sorted, stimpath, count_1, check_image_order=True):

    prom        = [x for x in os.listdir(experiment_path) if x.endswith('.prom.nwb')]
    prom_test   = [x for x in os.listdir(experiment_path) if x.endswith('.prom_test.nwb')]
    prom_train  = [x for x in os.listdir(experiment_path) if x.endswith('.prom_train.nwb')]

    combined = False
    train    = False
    test     = False

    StimuliIDs_train = None
    StimuliIDs_test  = None

    if len(prom) != 0: 

        # ------------------------------------------------------------------------------ 
        # Load combined nwb file.
        # ------------------------------------------------------------------------------ 
        io = NWBHDF5IO(os.path.join(experiment_path, prom[0]), "a") 
        combined_nwb = io.read()
        try: 
            combined_nwb.stimulus_template['StimulusSet']
            print('Simulus Set already exists in combined.')
            update_sheet(df, experiment_path, 'StimulusSet prom', 'Done')
            # display(combined_nwb)
            combined = True
        except: pass

        n_stimuli = []
        for scratch in list(combined_nwb.scratch):
            if scratch.startswith('PSTHs_QualityApproved'):
                n_stimuli.append(combined_nwb.scratch[scratch][:].shape[0])
        assert all(element == n_stimuli[0] for element in n_stimuli) == True, 'Number of Stimuli are not consistent over the PSTH!'
        n_stimuli = n_stimuli[0]
        # print(n_stimuli)
        
        # ------------------------------------------------------------------------------ 
        # Load combined_train nwb file.
        # ------------------------------------------------------------------------------ 
        io_train = NWBHDF5IO(os.path.join(experiment_path, prom_train[0]), "a") 
        combined_nwb_train = io_train.read()
        try: 
            combined_nwb_train.stimulus_template['StimulusSet']
            print('Simulus Set already exists in train.')
            update_sheet(df, experiment_path, 'StimulusSet prom train', 'Done')
            # display(combined_nwb_train)
            train = True
        except: pass

        n_stimuli_train = []
        for scratch in list(combined_nwb_train.scratch):
            if scratch.startswith('PSTHs_QualityApproved'):
                n_stimuli_train.append(combined_nwb_train.scratch[scratch][:].shape[0])
            if scratch.startswith('StimuliIDs'):
                StimuliIDs_train = combined_nwb_train.scratch[scratch][:]
        assert all(element == n_stimuli_train[0] for element in n_stimuli_train) == True, 'Number of Stimuli are not consistent over the PSTH!'
        n_stimuli_train = n_stimuli_train[0]
        # print(n_stimuli_train, len(StimuliIDs_train))

        # ------------------------------------------------------------------------------ 
        # Load combined_test nwb file.
        # ------------------------------------------------------------------------------ 
        io_test = NWBHDF5IO(os.path.join(experiment_path, prom_test[0]), "a") 
        combined_nwb_test = io_test.read()
        try: 
            combined_nwb_test.stimulus_template['StimulusSet']
            print('Simulus Set already exists in test.')
            update_sheet(df, experiment_path, 'StimulusSet prom test', 'Done')
            # display(combined_nwb_test)
            test = True
        except: pass

        n_stimuli_test = []
        for scratch in list(combined_nwb_test.scratch):
            if scratch.startswith('PSTHs_QualityApproved'):
                n_stimuli_test.append(combined_nwb_test.scratch[scratch][:].shape[0])
            if scratch.startswith('StimuliIDs'):
                StimuliIDs_test = combined_nwb_test.scratch[scratch][:]

        assert all(element == n_stimuli_test[0] for element in n_stimuli_test) == True, 'Number of Stimuli are not consistent over the PSTH!'
        n_stimuli_test = n_stimuli_test[0]
        # print(n_stimuli_test, len(StimuliIDs_test))

        # ------------------------------------------------------------------------------ 
        # Create StimulusSets.
        # ------------------------------------------------------------------------------ 
        if combined == False or train == False or test == False: 
            # ------------------------------------------------------------------------------ 
            # VideoStimulusSets
            # ------------------------------------------------------------------------------ 
            if stimpath == None:
                external_file_prom = [os.path.join('../VideoStimulusSet', movie_path) for movie_path in list_images_sorted]

                external_file_train = [external_file_prom[i] for i in StimuliIDs_train]
                external_file_test  = [external_file_prom[i] for i in StimuliIDs_test]

                print(external_file_prom)
                print('------')
                print(StimuliIDs_train)
                print(external_file_train)
                print(StimuliIDs_test)
                print(external_file_test)

                def StimulusSetDescription_test_train():
                    return f"This list references external files linking to movies that form the stimulus set, with each movie uniquely identified \
                    by a stimulus ID noted in the field 'starting_frame'. The 'StimuliIDs' array acts as a crucial link, mapping \
                    each PSTH entry to its corresponding movie within this set. If the PSTHs represent responses to 'n' \
                    different stimuli, the 'StimuliIDs' array will have 'n' entries, sequentially aligning each PSTH \
                    entry with its respective stimulus ID (e.g., the first PSTH entry is linked to the stimulus ID \
                    described in the first entry of the StimuliIDs' array). The movies in the stimulus set are \
                    organized in ascending order by their stimulus IDs. \
                    Note: The original function of 'starting_frame' is altered for this specific use case."
            

                # ------------------------------------------------------------------------------ 
                # StimulusSet for combined.
                # ------------------------------------------------------------------------------ 
                StimulusMovieFile_combined = ImageSeries(
                    name="StimulusSet",
                    description=f"This list references external files linking to movies that form the stimulus set, with each movie uniquely identified \
                        by a stimulus ID noted in the field 'starting_frame'. Each movie has an associated 'starting_frame' field, repurposed to serve as the \
                        stimulusID in PSTHs. Filenames with sequential numbers starting from 0 align with these \
                        StimulusIDs. The arrangement ensures a one-to-one correspondence between \
                        the sequence of PSTH entries and the stimulus IDs; the first PSTH entry corresponds to the first stimulus ID, and so forth. \
                        The movies in the stimulus set are organized in ascending order by their stimulus ID. \
                        Note: The original function of 'starting_frame' is altered for this specific use case.", 
                    unit="n.a.",
                    external_file=external_file_prom,
                    format="external", 
                    rate=0.0, 
                    starting_frame = np.linspace(0, n_stimuli-1, n_stimuli))

                # ------------------------------------------------------------------------------ 
                # StimulusSet for train.
                # ------------------------------------------------------------------------------ 
                StimulusMovieFile_train = ImageSeries(
                    name="StimulusSet",
                    description=StimulusSetDescription_test_train(),
                    unit="n.a.",
                    external_file=external_file_train,
                    format="external", 
                    rate=0.0, 
                    starting_frame = StimuliIDs_train)
                
                # ------------------------------------------------------------------------------ 
                # StimulusSet for test.
                # ------------------------------------------------------------------------------ 
                StimulusMovieFile_test = ImageSeries(
                    name="StimulusSet",
                    description=StimulusSetDescription_test_train(),
                    unit="n.a.",
                    external_file=external_file_test,
                    format="external", 
                    rate=0.0, 
                    starting_frame = StimuliIDs_test)
                    
                assert n_stimuli == len(external_file_prom), 'Number of Stimuli in prom does not match number of Images!'
                assert n_stimuli_train == len(external_file_train), 'Number of Stimuli in prom train does not match number of Images!'
                assert n_stimuli_test == len(external_file_test), 'Number of Stimuli in prom test does not match number of Images!'
                
                # ------------------------------------------------------------------------------ 
                # Append to nwb files.
                # ------------------------------------------------------------------------------ 
                if combined == False: 
                    try: 
                        combined_nwb.add_stimulus_template(timeseries=StimulusMovieFile_combined, use_sweep_table=False) 
                        print(f"Added StimulusSet to combined nwb.")
                        display(combined_nwb)
                        io.write(combined_nwb)
                        io.close()  
                        update_sheet(df, experiment_path, 'StimulusSet prom', 'Done')
                    except Exception as error: 
                        print("An error occurred:", error) 
                        update_sheet(df, experiment_path, 'StimulusSet prom', error)

                if train == False:
                    try: 
                        combined_nwb_train.add_stimulus_template(timeseries=StimulusMovieFile_train, use_sweep_table=False)  
                        print(f"Added StimulusSet to train nwb.")
                        display(combined_nwb_train)
                        io_train.write(combined_nwb_train)
                        io_train.close()    
                        update_sheet(df, experiment_path, 'StimulusSet prom train', 'Done')
                    except Exception as error: 
                        print("An error occurred:", error) 
                        io_train.close() 
                        update_sheet(df, experiment_path, 'StimulusSet prom train', error)

                if test == False:
                    try: 
                        combined_nwb_test.add_stimulus_template(timeseries=StimulusMovieFile_test, use_sweep_table=False) 
                        print(f"Added StimulusSet to test nwb.")
                        display(combined_nwb_test)
                        io_test.write(combined_nwb_test)
                        io_test.close()  
                        update_sheet(df, experiment_path, 'StimulusSet prom test', 'Done')
                    except Exception as error: 
                        print("An error occurred:", error)    
                        io_test.close() 
                        update_sheet(df, experiment_path, 'StimulusSet prom test', error)


            # ------------------------------------------------------------------------------ 
            # Image StimulusSets
            # ------------------------------------------------------------------------------ 
            else:
                
                list_images         = []
                list_images_train   = []
                list_images_test    = []
                
                for temp, image in enumerate(list_images_sorted):
                    image_counter = temp
                    if count_1 ==1: image_counter = image_counter+1

                    expected_image_name = f'im{image_counter}'
                    if check_image_order:
                        if image.split(".")[0] != expected_image_name and image.split(".")[0] !=f'{image_counter}' and image.split(".")[0] !=f'image{image_counter}' and image.split(".")[0] !=f'im{image_counter}_scrambled' and image.split(".")[0] !=f'{image_counter}':
                                print(f'Image names do not increase with +1!!! Expected: {expected_image_name} or image{image_counter} or {image_counter}, Found: {image.split(".")[0]}')

                    path = os.path.join(stimpath, image)
                    img = Image.open(path)  # an example image
                    
                    image_file_name = f'exp_{experiment_name}_{temp}.png'

                    if img.mode != 'RGB':
                        data=np.array(img.convert("RGB"))
                    else: 
                        try:
                            data=np.array(img)
                        except: 
                            data_ = np.zeros(data.shape) #dummy image
                            data = data_
                            image_file_name = f'corrupted_png_{temp}.png'

                    nwb_image = RGBImage(
                        name= image_file_name,
                        data = data,                            
                        resolution=0.0,
                        description= f"StimulusID = {temp}",
                    )
                    list_images.append(nwb_image)
                
                print(nwb_image.description, image)

                list_images_train = [list_images[i] for i in StimuliIDs_train]
                list_images_test  = [list_images[i] for i in StimuliIDs_test]
                
                assert n_stimuli == len(list_images), 'Number of Stimuli does not match number of Images!'

                # ------------------------------------------------------------------------------ 
                # Create nwb Images and append.
                # ------------------------------------------------------------------------------ 
                def StimulusSetDescription_test_train():
                    return f"This list contains images that form the stimulus set, with each image uniquely identified \
                    by a stimulus ID noted in its description. The 'StimuliIDs' array acts as a crucial link, mapping \
                    each PSTH entry to its corresponding image within this set. If the PSTHs represent responses to 'n' \
                    different stimuli, the 'StimuliIDs' array will have 'n' entries, sequentially aligning each PSTH \
                    entry with its respective stimulus ID (e.g., the first PSTH entry is linked to the stimulus ID \
                    described in the first entry of the StimuliIDs' array). The images in the stimulus set are \
                    organized in ascending order by their stimulus IDs."

                if combined == False: 
                    all_images = Images(
                        name=f'StimulusSet',
                        images= list_images ,
                        description= f"This list contains images that form the stimulus set, with each image uniquely identified \
                        by a stimulus ID noted in its description. The arrangement ensures a one-to-one correspondence between \
                        the sequence of PSTH entries and the stimulus IDs; the first PSTH entry corresponds to the first stimulus ID, and so forth. \
                        The images in the stimulus set are organized in ascending order by their stimulus ID.")
                    try: 
                        combined_nwb.add_stimulus_template(timeseries=all_images, use_sweep_table=False) 
                        print(f"Added StimulusSet to combined nwb.")
                        # display(combined_nwb)
                        io.write(combined_nwb)
                        io.close()  
                        update_sheet(df, experiment_path, 'StimulusSet prom', 'Done')
                    except Exception as error: 
                        print("An error occurred:", error) 
                        io.close()  
                        update_sheet(df, experiment_path, 'StimulusSet prom', error)

                if train == False:
                    train_images = Images(
                        name=f'StimulusSet',
                        images= list_images_train ,
                        description= StimulusSetDescription_test_train(),
                    )
                    try: 
                        combined_nwb_train.add_stimulus_template(timeseries=train_images, use_sweep_table=False)  
                        print(f"Added StimulusSet to train nwb.")
                        # display(combined_nwb_train)
                        io_train.write(combined_nwb_train)
                        io_train.close()    
                        update_sheet(df, experiment_path, 'StimulusSet prom train', 'Done')
                    except Exception as error: 
                        print("An error occurred:", error) 
                        io_train.close() 
                        update_sheet(df, experiment_path, 'StimulusSet prom train', error)

                if test == False:
                    test_images = Images(
                        name=f'StimulusSet',
                        images= list_images_test ,
                        description= StimulusSetDescription_test_train(),
                    )
                    try: 
                        combined_nwb_test.add_stimulus_template(timeseries=test_images, use_sweep_table=False) 
                        print(f"Added StimulusSet to test nwb.")
                        # display(combined_nwb_test)
                        io_test.write(combined_nwb_test)
                        io_test.close()  
                        update_sheet(df, experiment_path, 'StimulusSet prom test', 'Done')
                    except Exception as error: 
                        print("An error occurred:", error)    
                        io_test.close() 
                        update_sheet(df, experiment_path, 'StimulusSet prom test', error)
       

### Step 2: Create per experiment stimulus set and add them to the per experiment nwb files.

In [None]:
# ------------------------------------------------------------------------------ 
# Add Stimulus Sets to each prom nwb file going on BrainScore. This Cell needs to be run 3 times (due to nature of nwb_images)
# ------------------------------------------------------------------------------ 
list_of_bs_exp_names = []
df['StimulusSet prom'] = ''
df['StimulusSet prom test'] = ''
df['StimulusSet prom train'] = ''
df['StimulusSetPath'] = ''
for index, row in df.iterrows():
    if row['BrainScore']=='Y': list_of_bs_exp_names.append(row['ImageSet'])
    
experiment_file_paths = glob.glob(os.path.join(root_dir, '[exp]*', '*'))
stimulus_dir          = '/braintree/data2/active/users/sgouldin/experiments-codebase'
stimuli_names         = os.listdir(stimulus_dir)

for experiment_path in experiment_file_paths: 
    experiment_name =  "_".join(os.path.basename(experiment_path).split('.')[0].split('_')[1:])


    # if not experiment_name.startswith('gratingsAdap_s'): continue    
    
    if experiment_name not in list_of_bs_exp_names: 
        continue 

    print('________________________________________________________________________________')
    print(experiment_name)
    
    # ------------------------------------------------------------------------------ 
    # Find Stimulus Directory name for each experiment. 
    # ------------------------------------------------------------------------------ 
    if experiment_name == 'domain-transfer-2023':
        stim_name = [x for x in stimuli_names if x.endswith('domain_transfer') and not x.startswith('.')]
    elif experiment_name == 'HVM-var6-2023':
        stim_name = [x for x in stimuli_names if x.endswith('HVM_var6') and not x.startswith('.')]
    elif experiment_name.startswith('gratingsAdap_'):
        stim_name = [x for x in stimuli_names if x.endswith('gratingsAdap') and not x.startswith('.')]
    elif experiment_name.startswith('gestalt'):
        stim_name = [x for x in stimuli_names if x.endswith('Gestalt') and not x.startswith('.')]
    elif experiment_name.startswith('object_relations'):
        stim_name = [x for x in stimuli_names if x.endswith('ObjectRelationships') and not x.startswith('.')]
    elif experiment_name.startswith('1_shapes'):
        stim_name = [x for x in stimuli_names if x.endswith('shapes') and not x.startswith('.')]
    elif experiment_name.startswith('food'):
        stim_name = [x for x in stimuli_names if x.endswith('Food') and not x.startswith('.')]
    elif experiment_name.startswith('shapenet360'):
        stim_name = [x for x in stimuli_names if x.endswith('ShapeNet360') and not x.startswith('.')]
    elif experiment_name.startswith('shapegen'):
        stim_name = [x for x in stimuli_names if x.endswith('ShapeGens') and not x.startswith('.')]
    elif experiment_name.startswith('sine_wave'):
        stim_name = [x for x in stimuli_names if x.endswith('sinewave_fullfield') and not x.startswith('.')]
    elif experiment_name.startswith('square_sinewave'):
        stim_name = [x for x in stimuli_names if x.endswith('squarewave_fullfield') and not x.startswith('.')]
    elif experiment_name.endswith('oasis900'):
        stim_name = [x for x in stimuli_names if x.endswith('Oasis900') and not x.startswith('.')]
    elif experiment_name.startswith('oasis900_200'):
        stim_name = [x for x in stimuli_names if x.endswith('Oasis900') and not x.startswith('.')]
    elif experiment_name.startswith('oasis100'):
        stim_name = [x for x in stimuli_names if x.endswith('OASIS100_control') and not x.startswith('.')]
    elif experiment_name.startswith('oasis900rotated'):
        stim_name = [x for x in stimuli_names if x.endswith('OasisRotated') and not x.startswith('.')]
    elif experiment_name.startswith('oasis900scrambled'):
        stim_name = [x for x in stimuli_names if x.endswith('OasisScramble') and not x.startswith('.')]
    elif experiment_name.startswith('IAPS-200on'):
        stim_name = [x for x in stimuli_names if x.endswith('IAPS') and not x.startswith('.')]
    elif experiment_name.startswith('flicker'):
        stim_name = [x for x in stimuli_names if x.startswith('flicker') and not x.startswith('.')]
    elif experiment_name.startswith('ko_context_size'):
        stim_name = [f'{x}/ko_context_size' for x in stimuli_names if x.startswith('old_rig1') and not x.startswith('.')]
    elif experiment_name.startswith('muri'):
        stim_name = [f'{x}/RSVP-MURI1320' for x in stimuli_names if x.startswith('old_rig1') and not x.startswith('.')]
    else:
        stim_name = [x for x in stimuli_names if x.endswith(experiment_name) and not x.startswith('.')]

    if len(stim_name)==0: 
        print(f'    No Stim found: {experiment_name}') 
        update_sheet(df, experiment_path, 'StimulusSet prom', 'No StimulusSet found.')
        continue

    files_starting_with_vid = [file for file in os.listdir(os.path.join(stimulus_dir, stim_name[0])) if file.startswith('vid')]
    ImageStimSetPath = None
    VideoStimSetPath = None

   
    # ------------------------------------------------------------------------------ 
    # Find 'images' folder in Stimulus Directory. (standard case)
    # ------------------------------------------------------------------------------ 
    if 'images' in os.listdir(os.path.join(stimulus_dir, stim_name[0])): 
        ImageStimSetPath = os.path.join(stimulus_dir, stim_name[0], 'images')
    
    # ------------------------------------------------------------------------------ 
    # Find 'vid...' folder.
    # ------------------------------------------------------------------------------ 
    elif len(files_starting_with_vid) > 0: 
        VideoStimSetPath = os.path.join(stimulus_dir, stim_name[0], files_starting_with_vid[0])

    # ------------------------------------------------------------------------------ 
    # Manually find folders for gratingsAdap.
    # ------------------------------------------------------------------------------ 
    elif experiment_name.startswith('gratingsAdap_'):
        season = experiment_name.split('_')[-1][-1]
        file = f'season{season}'
        list_videos = os.listdir(os.path.join(stimulus_dir, stim_name[0], file))
        list_videos = [x for x in list_videos if not x.startswith('.')]
        if list_videos[0].startswith('mv'): VideoStimSetPath = os.path.join(stimulus_dir, stim_name[0], file)
        if list_videos[0].startswith('im'): ImageStimSetPath = os.path.join(stimulus_dir, stim_name[0], file)
    
    # ------------------------------------------------------------------------------ 
    # Manually find folders for object_relations.
    # ------------------------------------------------------------------------------ 
    elif experiment_name.startswith('object_relations'):
        files_starting_with_vid = [file for file in os.listdir(os.path.join(stimulus_dir, stim_name[0])) if file.startswith('mworks')]
        VideoStimSetPath = os.path.join(stimulus_dir, stim_name[0], files_starting_with_vid[0])


    # ------------------------------------------------------------------------------ 
    # Manually find folders for oasis900.
    # ------------------------------------------------------------------------------ 
    elif experiment_name =='oasis900' or experiment_name =='oasis900_200on':
            ImageStimSetPath = os.path.join(stimulus_dir, stim_name[0], 'image_dicarlo_oasis900')

    # ------------------------------------------------------------------------------ 
    # Manually find folders for oasis100.
    # ------------------------------------------------------------------------------ 
    elif experiment_name.startswith('oasis100'): 
        if experiment_name.endswith('c'): ImageStimSetPath = os.path.join(stimulus_dir, stim_name[0], 'images_control')
        if experiment_name.endswith('o'): ImageStimSetPath = os.path.join(stimulus_dir, stim_name[0], 'images_original')

    # ------------------------------------------------------------------------------ 
    # Manually find folders for square_sinewave.
    # ------------------------------------------------------------------------------ 
    elif experiment_name == 'square_sinewave': 
        VideoStimSetPath = os.path.join(stimulus_dir, stim_name[0], 'squarewave_movies')

    else: print(f'  No Images or Videos found for {experiment_name} {stim_name}')
    # ------------------------------------------------------------------------------ 
    # For the following experiments, either SimulusSet is not found or the nwb files
    # are not creted yet. Once both are done, check how the Simulus Direcories look 
    # and update this part (or the finding StimulusDirecory part.)
    # ------------------------------------------------------------------------------ 

    # if experiment_name == 'NSD-COCO': 
    #     print("To Do")
    #     continue
    if experiment_name == 'RF': 
        print("To Do")
        continue
    elif experiment_name == 'flicker': 
        print("To Do")
        continue
    elif experiment_name == 'gestalt': 
        print("To Do")
        continue
    elif experiment_name.startswith('monkeyvalence'): 
        print("To Do")
        continue
    elif experiment_name == 'sine_wave': 
        print("To Do")
        continue

    # ------------------------------------------------------------------------------ 
    # Create StimulusSet for each ImageStimulusSet
    # ------------------------------------------------------------------------------ 
    df_index = np.where(df['ImageSet'].to_numpy() == experiment_name)[0][0]

    if ImageStimSetPath != None: 
        
        check_image_order = True
        count_1 = 0

        if experiment_name =='oasis900' or experiment_name =='oasis900_200on':
            def extract_integer(image_name):
                try:
                    return int(image_name[2:-4])
                except: pass
                
            list_images = os.listdir(ImageStimSetPath)
            list_images_sorted = [x for x in sorted(list_images, key = extract_number) if not x.startswith('.')]
            csv_path = os.path.join('/', *ImageStimSetPath.split('/')[:-1], 'image_dicarlo_oasis900.csv')
            df_csv = pd.read_csv(csv_path)
            mapping = {extract_integer(row['image_file_name']): row['filename'] for index, row in df_csv.iterrows()}
            reverse_mapping = {v: k for k, v in mapping.items()}
            list_images_sorted = sorted(list_images_sorted, key=lambda x: reverse_mapping.get(x, float('inf')))
            check_image_order = False

        elif experiment_name =='oasis900rotated':
            """
            According to the mwel file, the stimulus is organized as first 900 ori and then 900 rot.
            file_path = os.path.join('/', *path.split('/')[:-1],'image_set_definition_oriandrotated.mwel' ) 
            with open(file_path, 'r') as file:
                mwel_content = file.read()
            print(mwel_content)
            """
            list_images = os.listdir(ImageStimSetPath)
            list_images_sorted = [x for x in sorted(list_images, key = extract_number) if not x.startswith('.')]
            list_images_sorted_ori = [x for x in list_images_sorted if x.startswith('im_ori')]
            list_images_sorted_rot = [x for x in list_images_sorted if x.startswith('im_rot')]
            list_images_sorted = list_images_sorted_ori + list_images_sorted_rot
            check_image_order = False

        elif experiment_name == 'oasis900scrambled' or experiment_name == 'oasis900scrambled_200on': 
            filename = 'image_set_definition_oriandscramble.mwel'
            list_images_sorted = []
            with open(os.path.join(stimulus_dir, stim_name[0], filename), 'r') as file:
                mwel_content = file.read()
            for substring in mwel_content.split('var imagefiles')[-1].split('",\n"'):
                if substring == 'images/im9_scrambled.jpg"\n]\n': substring = 'images/im9_scrambled.jpg'
                list_images_sorted.append(substring.split('/')[-1])
            check_image_order = False

        elif experiment_name == 'NSD-COCO': 
            list_images = os.listdir(ImageStimSetPath)
            list_images_sorted = [x for x in sorted(list_images, key = extract_number) if not x.startswith('.')]
            check_image_order = False

        else:
            list_images = os.listdir(ImageStimSetPath)
            list_images_sorted = [x for x in sorted(list_images, key = extract_number) if not x.startswith('.')]

        if experiment_name == 'shapenet360' or experiment_name =='1_shapes' or experiment_name =='food' or experiment_name =='shapegen_static': # Add when list_images_sorted start with 1 and not 0 in filename. 
            count_1 = 1
            
        print(experiment_name, ImageStimSetPath, list_images_sorted)
        df.at[df_index, "StimulusSetPath"] = ImageStimSetPath
        update_prom_nwb(experiment_path, experiment_name, list_images_sorted, ImageStimSetPath, count_1, check_image_order)

    # ------------------------------------------------------------------------------ 
    # Create StimulusSet for each VideoStimulusSet
    # ------------------------------------------------------------------------------ 
    if VideoStimSetPath != None: 

        print(experiment_name, VideoStimSetPath)
        df.at[df_index, "StimulusSetPath"] = VideoStimSetPath
        list_movies = os.listdir(VideoStimSetPath)
        list_movies_sorted = [x for x in sorted(list_movies, key = extract_number) if not x.startswith('.')]

        if experiment_name == 'square_sinewave':
            file_path = '/braintree/data2/active/users/sgouldin/experiments-codebase/squarewave_fullfield/movie_definition_squarewave_set1.mwel'
            list_movies_sorted = []
            with open(file_path, 'r') as file:
                mwel_content = file.read()
            for substring in mwel_content.split('var imagefiles')[-1].split('",\n"'):
                try: 
                    subsubstring = substring.split('"')
                    for sub in subsubstring:
                        if sub.startswith('squarewave_movies'):
                            list_movies_sorted.append(sub.split('/')[-1])
                except: list_movies_sorted.append(substring.split('/')[-1])


        # ------------------------------------------------------------------------------ 
        # Copy movies into experiment file.
        # ------------------------------------------------------------------------------ 
        image_set_path = '/'.join(experiment_path.split('/')[:-1])
        try: os.mkdir(os.path.join(image_set_path, 'VideoStimulusSet'))
        except: pass
        rename_flag = False
        for movie in list_movies_sorted:
            try: shutil.copy2(os.path.join(VideoStimSetPath, movie), os.path.join(image_set_path, 'VideoStimulusSet'))
            except: rename_flag = True

        # ------------------------------------------------------------------------------ 
        # Update nwb files.
        # ------------------------------------------------------------------------------ 
        list_movies_sorted_new = []
        if rename_flag == False:
            if experiment_name == 'motionset1' or experiment_name == 'moca' or experiment_name == 'afv'  or experiment_name == 'faceemovids' or experiment_name.startswith('gratingsAdap_'):
                for movie in list_movies_sorted:
                    match = re.search(r'\d+', movie)
                    if match:
                        number = int(match.group())
                        movie_filename_new = f'exp_{experiment_name}_{number}.mp4'
                        list_movies_sorted_new.append(movie_filename_new)
                        moviepath = os.path.join(image_set_path, 'VideoStimulusSet')
                        os.rename(os.path.join(moviepath, movie), os.path.join(moviepath, movie_filename_new))

                # update_exp_nwb_movies(experiment_path, list_movies_sorted, 0, check_image_order=True)
            elif experiment_name == 'gestalt' or experiment_name == 'Co3D':
                for movie in list_movies_sorted:
                    match = re.search(r'\d+', movie)
                    if match:
                        number = int(match.group())
                        movie_filename_new = f'exp_{experiment_name}_{number-1}.mp4'
                        list_movies_sorted_new.append(movie_filename_new)
                        moviepath = os.path.join(image_set_path, 'VideoStimulusSet')
                        os.rename(os.path.join(moviepath, movie), os.path.join(moviepath, movie_filename_new))
                # update_exp_nwb_movies(experiment_path, list_movies_sorted, 1, check_image_order=True)
            elif experiment_name == 'square_sinewave' or experiment_name == 'object_relations':
                for movie, number in zip(list_movies_sorted, range(len(list_movies_sorted))):
                    movie_filename_new = f'exp_{experiment_name}_{number}.mp4'
                    list_movies_sorted_new.append(movie_filename_new)
                    moviepath = os.path.join(image_set_path, 'VideoStimulusSet')
                    os.rename(os.path.join(moviepath, movie), os.path.join(moviepath, movie_filename_new))
            else: list_movies_sorted_new = list_movies_sorted

        else: list_movies_sorted_new = list_movies_sorted
        
        print(list_movies_sorted_new)
        update_prom_nwb(experiment_path, experiment_name, list_movies_sorted_new, None, 0, check_image_order=True)

 

In [4]:
# ------------------------------------------------------------------------------ 
# Update Sheet 2 of Excel File.
# ------------------------------------------------------------------------------ 
xls = pd.ExcelFile(f'{os.path.dirname(cwd)}/pico_inventory.xlsx')
sheets = {sheet: xls.parse(sheet) for sheet in xls.sheet_names}

sheets['Sheet2'] = df  

with pd.ExcelWriter(f'{os.path.dirname(cwd)}/pico_inventory.xlsx', engine='openpyxl', mode='w') as writer:
    for sheet_name, sheet_df in sheets.items():
        sheet_df.to_excel(writer, sheet_name=sheet_name, index=False)  

In [None]:
# ------------------------------------------------------------------------------ 
# Basic Checks for prom files. 
# ------------------------------------------------------------------------------ 

def prettyprint(string):
    pretty_string = ' '.join(string.split())
    wrapped_string = textwrap.fill(pretty_string, width=120)
    return wrapped_string

def get_image_hash(img):
    # Convert the image to bytes
    img_bytes = img.tobytes()
    # Use hashlib to generate a hash
    hash = hashlib.md5(img_bytes).hexdigest()
    return hash

def are_images_identical(img1, img2):
    hash1 = get_image_hash(img1)
    hash2 = get_image_hash(img2)
    return hash1 == hash2

def check_prom_file(experiment_path):
    prom        = [x for x in os.listdir(experiment_path) if x.endswith('.prom.nwb')]
    prom_test   = [x for x in os.listdir(experiment_path) if x.endswith('.prom_test.nwb')]
    prom_train  = [x for x in os.listdir(experiment_path) if x.endswith('.prom_train.nwb')]
    
    if len(prom) != 0: 
        files = [prom[0], prom_train[0], prom_test[0]]
        for file in files:
            if not file.startswith('exp_gratingsAdap_s'): continue 
            print('________________________________________', os.path.basename(experiment_path), file)
            
            io = NWBHDF5IO(os.path.join(experiment_path, file), "r") 
            prom_nwb = io.read()
            
            assert len(prom_nwb.keywords[:]) != 0, 'No keywords in file.'
            if len(prom_nwb.notes) == 0: print('No notes in file.')

            # ------------------------------------------------------------------------------ 
            # Check PSTH shapes.
            # ------------------------------------------------------------------------------ 
        
            # print(prettyprint(prom_nwb.scratch['PSTHs_Normalizers_SessionMerged'].description))
            # print('__________')
            # print(prettyprint(prom_nwb.scratch['PSTHs_QualityApproved_SessionMerged'].description))
            # print('__________')
            # print(prettyprint(prom_nwb.scratch['PSTHs_QualityApproved_ZScored_SessionMerged'].description))
            # print('__________')
            # print(prettyprint(prom_nwb.scratch['QualityApprovedChannelMasks'].description))


            psth_zscored = prom_nwb.scratch['PSTHs_QualityApproved_ZScored_SessionMerged'][:]
            psth         = prom_nwb.scratch['PSTHs_QualityApproved_SessionMerged'][:]
            n_stimuli = psth.shape[0]

            assert len(psth_zscored.shape) == 4, 'wrong shape for PSTHs_QualityApproved_ZScored_SessionMerged '
            assert len(psth.shape) == 4,         'wrong shape for PSTHs_QualityApproved_SessionMerged ' 
            assert psth_zscored.shape == psth.shape, 'PSTH Shapes different.'
            assert np.min(psth_zscored) != 0, 'ZScored_PSTH not Z-scored'
            assert np.min(psth) == 0, 'PSTH containing negative numbers.'

            # ------------------------------------------------------------------------------ 
            # Check Stimulus Set.
            # ------------------------------------------------------------------------------ 

            # print((prettyprint(prom_nwb.stimulus_template['StimulusSet'].description)))
            imageset = os.path.basename(experiment_path).split('.')[0]
            try: 
                stimulus_path = df[df['ImageSet'] == imageset.replace("exp_", "", 1)]['StimulusSetPath'].tolist()[0]
                print(stimulus_path)
                list_images_sorted = [x for x in sorted(os.listdir(stimulus_path), key = extract_number) if not x.startswith('.')]
                
                
                try:
                    list_images_nwb = sorted(list(prom_nwb.stimulus_template['StimulusSet'].images.keys()), key = lambda x: int(x.split('_')[-1].split('.')[0]))
                    stimulus_ids = [int(filename.split('_')[-1].split('.')[0]) for filename in list_images_nwb]

                    if file.split('.')[-2] != 'prom':
                        stimulus_ids_nwb = list(prom_nwb.scratch['StimuliIDs'][:])
                        are_lists_equal = sorted(stimulus_ids) == sorted(stimulus_ids_nwb)
                        if are_lists_equal == False: print('Images and StimuluIDs are not identical!')
                        if len(stimulus_ids_nwb) != n_stimuli: print('Number of Stimuli do not match with StimulusIDs and n_stimuli')

                    n_images = len(prom_nwb.stimulus_template['StimulusSet'].images)
                    image_id = random.randint(0, n_images)

                    for image_id in range(n_images):
                        try:
                            image = prom_nwb.stimulus_template['StimulusSet'].images[f'{imageset}_{image_id}.png'][:]
                            im = Image.fromarray(image)
                            img = Image.open(os.path.join(stimulus_path, list_images_sorted[image_id]))
                            img_= np.array(img)
                            img = im = Image.fromarray(img_)

                            if are_images_identical(im, img) == False: 
                                display_image(img)
                                display_image(im)
                                print(f'Image {image_id} are not Identical for file {list_images_sorted[image_id]}!')
                                print(stimulus_path)
                                print(list_images_sorted)
                                # break

                        except: pass #Exception as e: print('Image not in this sub-set.')


                    pattern = re.compile(fr'{imageset}_\d+\.png')
                    all_match_pattern = all(pattern.match(filename) for filename in list_images_nwb)
                    if all_match_pattern == False: 
                        print(' NWB Image Names are not consistent. At least one image in source is corrupted.')

                except: 

                    if file.split('.')[-2] != 'prom':
                        stimulus_ids_nwb = (prom_nwb.stimulus_template['StimulusSet'].starting_frame)
                        if len(stimulus_ids_nwb) != n_stimuli: print('Number of Stimuli do not match with StimulusIDs and n_stimuli')
            except: print('No StimulusPath.')
            io.close()


experiment_file_paths = glob.glob(os.path.join(root_dir, '[exp]*', '*'))
stimulus_dir          = '/braintree/data2/active/users/sgouldin/experiments-codebase'
stimuli_names         = os.listdir(stimulus_dir)


for experiment_path in experiment_file_paths: 
    check_prom_file(experiment_path)
    # break