# Benchmark on BrainScore using nwb files.

In [None]:
import os
from pynwb import NWBHDF5IO, NWBFile
from pynwb.base import Images
from pynwb.image import RGBImage, ImageSeries
from brainio.stimuli import StimulusSet
from PIL import Image
import re
import numpy as np
import pandas as pd
import json
import xarray as xr
from brainio.assemblies import NeuronRecordingAssembly, NeuroidAssembly
import os
from pathlib import Path
import json
from brainscore_vision import load_benchmark
from brainio.packaging import package_stimulus_set, package_data_assembly
from brainscore_vision import load_benchmark, load_dataset, load_ceiling, load_model, load_metric
import brainio
from brainscore_vision.metric_helpers.transformations import CrossValidation
from brainscore_vision.benchmark_helpers.neural_common import NeuralBenchmark, average_repetition
from brainscore_vision.utils import LazyLoad


def extract_number(filename):
    # Extract the number from the filename and return it as an integer
    match = re.search(r'\d+', filename.split('_')[-1])
    return int(match.group()) if match else 0

def get_stimuli(nwb_file, experiment_path, exp_name):
    image_paths = []
    stimuli     = []
    image_ids   = [int(x.split('_')[-1].split('.png')[0]) for x in sorted(list(nwb_file.stimulus_template[f'StimulusSet'].images), key = extract_number)]
    stimulus_id = 0
    
    try: os.mkdir(os.path.join(experiment_path, 'images'))
    except: pass 

    print("Iterating over the images ...")
    for i in image_ids:
        try:
            image = nwb_file.stimulus_template[f'StimulusSet'][f'exp_{exp_name}_{i}.png'][:]
            im = Image.fromarray(image)
            if not os.path.isfile(os.path.join( experiment_path, 'images', f'exp_{exp_name}_{i}.png')):
                im.save(os.path.join( experiment_path, 'images', f'exp_{exp_name}_{i}.png'))
            image_paths.append(os.path.join( experiment_path, 'images', f'exp_{exp_name}_{i}.png'))
        
            stimuli.append({
                'stimulus_id': stimulus_id,
                'image_id': stimulus_id,
                'id': stimulus_id,
                # 'stimulus_path_within_store': f"exp_{exp_name}_{i}",
                # 'stimulus_set': exp_name,
                'image_number': i,
                # 'stimulus_nwb_file_path': f"{os.path.join(*experiment_path.split('/')[:-1])}/stimulus_template/StimulusSet/exp_{exp_name}_{i}.png",
                'image_file_name': f'exp_{exp_name}_{i}.png',
                'filename': f'exp_{exp_name}_{i}.png',
                'background_id':'',
                's':'',	
                'rxy':'',
                'tz':'',	
                'category_name':'',
                'rxz_semantic':'',	
                'ty':'',
                'ryz':'',
                'object_name':'',	
                'variation':'',	
                'size':'',	
                'rxy_semantic':'',
                'ryz_semantic':'',
                'rxz':''	                
            })
            stimulus_id += 1
        except Exception as e: 
            print(e)

    stimuli = pd.DataFrame(stimuli)   
    return stimuli, image_paths

def filter_neuroids(assembly, threshold):
    ceiler = load_ceiling('internal_consistency')
    ceiling = ceiler(assembly)
    ceiling = ceiling.raw
    ceiling = CrossValidation().aggregate(ceiling)
    pass_threshold = ceiling >= threshold
    assembly = assembly[{'neuroid': pass_threshold}]
    return assembly

def load_assembly(assembly, average_repetitions, region):
    assembly = assembly.sel(region=region)
    assembly['region'] = 'neuroid', [region] * len(assembly['neuroid'])
    assembly.load()
    assembly = assembly.sel(time_bin_id=0)  # 70-170ms
    assembly = assembly.squeeze('time_bin')
    assert NUMBER_OF_TRIALS == len(np.unique(assembly.coords['repetition']))
    assert VISUAL_DEGREES == assembly.attrs['image_size_degree']
    if average_repetitions:
        assembly = average_repetition(assembly)
    return assembly

def _NeuralBenchmark(assembly_, region):
    # print('load rep')
    assembly_repetition = LazyLoad(lambda region=region: load_assembly(assembly_, average_repetitions=False, region=region))
    # print('load assy')
    assembly = LazyLoad(lambda region=region: load_assembly(assembly_, average_repetitions=True, region=region))
    # metric = load_metric('pls', crossvalidation_kwargs=dict(stratification_coord='object_name'))
    # print('load metric')
    metric = load_metric('pls')
    # print('load ceilling')
    ceiler = load_ceiling('internal_consistency')
    # print('return nb')
    return NeuralBenchmark(identifier=f'Aliya2024.{exp_name}.{region}-pls', version=1,
                           assembly=assembly, similarity_metric=metric,
                           visual_degrees=VISUAL_DEGREES, number_of_trials=NUMBER_OF_TRIALS,
                           ceiling_func=lambda: ceiler(assembly_repetition),
                           parent=region)

def get_meta(nwb_file):
    s = (nwb_file.scratch['PSTHs_QualityApproved_ZScored_SessionMerged'].description.split('[start_time_ms, stop_time_ms, tb_ms]: ')[-1])
    numbers = s.strip('[]').split()
    array = np.array(numbers, dtype=int)
    return array

def get_neuroids(nwb_file):
    #-----------------------------------------------------------------------------------------------------------------------------
    # Get electrode metadata from nwb file.
    #-----------------------------------------------------------------------------------------------------------------------------
    data_list = []  
    for i in range(len(nwb_file.electrodes['location'])):
        data_dict       = {}
        location_item   = nwb_file.electrodes['location'][i]
        group_item      = nwb_file.electrodes['group'][i] 
        bank_item       = nwb_file.electrodes['group_name'][i] 
        label_item      = nwb_file.electrodes['label'][i]
        try: label_item = int(nwb_file.electrodes['label'][i].split('_')[0])
        except: label_item = nwb_file.electrodes['label'][i]

        location_match = re.search(r'\[(\d+).0, (\d+).0, (\d+).0\]', location_item)
        if location_match:
            data_dict['col'] = location_match.group(2)
            data_dict['row'] = location_match.group(1)
            data_dict['elec'] = location_match.group(3) 

        serialnumer = group_item.description.split('Serialnumber: ')[-1]
        data_dict['arr'] = serialnumer

        group_match = re.search(r"\['(\w+)', '(\w+)', '(\w+)'\]", group_item.location)
        if group_match:
            data_dict['hemisphere']  = group_match.group(1)
            data_dict['region']  = group_match.group(2)
            data_dict['subregion'] = group_match.group(3)
        
        data_dict['bank']  = bank_item.split('_')[-1]
        data_dict['animal'] = subject
        if (label_item) < 10:
            neuroid_id = f"{bank_item.split('_')[-1]}-00{label_item}"
            elec = f"00{label_item}"
        else:
            neuroid_id = f"{bank_item.split('_')[-1]}-0{label_item}"
            elec = f"00{label_item}"
        data_dict['neuroid_id']  = neuroid_id
        data_dict['elec']  = elec

        data_list.append(data_dict)
    
    neuroid_meta = pd.DataFrame(data_list)
    return neuroid_meta

def get_QC_neurids(nwb_file):
    '''
    This Method uses logical OR to find the common QC channels. (Closer to the BrainScore Method)
    '''
    psth = nwb_file.scratch['PSTHs_ZScored_SessionMerged'][:]
    common_QC_channels = np.logical_and.reduce(nwb_file.scratch['QualityApprovedChannelMasks'])
    channel_masks_day = nwb_file.scratch['QualityApprovedChannelMasks'][:]
    channel_mask_all_list = []
    day = 0
    for key in sorted(nwb_file.scratch.keys()):
        if key.startswith('PSTHs_QualityApproved_20'):
            # print(key)
            nreps_per_day = nwb_file.scratch[key][:].shape[1]
            for i in range(nreps_per_day):
                channel_mask_all_list.append(channel_masks_day[day,:])
            day += 1
    channel_mask_all = np.array(channel_mask_all_list)

    assert channel_mask_all.shape[0] == psth.shape[1]
    filtered_neurids = np.any(channel_mask_all, axis=0)
    return filtered_neurids, common_QC_channels

def load_responses(nwb_file, stimuli, use_QC_data = True, do_filter_neuroids = False, use_brainscore_filter_neuroids_method=False):
    #-----------------------------------------------------------------------------------------------------------------------------
    # Get the PSTH and normalizer PSTH
    #-----------------------------------------------------------------------------------------------------------------------------
    normalizer_psth = nwb_file.scratch['PSTHs_Normalizers_SessionMerged'][:]
    if not use_QC_data:
        psth            = nwb_file.scratch['PSTHs_ZScored_SessionMerged'][:]
    elif use_QC_data:
        psth            = nwb_file.scratch['PSTHs_QualityApproved_ZScored_SessionMerged'][:]
    meta = get_meta(nwb_file)
    qc_array_or, qc_array_and = get_QC_neurids(nwb_file)
    #-----------------------------------------------------------------------------------------------------------------------------
    # Compute firing rates.
    #-----------------------------------------------------------------------------------------------------------------------------
    timebins = [[70, 170], [170, 270], [50, 100], [100, 150], [150, 200], [200, 250], [70, 270]]
    timebase = np.arange(meta[0], meta[1], meta[2])
    assert len(timebase) == psth.shape[2]
    rate = np.empty((len(timebins), psth.shape[0], psth.shape[1], psth.shape[3]))
    for idx, tb in enumerate(timebins):
        t_cols = np.where((timebase >= (tb[0])) & (timebase < (tb[1])))[0]
        rate[idx] = np.mean(psth[:, :, t_cols, :], axis=2)  # Shaped time bins x images x repetitions x channels

    #-----------------------------------------------------------------------------------------------------------------------------
    # Load neuroid metadata and image metadata
    #-----------------------------------------------------------------------------------------------------------------------------
    image_id     = stimuli.image_number
    neuroid_meta = get_neuroids(nwb_file)

    assembly = xr.DataArray(rate,
                            coords={'repetition': ('repetition', list(range(rate.shape[2]))),
                                    'time_bin_id': ('time_bin', list(range(rate.shape[0]))),
                                    'time_bin_start': ('time_bin', [x[0] for x in timebins]),
                                    'time_bin_stop': ('time_bin', [x[1] for x in timebins]),
                                    'image_id': ('image', image_id)},
                            dims=['time_bin', 'image', 'repetition', 'neuroid'])
    if use_QC_data:
        for column_name, column_data in neuroid_meta.iteritems():
            assembly = assembly.assign_coords(**{f'{column_name}': ('neuroid', list(column_data.values[qc_array_and]))})
    else:
        for column_name, column_data in neuroid_meta.iteritems():
            assembly = assembly.assign_coords(**{f'{column_name}': ('neuroid', list(column_data.values))})

    assembly = assembly.sortby(assembly.image_id)
    stimuli  = stimuli.sort_values(by='image_id').reset_index(drop=True)
    for column_name, column_data in stimuli.iteritems():
        assembly = assembly.assign_coords(**{f'{column_name}': ('image', list(column_data.values))})
    assembly = assembly.sortby(assembly.id)  

    # Collapse dimensions 'image' and 'repetitions' into a single 'presentation' dimension
    assembly = assembly.stack(presentation=('image', 'repetition')).reset_index('presentation')
    assembly = NeuronRecordingAssembly(assembly)

    if do_filter_neuroids and use_brainscore_filter_neuroids_method:
        # Filter noisy electrodes
        psth = normalizer_psth
        if psth.shape[0] == 26:
            psth = psth[:-1,:,:,:] #remove grey image
        t_cols = np.where((timebase >= (70 )) & (timebase < (170)))[0]
        rate = np.mean(psth[:, :, t_cols, :], axis=2)
        normalizer_assembly = xr.DataArray(rate,
                                        coords={'repetition': ('repetition', list(range(rate.shape[1]))),
                                                'image_id': ('image', list(range(rate.shape[0]))),
                                                'id': ('image', list(range(rate.shape[0])))},
                                        dims=['image', 'repetition', 'neuroid'])
        for column_name, column_data in neuroid_meta.iteritems():
            normalizer_assembly = normalizer_assembly.assign_coords(
                **{f'{column_name}': ('neuroid', list(column_data.values))})

        normalizer_assembly = normalizer_assembly.assign_coords(**{f'{"stimulus_id"}': ('image', list(np.linspace(1,psth.shape[0],psth.shape[0], dtype=int)))})# had to add this part TODO: remove the last grey image from normalizer set when doing the nwb conversion
        normalizer_assembly = normalizer_assembly.stack(presentation=('image', 'repetition')).reset_index('presentation')
        normalizer_assembly = normalizer_assembly.drop('image')
        normalizer_assembly = normalizer_assembly.transpose('presentation', 'neuroid')
        normalizer_assembly = NeuronRecordingAssembly(normalizer_assembly)

        filtered_assembly = filter_neuroids(normalizer_assembly, 0.7)
        assembly = assembly.sel(neuroid=np.isin(assembly.neuroid_id, filtered_assembly.neuroid_id))

    elif do_filter_neuroids and not use_brainscore_filter_neuroids_method:
        filter_assembly = xr.DataArray(qc_array_or,
                                dims=['neuroid'])
        for column_name, column_data in neuroid_meta.iteritems():
            filter_assembly = filter_assembly.assign_coords(
                **{f'{column_name}': ('neuroid', list(column_data.values))})
            
        filtered_assembly = filter_assembly.sel(neuroid=qc_array_or)
        assembly = assembly.sel(neuroid=np.isin(assembly.neuroid_id, filtered_assembly.neuroid_id))

    elif use_QC_data:
        filter_assembly = xr.DataArray(qc_array_and,
                                dims=['neuroid'])
        for column_name, column_data in neuroid_meta.iteritems():
            filter_assembly = filter_assembly.assign_coords(
                **{f'{column_name}': ('neuroid', list(column_data.values))})
            
        filtered_assembly = filter_assembly.sel(neuroid=qc_array_and)
        assembly = assembly.sel(neuroid=np.isin(assembly.neuroid_id, filtered_assembly.neuroid_id))

    assembly = assembly.transpose('presentation', 'neuroid', 'time_bin')

    # Add other experiment related info
    assembly.attrs['image_size_degree'] = 8
    assembly.attrs['stim_on_time_ms'] = 100

    return assembly


#------------------------------------------------------------------------------------------------------------------
# Load NWB file and get stimuli.
#------------------------------------------------------------------------------------------------------------------

exp_name    = 'domain-transfer-2023'
subject     = 'pico'
nwb_file_path   = '/braintree/home/aliya277/inventory_new/exp_{}/exp_{}.sub_{}/exp_{}.sub_{}.prom.nwb'.format(exp_name, exp_name, subject, exp_name, subject)
experiment_path = '/braintree/home/aliya277/inventory_new/exp_{}/exp_{}.sub_{}'.format(exp_name, exp_name, subject)

print(f'DATASET: {exp_name}')
### Load nwb file
print("Loading the NWB file ...")
io = NWBHDF5IO(nwb_file_path, "r") 
nwb_file = io.read()    

print("Getting my Stimulus Set ...")
aliya_stimuli, aiya_imagepaths = get_stimuli(nwb_file, experiment_path, exp_name)

# ------------------------------------------------------------------------------------------------------------------
# Creata Assemblies.
# ------------------------------------------------------------------------------------------------------------------
print("Creating Assemblies ...")

# This assembly is created using the QC method used to create the QualitApproved SessionMerged PSTH from the NWB files. This method is quite strict.
aliya_assembly1  = load_responses(nwb_file, aliya_stimuli)
# This assembly is created using a slightly more loose QC method using the p-values from the NWB files. This method is more similar to the brainscore method.
aliya_assembly2  = load_responses(nwb_file, aliya_stimuli, use_QC_data = False, do_filter_neuroids = True,  use_brainscore_filter_neuroids_method=False)
# This assembly is created using the brainscore QC method.
aliya_assembly3  = load_responses(nwb_file, aliya_stimuli, use_QC_data = False, do_filter_neuroids = True,  use_brainscore_filter_neuroids_method=True)

#------------------------------------------------------------------------------------------------------------------
# Creata Benchmarks.
#------------------------------------------------------------------------------------------------------------------
print("Creating Benchmarks ...")

VISUAL_DEGREES = 8
psth = nwb_file.scratch['PSTHs_ZScored_SessionMerged'][:]
NUMBER_OF_TRIALS = psth.shape[1]

def update_assembly(assembly):
    assembly.name = f'dicarlo.{exp_name}.Aliya2024'
    stimuli = StimulusSet(aliya_stimuli)
    stimuli.stimulus_paths = aiya_imagepaths
    stimuli.name = f"Aliya2024_{exp_name}"
    stimuli.identifier = f"Aliya2024_{exp_name}"
    assembly.attrs['stimulus_set'] = stimuli
    return assembly

aliya2024_benchmark1 = _NeuralBenchmark( update_assembly(aliya_assembly1), 'IT')
aliya2024_benchmark2 = _NeuralBenchmark( update_assembly(aliya_assembly2), 'IT')
aliya2024_benchmark3 = _NeuralBenchmark( update_assembly(aliya_assembly3), 'IT')
io.close()

#------------------------------------------------------------------------------------------------------------------
# Score Benchmarks.
#------------------------------------------------------------------------------------------------------------------
print('Scoring Models ...')
list_models=['pixels', 'tv_efficientnet-b1', 'alexnet', 'resnet50_julios', 'yudixie_resnet50_imagenet1kpret_0_240312', 'eBarlow_Vanilla', 'eMMCR_Vanilla']
for model_ in list_models:
    try:
        print('------------------------------------------')
        print('------------------------------------------')
        print('Scoring Model: ',model_)
        model = load_model(model_)
        score_1_aliya = aliya2024_benchmark1(model)
        score_2_aliya = aliya2024_benchmark2(model)
        score_3_aliya = aliya2024_benchmark3(model)
        print('\n------------------------------------------')
        print('Score1:')
        print(score_1_aliya)
        print('\n------------------------------------------')
        print('Score2:')
        print(score_2_aliya)
        print('\n------------------------------------------')
        print('Score3:')
        print(score_3_aliya)
        print('\n')

    except Exception as e: print(e)

    import shutil
    import os
    # Specify the directory path
    directory_path1 = '/home/aliya277/.model-tools'
    directory_path2 = '/home/aliya277/.result_caching'
    for directory_path in [directory_path1, directory_path2]:
        # Check if the directory exists
        if os.path.exists(directory_path):
            shutil.rmtree(directory_path)
            print(f"Directory {directory_path} has been removed successfully.")
        else:
            print(f"Directory {directory_path} does not exist.")
