### Step 1: Load modules and functions

In [None]:
# ------------------------------------------------------------------------------ 
# Define parameters:
# ------------------------------------------------------------------------------ 
n_boot=10000
integration_start_ms    = 70
integration_end_ms      = 170

In [1]:
import json, sys
import xarray as xr
import numpy as np
import pandas as pd
import h5py
from brainio.assemblies import NeuronRecordingAssembly
from pynwb import NWBHDF5IO, NWBFile
import glob, os
from datetime import datetime
import pytz  # This is required to handle timezone conversions
import sys
import io as ioprint
cwd = os.getcwd()
sys.path.append(os.path.dirname(cwd))
from ndashboard.nquality.raw_data_template import SessionNeuralData
from ndashboard.nquality.quality_within_session import Session
from ndashboard.nquality.quality_across_sessions import LongitudinalQuality

def get_unix_timestamp(date_str, time_str, date_format='%Y%m%d', time_format='%H%M%S'):
    datetime_str = f"{date_str} {time_str}"
    dt = datetime.strptime(datetime_str, f"{date_format} {time_format}")
    # Assuming the provided time is in UTC
    # If it's in another timezone, you can adjust it accordingly using pytz
    dt = pytz.utc.localize(dt)
    unix_timestamp = (dt.timestamp())
    return unix_timestamp

def generate_timestamps(start_timestamp, interval_ms, length):
    # Create an array of increments (100ms steps)
    increments = np.arange(0, length * interval_ms / 1000, interval_ms / 1000)
    timestamps = start_timestamp + increments
    return timestamps

def create_norm_assembly(psth, meta, start_timestemp = None):
    timebase = np.arange(meta[0], meta[1], meta[2])
    timebins = np.asarray([[int(x), int(x)+int(meta[2])] for x in timebase])
    assert len(timebase) == psth.shape[2], f"Number of bins is not correct. Expected {len(timebase)} got {psth.shape[2]}"
    
    assembly = xr.DataArray(psth,
                    coords={'repetition': ('repetition', list(range(psth.shape[1]))),
                            'stimulus_id': ('image', list(range(psth.shape[0]))),
                            'time_bin_id': ('time_bin', list(range(psth.shape[2]))),
                            'time_bin_start': ('time_bin', [x[0] for x in timebins]),
                            'time_bin_stop': ('time_bin', [x[1] for x in timebins])},
                    dims=['image', 'repetition', 'time_bin', 'neuroid'])

    assembly = assembly.stack(presentation=('image', 'repetition')).reset_index('presentation')
    assembly = assembly.drop('image')
    assembly = assembly.isel(time_bin = slice(int(0-(meta[0]/meta[2])+(integration_start_ms/meta[2])), int(0-(meta[0]/meta[2])+(integration_end_ms/meta[2])))).sum('time_bin').transpose('presentation', 'neuroid')
    
    if start_timestemp == None: 
        assembly = assembly.assign_coords({'unix_timestamp': ('presentation', np.linspace(0, integration_end_ms-integration_start_ms, assembly.shape[0]))})
    else: 
        timestamps = generate_timestamps(start_timestemp, interval_ms=integration_end_ms-integration_start_ms, length=assembly.shape[0])
        assembly   = assembly.assign_coords({'unix_timestamp': ('presentation', timestamps)})
    numchannels = assembly.shape[1]

    return assembly, numchannels

def find_norm_with_date(realdate):
    all_paths = []
    normalizer_file_paths = glob.glob(os.path.join(root_dir, '[norm]*', '*', '*', '[!h5]*'))
    for norm_file_path in normalizer_file_paths:
        date, time = os.path.basename(norm_file_path).split('.')[-2].split('_')
        if date == realdate: 
            all_paths.append(norm_file_path)

    if len(all_paths) == 0: return [None]   
    
    return all_paths

def get_list_of_days(days):
    list_days = []
    for day in days:
        list_days.append(day.split('.')[-1])
    return list_days

def quality_within_session(day):
    
    norm_paths = find_norm_with_date(day)
    for norm_path in norm_paths:
        if norm_path != None:
            print(norm_path)
            print(glob.glob(os.path.join(norm_path, '*[nwb]')))
            norm_nwb_file_path = glob.glob(os.path.join(norm_path, '*[nwb]'))[0]
            
            io = NWBHDF5IO(norm_nwb_file_path, "r") 
            norm_nwbfile = io.read()
            try: 
                psth = norm_nwbfile.scratch['psth'][:]
                meta = norm_nwbfile.scratch['psth meta'][:]
                #print(psth.shape)
                assert psth.shape[0] == 86 or psth.shape[0] == 26 # norm FOSS has 86 and norm HVM has 26 images

                print(f"Within: using normalizer file {os.path.basename(norm_nwb_file_path)}")
                print(psth.shape, meta)
                io.close()
                da, nc  = create_norm_assembly(psth, meta)
                session = SessionNeuralData(da)
                session = Session(session, boot_seed=0, nboot=n_boot)
                ds_quality = session.ds_quality
                pvalues = ds_quality['pvalue_signal_variance'].data
                print('P Values of first session: ', pvalues)
                
                return pvalues
            except Exception as e: 
                print(e)
                print(f'No psth available for normalizer {day}.')
                return [None]
        else: 
            print(f'No normalizer found for day {day}.')
            return [None]

def quality_across_sessions(first_day, comparing_day):
    
    norm_paths_1 = find_norm_with_date(first_day)
    norm_paths_2 = find_norm_with_date(comparing_day)

    if len(norm_paths_2) == 2: norm_paths_1.append(norm_paths_1[0])

    for norm_path_1, norm_path_2 in zip(norm_paths_1, norm_paths_2):
        if norm_path_1 != None and norm_path_2 != None:

            norm_nwb_file_path_1 = glob.glob(os.path.join(norm_path_1, '*[nwb]'))[0]
            norm_nwb_file_path_2 = glob.glob(os.path.join(norm_path_2, '*[nwb]'))[0]

            try: 
                io = NWBHDF5IO(norm_nwb_file_path_1, "r") 
                norm_nwbfile = io.read()
                psth_1 = norm_nwbfile.scratch['psth'][:]
                meta_1 = norm_nwbfile.scratch['psth meta'][:]
                io.close()
                
                io = NWBHDF5IO(norm_nwb_file_path_2, "r") 
                norm_nwbfile = io.read()
                psth_2 = norm_nwbfile.scratch['psth'][:]
                meta_2 = norm_nwbfile.scratch['psth meta'][:]
                io.close()

                assert psth_1.shape[0] == 86 or psth_1.shape[0] == 26 # norm FOSS has 86 and norm HVM has 26 images
                assert psth_2.shape[0] == 86 or psth_2.shape[0] == 26
                assert psth_1.shape[0] == psth_2.shape[0] # only compare from the same normalizer image set

                print(f"Across: using normalizer file {os.path.basename(norm_nwb_file_path_1)} and {os.path.basename(norm_nwb_file_path_2)} ")
                print(psth_1.shape, meta_1, psth_2.shape, meta_2)
                
                n_channel_1 = psth_1.shape[-1]
                da, nc  = create_norm_assembly(psth_1, meta_1)
                session_1 = SessionNeuralData(da)

                n_channel_2 = psth_2.shape[-1]
                da, nc  = create_norm_assembly(psth_2, meta_2)
                session_2 = SessionNeuralData(da)

                if n_channel_1 == n_channel_2:
                    session = LongitudinalQuality([session_1, session_2], boot_seed=0, nboot=n_boot)
                    ds_quality = session.ds_quality
                    pvalues = ds_quality['pvalue_signal_variance'].data
                    return pvalues
                else: 
                    print(f"Number of channels do not match. {n_channel_1} != {n_channel_2} ")
                    return [None, None]
                
            except Exception as error:
                print(error)
                print(f'No psth available for normalizers {first_day, comparing_day}.')
                return [None, None]
            
        else: 
            print(f'No normalizer found for day {first_day, comparing_day}.')
            return [None, None]

def update_sheet(df, exp_nwb_path, text):
    imageset = os.path.basename(exp_nwb_path).split('.')[0].split('_')[1:]
    if len(imageset) == 1: imageset = imageset[0]
    elif len(imageset) > 1: imageset = '_'.join(imageset)
    mask = df['ImageSet'] == imageset
    index = df.index[mask].tolist()[0]
    df.at[index, 'proc_nwb'] = text
        
root_dir        = '/braintree/home/aliya277/inventory_new'

df = pd.read_excel( os.path.dirname(cwd)+'/pico_inventory.xlsx' , sheet_name='Sheet2')
df['proc_nwb'] = 'No QC.'
# list of experiments sarah marked as 'not going on brainscore'
list_exp_not_using = [row['ImageSet'] for index, row in df.iterrows() if row['BrainScore'] != 'Y']

SubjectName = 'pico'
storage_dir = '/braintree/home/aliya277/inventory_new'

### Step 2: For each experiment, do a within-session and across-session quality check and save them in the experiment nwb file.

In [None]:
experiment_file_paths = glob.glob(os.path.join(root_dir, '[exp]*', '*'))
experiment_file_paths = [d for d in experiment_file_paths if 'VideoStimulusSet' not in os.path.basename(d)]


for experiment_path in experiment_file_paths: 

    days    = glob.glob(os.path.join(experiment_path, '*[!npy][!txt][!nwb]'))
    n_days  = len(days)
    n_sessions = 0
    for day in days :
        n_sessions += len(glob.glob(os.path.join(experiment_path, day, '*', '*proc.nwb'))) # TO DO: REMOVE OLD
    first_day = days[0].split('.')[-1]

    
    #if not os.path.basename(experiment_path) == 'exp_hvm.sub_solo': continue 
    # if not os.path.basename(experiment_path) == 'exp_nat300.sub_solo': continue 
    # if not os.path.basename(experiment_path) == 'exp_bold5000.sub_solo': continue 

    print('________________________________________________________________________________')
    print(f'{os.path.basename(experiment_path)} has {n_days} days and {n_sessions} sessions')
    
    # ------------------------------------------------------------------------------ 
    # Skip code, if experiment is not wanted (see list above) or if nwb not exists.
    # ------------------------------------------------------------------------------  

    if os.path.basename(experiment_path).split("_")[1].split('.')[0] in list_exp_not_using: 
        print(os.path.basename(experiment_path).split("_")[1].split('.')[0])
        update_sheet(df, experiment_path, 'Experiment is not going on BrainScore')
        continue 
    
    if n_sessions == 0: 
        print(f'{os.path.basename(experiment_path)} has no nwb file.')
        update_sheet(df, experiment_path, 'No nwb files in experiment.')
        continue     
    
    if os.path.isfile(os.path.join(experiment_path, 'pvalues_first_day.npy')):
        pvalues_first_day = np.load(os.path.join(experiment_path, 'pvalues_first_day.npy'))

    else:
        # ------------------------------------------------------------------------------ 
        # Do within Session QC for the sessions of the first day.  
        # ------------------------------------------------------------------------------  
        pvalues_first_day = quality_within_session(first_day)

        if pvalues_first_day[0] == None: 
            update_sheet(df, experiment_path, 'No normalizers.')
            continue 
        np.save(os.path.join(experiment_path, 'pvalues_first_day.npy'), pvalues_first_day)
    
    if n_days > 1: 
        run_again = True
        if os.path.isfile(os.path.join(experiment_path, 'corss_session_pvalues.npy')): 
            run_again = False
            corss_session_pvalues = np.load(os.path.join(experiment_path, 'corss_session_pvalues.npy'))
            if len(corss_session_pvalues) != n_days-1: run_again = True

        if run_again==True:
            # ------------------------------------------------------------------------------ 
            # Do across session QC for the other sessions with the first day. 
            # ------------------------------------------------------------------------------  
            list_all_days = get_list_of_days(days[1:])
            corss_session_pvalues = []
            for comparing_day in list_all_days:
                p_values = quality_across_sessions(first_day, comparing_day)
                corss_session_pvalues.append(p_values)
            
            if any(element is None for sublist in corss_session_pvalues for element in sublist): 
                print("Did not find all the normalizers.")
                update_sheet(df, experiment_path, 'No normalizers.')
                continue
            else:
                if np.array(corss_session_pvalues == None).sum() == 0:
                    for val in corss_session_pvalues: print(np.allclose(np.where(pvalues_first_day<0.05), np.where(val[0,:]<0.05)))

                np.save(os.path.join(experiment_path, 'corss_session_pvalues.npy'), corss_session_pvalues)

    # ------------------------------------------------------------------------------ 
    # Add p-values to respective experiment nwb files.
    # ------------------------------------------------------------------------------ 

    for day, i_day in zip(days, range(n_days)) :
        print(f'Checking if P-Values are added to experiment file for day {i_day+1}')
        exp_nwb_paths = (glob.glob(os.path.join(experiment_path, day, '*',  '*proc.nwb')))

        if i_day ==0:   current_pvalues = pvalues_first_day
        else:           current_pvalues = corss_session_pvalues[i_day-1][1,:]

        if current_pvalues[0] == None: continue

        for exp_nwb_path in exp_nwb_paths:
            
            print(i_day, os.path.basename(exp_nwb_path))
            io = NWBHDF5IO(exp_nwb_path, "a") 
            exp_nwbfile = io.read()
            try:
                n_channel = exp_nwbfile.scratch['psth'][:].shape[-1]
                n_channel_norm = pvalues_first_day.shape[0]
                assert n_channel == n_channel_norm
            except Exception as e: print(e, 'Check Channel numbers for normalizers and experiment.')

            try: 
                exp_nwbfile.scratch['PValuesPerChannel'][:]
                print('     P-Values are already added.')
            except:
                print("Adding pvalues to nwb.")
                exp_nwbfile.add_scratch(
                        current_pvalues,
                        name="PValuesPerChannel",
                        description=f"An array of length equal to the number of electrodes, where each entry corresponds to \
                            the p-value for the electrode with the matching ID. For the initial recording, a within-session \
                            quality check is conducted; for subsequent recordings, a cross-session quality comparison is \
                            performed with the initial recording. Parameters used: boot_seed=0, nboot={n_boot}.",
                        )
                        
                io.write(exp_nwbfile)
            io.close()
    update_sheet(df, experiment_path, 'P-Values added.')
                
# Update Sheet 2
xls = pd.ExcelFile(f'{os.path.dirname(cwd)}/pico_inventory.xlsx')
sheets = {sheet: xls.parse(sheet) for sheet in xls.sheet_names}
sheets['Sheet2'] = df  
with pd.ExcelWriter(f'{os.path.dirname(cwd)}/pico_inventory.xlsx', engine='openpyxl', mode='w') as writer:
    for sheet_name, sheet_df in sheets.items():
        sheet_df.to_excel(writer, sheet_name=sheet_name, index=False)

