In [1]:
import os.path

import nibabel as nib
import numpy as np
import pandas as pd

global fig
%matplotlib widget
%matplotlib inline

from skopt.space import Integer

########################
### Custom functions ###
########################
from src.loss_func import DC, get_yield, hellinger_distance
from src.phos_elect import create_grid, implant_grid, get_phosphenes
### needed for matrix rotation/translation ect
from src.ninimplant import get_xyz
import src.utils as utils
import src.visualizations as visualizations

import src.generate_visual_sectors as gvs

# !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
# added because: "invalid value encountered in True-divide"
import sys
import warnings
if not sys.warnoptions:
    warnings.simplefilter("ignore")
np.seterr(divide='ignore', invalid='ignore')

In [2]:
def read_pickle_file(sub: str, hem: str) -> list:
    """Reads the saved pickle file.
    
    Parameters
    ----------
    hem : str
        The hemisphere from which we want to get the pickle file
        
    Returns
    -------
    data : list
        [optimized_arrays_from_f_manual, phosphenes_per_arr, phosphene_map_per_arr,
        total_contacts_xyz_moved, all_phosphenes, total_phosphene_map]
    """
    import glob, pickle
    RESULTS_PATH = "/home/odysseas/Desktop/UU/thesis/BayesianOpt/fsaverage_5_arrays_10x10x10/results/"
    dir = RESULTS_PATH + sub + "/" + hem + "/"
    filenames = glob.glob(os.path.join(dir, "*.pkl"))
    data = []
    if filenames:
        # Assuming there's only one file in the directory, you can take the first one
        filename = filenames[0]
        try:
            with open(filename, "rb") as file:
                data = pickle.load(file)
        except FileNotFoundError as e:
            print(e)
        return data 
    

def read_best_params(hem: str) -> pd.DataFrame:
    """Reads the best parameters from the array defined by current array for the specific hemisphere.
    
    Parameters
    ----------
    hem : str
        The hemisphere from which we want to get the best parameters.
        
    Returns
    -------
    best_params_df : pd.DataFrame
        As many rows as arrays with columns: [best_alpha, best_beta, best_offset_from_base, best_shank_length]
    """
    import glob
    RESULTS_PATH = "/home/odysseas/Desktop/UU/thesis/BayesianOpt/fsaverage_5_arrays_10x10x10/results/"
    dir = RESULTS_PATH + "fsaverage/" + hem + "/"
    filenames = glob.glob(os.path.join(dir, "*.csv"))
    if filenames:
        # Assuming there's only one file in the directory, you can take the first one
        filename = [file for file in filenames if "best" in file][0]
    else:
        filename = ""
    res_df = pd.read_csv(filename)
    best_params_df = res_df.loc[:, ["best_alpha", "best_beta", "best_offset_from_base", "best_shank_length"]]
    return best_params_df

In [3]:
def get_metrics_from_fsaverage_arr_contacts(current_array: int, best_params_hem_df: pd.DataFrame):
    """
    Parameters
    ----------
    current_array : int
        The current array number
    best_params_hem_df : pd.DataFrame
        The dataframe with the best parameters for all arrays
    """    
    global CONFIG
    # global variables defined in main
    global gm_mask, bin_thresh, target_density, hem
    global good_coords_V1, polar_map, ecc_map, sigma_map
    global best_dc, best_hd
    global start_location
    global total_contacts_xyz_moved, optimized_arrays_from_f_manual

    penalty = 0.25
    best_params_list = best_params_hem_df.iloc[current_array-1].tolist()
    print(f"best parameters for array {current_array}: {best_params_list}")
    alpha, beta, offset_from_base, shank_length = best_params_list[0], best_params_list[1], best_params_list[2], best_params_list[3]
    
    new_angle = (float(alpha), float(beta), 0)

    # create grid
    orig_grid = create_grid(start_location, shank_length, CONFIG["N_CONTACTPOINTS_SHANK"], CONFIG["N_COMBS"],
                            CONFIG["N_SHANKS_PER_COMB"], CONFIG["SPACING_ALONG_XY"], offset_from_origin=0)

    # implanting grid
    all_output = implant_grid(gm_mask, orig_grid, start_location, new_angle, offset_from_base)
    array_contacts, grid_valid_convex_hull = all_output[1], all_output[-1]
    
    print(f"grid valid convex hull: {grid_valid_convex_hull}")

    if current_array <= 1:
        total_contacts_so_far = array_contacts
        grid_valid = grid_valid_convex_hull
    else:
        total_contacts_so_far = np.hstack((total_contacts_xyz_moved, array_contacts))
        grid_valid_overlap = utils.get_overlap_validity(optimized_arrays_from_f_manual, array_contacts, CONFIG)
        grid_valid = grid_valid_overlap and grid_valid_convex_hull
    
    array_phosphenes = get_phosphenes(array_contacts, good_coords_V1, polar_map, ecc_map, sigma_map)    
    phosphenes_so_far = get_phosphenes(total_contacts_so_far, good_coords_V1, polar_map, ecc_map, sigma_map)

    array_phosphene_map = utils.get_phosphene_map(array_phosphenes, CONFIG)
    phosphene_map_so_far = utils.get_phosphene_map(phosphenes_so_far, CONFIG)

    # compute dice coefficient -> should be large -> invert cost
    arr_dice, im1, im2 = DC(target_density, array_phosphene_map, bin_thresh)

    if not grid_valid_convex_hull or arr_dice == 0.0:
        return [grid_valid_convex_hull, arr_dice]

    total_dice, _, _ = DC(target_density, phosphene_map_so_far, bin_thresh)
    par1_total = 1.0 - (CONFIG["A"] * total_dice)

    prop_total_dice = total_dice / best_dc

    # compute yield -> should be 1 -> invert cost
    arr_yield = get_yield(array_contacts, good_coords)

    total_yield = get_yield(total_contacts_so_far, good_coords)

    # compute Hellinger distance -> should be small -> keep cost
    arr_hd = hellinger_distance(array_phosphene_map.flatten(), target_density.flatten())
    total_hd = hellinger_distance(phosphene_map_so_far.flatten(), target_density.flatten())

    prop_total_hd = (1 - total_hd) / (1 - best_hd)

    ## validations steps
    ####### added the first conditional to prevent lower cost functions for empty array #######
    if arr_dice == 0.0 or np.isnan(phosphene_map_so_far).any() or np.sum(phosphene_map_so_far) == 0:
        par1_total = 1

    if np.isnan(arr_hd) or np.isinf(arr_hd):
        par3 = 1
    else:
        par3 = CONFIG["C"] * arr_hd

    if arr_dice == 0 or par3 == 1:
        arr_yield = 0
    par2 = 1.0 - (CONFIG["B"] * arr_yield)

    ####### added the first conditional to prevent lower cost functions for empty array #######
    if arr_hd == 1.0 or np.isnan(total_hd) or np.isinf(total_hd):
        par3_total = 1
    else:
        par3_total = CONFIG["C"] * total_hd

    # combine cost functions
    cost = par1_total + par2 + par3_total
    best_cost = best_dc + 0 + best_hd
    highest_cost = 3
    # when some contact points are outside of the hemisphere (convex), add penalty
    if not grid_valid_convex_hull:
        cost = par1_total + penalty + par2 + penalty + par3_total + penalty

    # check if cost contains invalid value
    if np.isnan(cost) or np.isinf(cost):
        cost = 3

    # the proportion of the cost compared to the best possible cost
    prop_cost = 1 - ((cost - best_cost) / (highest_cost - best_cost))

    return [grid_valid, arr_dice, total_dice, prop_total_dice, arr_hd, total_hd, 
            prop_total_hd, arr_yield, total_yield, cost, prop_cost,
            array_phosphenes, phosphenes_so_far, array_phosphene_map, phosphene_map_so_far,
            array_contacts, total_contacts_so_far]

In [4]:
CONFIG = utils.read_config("config.json")
print(f"configuration:\n {CONFIG['N_COMBS']} x {CONFIG['N_SHANKS_PER_COMB']} x {CONFIG['N_CONTACTPOINTS_SHANK']}")
# set file names
FNAME_ANG = "inferred_angle.mgz"
FNAME_ECC = "inferred_eccen.mgz"
FNAME_SIGMA = "inferred_sigma.mgz"
FNAME_APARC = "aparc+aseg.mgz"
FNAME_LABEL = "inferred_varea.mgz"
# set beta angle range according to hemisphere
dim2_lh = Integer(name="beta", low=-15, high=110)
dim2_rh = Integer(name="beta", low=-110, high=15)
RESULTS_PATH = "/home/odysseas/Desktop/UU/thesis/BayesianOpt/fsaverage_5_arrays_10x10x10/results/"
BEST_PARAMS_LH = read_best_params("LH")
BEST_PARAMS_RH = read_best_params("RH")
def implant_from_fsaverage(sub):
    global gm_mask, target_density, bin_thresh, hem
    global good_coords, good_coords_V1, polar_map, ecc_map, sigma_map
    global best_dc, best_hd
    global phosphenes_per_arr, start_location
    global total_contacts_xyz_moved, optimized_arrays_from_f_manual
    
    target_density = gvs.complete_gauss(windowsize=CONFIG["WINDOWSIZE"],
                                    fwhm=1200, radiusLow=0, radiusHigh=500, center=None, plotting=False)

    target_density /= target_density.max()
    target_density /= target_density.sum()
    bin_thresh = np.percentile(target_density, CONFIG["DC_PERCENTILE"])
    data_dir = f"/home/odysseas/Desktop/UU/thesis/BayesianOpt/input_processed_data_HCP/{sub}/T1w/mri/"
    # print(f"Loading mri scans for sub {sub}")

    # actually load data
    ang_img = nib.load(data_dir + FNAME_ANG)
    polar_map = ang_img.get_fdata()
    ecc_img = nib.load(data_dir + FNAME_ECC)
    ecc_map = ecc_img.get_fdata()
    sigma_img = nib.load(data_dir + FNAME_SIGMA)
    sigma_map = sigma_img.get_fdata()
    aparc_img = nib.load(data_dir + FNAME_APARC)
    aparc_roi = aparc_img.get_fdata()
    label_img = nib.load(data_dir + FNAME_LABEL)
    label_map = label_img.get_fdata()

    # compute valid voxels
    dot = (ecc_map * polar_map)
    good_coords = np.asarray(np.where(dot != 0.0))

    # filter gm per hemisphere
    cs_coords_rh = np.where(aparc_roi == 1021)
    cs_coords_lh = np.where(aparc_roi == 2021)
    gm_coords_rh = np.vstack(np.where((aparc_roi >= 1000) & (aparc_roi < 2000)))
    gm_coords_lh = np.vstack(np.where(aparc_roi > 2000))
    xl, yl, zl = get_xyz(gm_coords_lh)
    xr, yr, zr = get_xyz(gm_coords_rh)
    gm_lh = np.array([xl, yl, zl]).T
    gm_rh = np.array([xr, yr, zr]).T

    # extract labels
    v1_coords_rh = np.asarray(np.where(label_map == 1))
    v1_coords_lh = np.asarray(np.where(label_map == 1))

    set_rounded_good_coords = set(map(tuple, good_coords.T))
    set_rounded_gm_coords_rh = set(map(tuple, gm_coords_rh.T))
    set_rounded_gm_coords_lh = set(map(tuple, gm_coords_lh.T))
    set_rounded_v1_coords_lh = set(map(tuple, v1_coords_lh.T))
    set_rounded_v1_coords_rh = set(map(tuple, v1_coords_rh.T))

    # divide V1 coords per hemisphere
    good_coords_lh = np.array(list(set(set_rounded_good_coords) & set(set_rounded_gm_coords_lh))).T
    good_coords_rh = np.array(list(set(set_rounded_good_coords) & set(set_rounded_gm_coords_rh))).T
    v1_coords_lh = np.array(list(set(set_rounded_v1_coords_lh) & set(set_rounded_gm_coords_lh))).T
    v1_coords_rh = np.array(list(set(set_rounded_v1_coords_rh) & set(set_rounded_gm_coords_rh))).T
    
    # find center of left and right calcarine sulci
    median_lh = [np.median(cs_coords_lh[0][:]), np.median(cs_coords_lh[1][:]), np.median(cs_coords_lh[2][:])]
    median_rh = [np.median(cs_coords_rh[0][:]), np.median(cs_coords_rh[1][:]), np.median(cs_coords_rh[2][:])]

    # get GM mask and compute dorsal/posterior planes
    gm_mask = np.where(aparc_roi != 0)

    # apply optimization to each hemisphere
    for (gm_mask, hem, start_location, best_params_hem_df, good_coords, good_coords_V1, dim2) in zip([gm_lh, gm_rh], ["LH", "RH"], 
                                                                                                 [median_lh, median_rh], 
                                                                                                 [BEST_PARAMS_LH, BEST_PARAMS_RH],
                                                                                                 [good_coords_lh, good_coords_rh], 
                                                                                                 [v1_coords_lh, v1_coords_rh], 
                                                                                                 [dim2_lh, dim2_rh]):

        print(f"SUBJECT {sub}, HEMISPHERE {hem}")
        utils.create_dirs(RESULTS_PATH, sub, hem)
        best_possible_phos = get_phosphenes(good_coords_V1, good_coords_V1, polar_map, ecc_map, sigma_map)
        best_possible_map = utils.get_phosphene_map(best_possible_phos, CONFIG)

        visualizations.visualize_phosphene_maps({}, best_possible_map, RESULTS_PATH, sub, hem, CONFIG, best=True, show=False, save=True)
        visualizations.visualize_polar_plot(best_possible_phos, RESULTS_PATH, sub, hem, CONFIG, best=True, show=False, save=True)
        visualizations.visualize_kde_polar_plot(best_possible_phos, RESULTS_PATH, sub, hem, CONFIG, best=True, show=False, save=True)

        best_dc, _, _ = DC(target_density, best_possible_map, bin_thresh)
        best_hd = hellinger_distance(best_possible_map.flatten(), target_density.flatten())

        total_contacts_xyz_moved = None
        phosphenes_so_far = None
        phosphene_map_so_far = None
        optimized_arrays_from_f_manual = {}
        phosphenes_per_arr = {}
        phosphene_map_per_arr = {}
        out_df_best_results = pd.DataFrame()
        arr_current = 1
        for i in range(1, CONFIG["N_ARRAYS"] + 1):
            data = get_metrics_from_fsaverage_arr_contacts(current_array=i, best_params_hem_df=best_params_hem_df)
            grid_valid, arr_dice = data[0], data[1]

            # print(f"The best configuration for array {arr_cur} is {'valid' if grid_valid else 'invalid'}")

            if not grid_valid or arr_dice == 0.0:
                # print(f"should skip array {i} because it is invalid or phosphene map is empty")
                continue

            (total_dice, prop_total_dice, arr_hd, total_hd, prop_total_hd, 
             arr_yield, total_yield, cost, prop_cost,
             array_phosphenes, phosphenes_so_far, array_phosphene_map, phosphene_map_so_far, 
             array_contacts, total_contacts_xyz_moved) = data[2:]
            
            phosphenes_per_arr[arr_current] = array_phosphenes
            phosphene_map_per_arr[arr_current] = array_phosphene_map
            optimized_arrays_from_f_manual[arr_current] = array_contacts

            res = best_params_hem_df.iloc[i-1].tolist()   # [alpha, beta, offset, shank_length]

            # visualizations.visualize_array_map(array_phosphene_map, phosphene_map_so_far, hem)
            print("*" * 35)
            print("FINISHED ARRAY", arr_current)
            print("*" * 35)

            df_best = utils.get_best_df_10x10x10(arr_current, arr_dice, total_dice, prop_total_dice, arr_yield, total_yield,
                                  arr_hd, total_hd, prop_total_hd, cost, prop_cost, res)

            out_df_best_results = pd.concat([out_df_best_results, df_best], axis=0, ignore_index=True)
            arr_current += 1
        
        if len(phosphene_map_per_arr) > 0:
            pickle_data = [optimized_arrays_from_f_manual, phosphenes_per_arr, phosphene_map_per_arr,
                           total_contacts_xyz_moved, phosphenes_so_far, phosphene_map_so_far]
    
            utils.write_results(out_df_best_results, RESULTS_PATH, sub, hem, "best")
            utils.write_results_pickle(RESULTS_PATH, sub, hem, pickle_data)
            utils.write_params(RESULTS_PATH, sub, hem, CONFIG)

            visualizations.visualize_phosphene_maps(phosphene_map_per_arr, phosphene_map_so_far, RESULTS_PATH, sub, hem, CONFIG, show=False, save=True)
            visualizations.visualize_polar_plot(phosphenes_so_far, RESULTS_PATH, sub, hem, CONFIG, show=False, save=True)
            visualizations.visualize_kde_polar_plot(phosphenes_so_far, RESULTS_PATH, sub, hem, CONFIG, show=False, save=True)
        else:
            print(f"NO RESULTS FOR SUB: {sub} and hem: {hem}")
            empty_df = utils.get_empty_df()
            utils.write_results(empty_df, RESULTS_PATH, sub, hem, "empty")

In [5]:
subj_list = os.listdir("/home/odysseas/Desktop/UU/thesis/BayesianOpt/input_processed_data_HCP/")
processed = os.listdir(RESULTS_PATH) + ["exp"]
subj_list = [sub for sub in subj_list if sub not in processed]
for num, sub in enumerate(subj_list):
    print(f"now at {num} of {len(subj_list)}")
    implant_from_fsaverage(sub)

In [4]:
from src.utils import check_all_files

arr_5 = "/home/odysseas/Desktop/UU/thesis/BayesianOpt/5_arrays_10x10x10/results/"
arr_16 = "/home/odysseas/Desktop/UU/thesis/BayesianOpt/16_arrays_1x10x10/results/"
arr_avg = "/home/odysseas/Desktop/UU/thesis/BayesianOpt/fsaverage_5_arrays_10x10x10/results/"
all_files, file_sizes, empty_files = check_all_files(arr_avg)

In [4]:
len(all_files)