# Table 1: Test Retest stuff

## Manuscript Information
 
"Active Mutual Conjoint Estimation of Multiple Contrast Sensitivity
Functions"
Dom CP Marticorena, Quinn Wai Wong, Jake Browning, Ken Wilbur, Pinakin Davey, Aaron R. Seitz, Jacob R. Gardner, Dennis L. Barbour
_Journal of Vision_

[link to paper or preprint]

## Lab and Institution Information

NeuroMedical Informatics Lab  
Washington University in St. Louis

## Figure Description

Test retest stuff

## References

[references]


## Imports

In [1]:
# Import libraries
import os
import sys
import torch
import gpytorch as gp

import math
import numpy as np

import datetime
import json
import pickle
import tqdm

import matplotlib.pyplot as plt
import matplotlib.lines as mlines
from matplotlib.gridspec import GridSpec
import seaborn as sns

# need access to root directory to import utils
parent_dir = os.path.dirname(os.path.abspath(''))
if parent_dir not in sys.path:
    sys.path.append(parent_dir)
    
from utility.utils import *
from QuickCSF import QuickCSF, simulate

### Check versions

In [2]:
print("python version -->>", sys.version.split(" ")[0], "(expected 3.10.9)")
print("gpytorch version -->>", gp.__version__, "(expected 1.8.1)")
print("pytorch version -->>", torch.__version__, "(expected 1.13.1)")

python version -->> 3.10.9 (expected 3.10.9)
gpytorch version -->> 1.8.1 (expected 1.8.1)
pytorch version -->> 1.13.1 (expected 1.13.1)


### Run-time flags

In [3]:
qcsf_train_mode = False
train_mode = False          # create new data? Set to false if plotting existing results
verbose_mode = True         # print verbose analyses?
scrn_mode = True            # plot on screen?
save_results_mode = True    # save results to file?
save_plots_mode = True      # save plots to directory?

### Variables

In [4]:
# get data selection
jigo_file = '../data/raw/representative_curves.json'
representative_curves = load_json_from_file(jigo_file)
all_phenotypes = {
    'Quantile 1': representative_curves['quantile1-pid2-cue1-ecc3'],
    'Quantile 2': representative_curves['quantile2-pid1-cue1-ecc0'],
    'Quantile 3': representative_curves['quantile3-pid2-cue0-ecc2'],
    'Quantile 4': representative_curves['quantile4-pid3-cue0-ecc1'],
    'Quantile 5': representative_curves['quantile5-pid7-cue1-ecc3'],
    'Quantile 6': representative_curves['quantile6-pid0-cue0-ecc0'],
    'Quantile 7': representative_curves['quantile7-pid0-cue1-ecc0'],
    'Quantile 8': representative_curves['quantile8-pid1-cue1-ecc3'],
    'Quantile 9': representative_curves['quantile9-pid5-cue1-ecc0'],
    'Quantile 10': representative_curves['quantile10-pid8-cue0-ecc3'],
    'Quantile 11': representative_curves['quantile11-pid5-cue0-ecc1'],
    'Quantile 12': representative_curves['quantile12-pid8-cue1-ecc2'],
    'Quantile 13': representative_curves['quantile13-pid4-cue0-ecc1'],
    'Quantile 14': representative_curves['quantile14-pid7-cue0-ecc1'],
    'Quantile 15': representative_curves['quantile15-pid3-cue1-ecc2'],
    'Quantile 16': representative_curves['quantile16-pid0-cue0-ecc1'],
    'Quantile 17': representative_curves['quantile17-pid8-cue1-ecc0'],
    'Quantile 18': representative_curves['quantile18-pid8-cue0-ecc1'],
    'Quantile 19': representative_curves['quantile19-pid0-cue1-ecc1'],
    'Quantile 20': representative_curves['quantile20-pid4-cue0-ecc2']
}

def create_quantile_variables(num_quantiles):
    """
    Create a list of quantile names based on the number of quantiles.
    :param num_quantiles: Total number of quantiles.
    :return: List of quantile names.
    """
    return [f'Quantile {i+1}' for i in range(num_quantiles)]

# tasks and number of latents
num_latents = 2

# fixed unless something nutso happens 
num_tasks = num_latents

sampling_method = 'alternating'     # 'alternating' or 'unconstrained'
weight_decay = 1e-4
num_quantiles = 20

# Create "enums"
create_quantile_variables(num_quantiles)

# Configure for all unique pairs
num_pairs = comb(num_quantiles, 2) + num_quantiles

# choose to run multiple experiments with preset random seeds
# or a single experiment specifying your own random seeds
run_multiple_experiments = True

# list of zeros as long as num_latents
primer_random_seeds = [0, 0]
gp_random_seed = 0

# number of samples
num_halton_samples_per_task = 2
num_new_pts_per_task = 98

# choosing which figures to make
make_gp_gifs = True
make_entropy_gifs = False
make_hyper_plots = True

# directory to save plots and results
# will save to the path <save_dir_prefix>/<current_timestamp>
save_dir_prefix = 'analysis/Tables'

# Individual print flags 
print_training_hyperparameters = False
print_training_iters = False
print_progress_bar = False

# Set all to true if verbose_mode
if verbose_mode:
    print_training_hyperparameters = False
    print_training_iters = True
    print_progress_bar = False
    
# Create the bounds for the data
raw_freq_min = .5
raw_freq_max = 32
raw_contrast_min = 1e-3
raw_contrast_max = 1

x_tick_labels = [.5, 2, 8, 32]
y_tick_labels = [1, 0.1, 0.01, 0.001]

# Define how to transform the data
x_min = logFreq().forward(raw_freq_min)
x_max = logFreq().forward(raw_freq_max)
y_min = logContrast().forward(raw_contrast_max)  # max and min get flipped when inverting
y_max = logContrast().forward(raw_contrast_min)

# transform the data
def normalize_to_unit_range(d):
    return scale_data_within_range(d, (0, 1), x_min, x_max, y_min, y_max)

# marginal log resolutions of evaulation grid
x_resolution = 15  # 15 spatial frequencies per octave
y_resolution = 30  # 30 contrast units per decade

# for computing the proper prior threshold curve
psi_gamma  = 0.04  # guess rate is 4%
psi_lambda = 0.04 # lapse rate is 4%
psi_sigma = 0.08
sigmoid_type = 'logistic'

# training parameters?
num_initial_points_training_iters = 500
num_new_points_training_iters = 150
train_on_all_points_iters = 1500
sampling_strategy = 'active'
mean_module = 'constant_mean'
train_on_all_points_after_sampling = False
calculate_rmse = True
calculate_posterior = True
calculate_entropy = True

# GP hyperparameters?
learning_rate = .125
beta_for_regularization = .5
min_lengthscale = .15 # Note this changed from .2

# Set raw ghost points
raw_ghost_frequency = np.array([1, 2, 4, 8, 16, 32, 64, 128])
raw_ghost_contrast = np.array([5e-4, 5e-4, 5e-4, 5e-4, 5e-4, 5e-4, 5e-4, 1])

### Set Timestamp to Save To

In [5]:
# get current timestamp
timestamp = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
save_dir = f"{save_dir_prefix}/{timestamp}/"
ensure_directory_exists(save_dir)
print(f"Saving to... {save_dir}")

Saving to... analysis/Tables/2024-05-15_10-06-49/


### Automatically calculated variables

In [6]:
# create common variables shared among all tasks
grid, xx, yy, xs, ys = create_evaluation_grid_resolution(
        x_min, x_max, y_min, y_max, x_resolution, y_resolution)

grid_transformed = transform_dataset(grid, phi=normalize_to_unit_range)
def create_cubic_spline(curve):
    """
  This method creates a cubic spline to approximate the given curve.
  :param curve: A nx2 numpy matrix. First column is x values. Second column is y values.
  :return: The cubic spline.
  """
    x = curve[:, 0]
    y = curve[:, 1]

    cs = CubicSpline(x, y, extrapolate=False)
    return cs

# function to get ground truth curves
def get_spline(curve):
    """
    Modifies the curve by fitting a parabola through the maximum and the last points, then creates a cubic spline.
    The parabola's x-intercept is added to the curve if it is to the right of the last point.
    If the x-intercept is not to the right, the curve is not modified.
    :param curve: A list of [x, y] points.
    :return: The cubic spline.
    """
    curve = np.array(curve)
    curve[:, 0] = (np.log2(10) * curve[:, 0]) - np.log2(0.125)

    # Find the max y value and its corresponding x value
    max_y = np.max(curve[:, 1])
    max_y_index = np.argmax(curve[:, 1])
    x_max_y = curve[max_y_index, 0]

    # Use the last real point of the curve
    x_last, y_last = curve[-1, 0], curve[-1, 1]

    # Fit the parabola through the max and last point
    a = (y_last - max_y) / ((x_last - x_max_y)**2)
    coeffs = [a, -2*a*x_max_y, a*x_max_y**2 + max_y]
    roots = np.roots(coeffs)
    x_intercept = roots[0]  # Choose the first root

    # Add the intercept point to the curve only if it's to the right of the last point
    if x_intercept > x_last:
        curve = np.vstack([curve, [x_intercept, 0]])
        curve = curve[curve[:, 0].argsort()]  # Sorting by x values

    # Create and return the cubic spline
    cs = create_cubic_spline(curve)
    return cs

# get number of actively learned points for conjoint
num_pts_per_task = num_new_pts_per_task + num_halton_samples_per_task
num_new_conjoint_pts = num_new_pts_per_task * num_tasks
num_conjoint_pts = num_pts_per_task * num_tasks

# create ghost points and labels
ghost_x1 = logFreq().forward(raw_ghost_frequency)
ghost_x2 = logContrast().forward(raw_ghost_contrast)
assert len(ghost_x1) == len(ghost_x2), "x1 and x2 have diff lengths"

ghost_X = np.vstack((ghost_x1, ghost_x2)).T
ghost_y = np.array([0]*len(ghost_x2))

# create disjoint initial primer points
halton_X = get_halton_samples(xx, yy, num_halton_samples_per_task)
initial_disjoint_X = np.vstack((ghost_X, halton_X))

# create conjoint initial primer points and task indices
# simulated Halton y labels are created later for each experiment
halton_Xs = np.repeat(halton_X, num_tasks, axis=0)

num_ghost_points_per_task = len(ghost_y)
halton_task_indices = torch.arange(num_tasks).tile((num_halton_samples_per_task))

ghost_Xs = np.tile(ghost_X, (num_tasks, 1))
ghost_ys = np.tile(ghost_y, num_tasks)
ghost_task_indices = torch.arange(num_tasks).repeat_interleave(num_ghost_points_per_task)

num_disjoint_ghost_points = len(ghost_x2)
num_conjoint_ghost_points = len(ghost_task_indices)

# create initial dataset using ghost and halton samples
initial_Xs = np.vstack((ghost_Xs, halton_Xs))
initial_task_indices = torch.cat((ghost_task_indices, halton_task_indices))

# unique random seeds for halton samples and training
primer_seeds_list = [(np.arange(num_tasks) + num_tasks*i).tolist() for i in range(num_pairs)]
gp_seed_list = [(np.arange(num_tasks) + num_tasks*i).tolist() for i in range(num_pairs)]

## Model Training

In [7]:
figure_dict = {
    "num_latents": num_latents,
    "sampling_method": sampling_method,
    "num_pairs": num_pairs,
    "num_halton_samples_per_task": num_halton_samples_per_task,
    "num_new_pts_per_task": num_new_pts_per_task,
    "raw_freq_min": raw_freq_min,
    "raw_freq_max": raw_freq_max,
    "raw_contrast_min": raw_contrast_min,
    "raw_contrast_max": raw_contrast_max,
    "x_resolution": x_resolution,
    "y_resolution": y_resolution,
    "raw_ghost_frequency": raw_ghost_frequency.tolist(),
    "raw_ghost_contrast": raw_ghost_contrast.tolist(),
    "primer_seeds_list": primer_seeds_list,
    "gp_seed_list": gp_seed_list,
    "weight_decay": weight_decay,
    "min_lengthscale": min_lengthscale,
    "psi_sigma": psi_sigma,
    "sigmoid_type": sigmoid_type,
    "psi_gamma": psi_gamma,
    "psi_lambda": psi_lambda,
    "lr": learning_rate,
    "num_initial_training_iters": num_initial_points_training_iters,
    "num_new_points_training_iters": num_new_points_training_iters,
    "beta_for_regularization": beta_for_regularization,
    "train_on_all_points_after_sampling": train_on_all_points_after_sampling,
    "print_training_hyperparameters": print_training_hyperparameters,
    "print_training_iters": print_training_iters,
    "progress_bar": print_progress_bar,
    "calculate_rmse": calculate_rmse,
    "calculate_entropy": calculate_entropy,
    "calculate_posterior": calculate_posterior
}

if train_mode:
    ensure_directory_exists(save_dir)
    with open(save_dir + "run_configs.json", 'w') as file:
        json.dump(figure_dict, file, indent=2)

### Disjoint

In [8]:
if train_mode:
    for run in range(2):
        print(f"Run: {run}")
        disjoint_results_dicts = [{} for _ in range(num_tasks)]
        disjoint_gp_lists = []

        all_phenotypes_keys = list(all_phenotypes.keys()) # TODO  # Getting all the keys from your dictionary

        # Using itertools to get all unique combinations of length 2 (pairs)
        phenotype_pairs = combinations_with_replacement(all_phenotypes_keys, 2)

        # initialize pair index to iterate using actual numbers
        pair_index = 0

        # Now iterate over the first 'num_pairs' pairs
        for pair in phenotype_pairs:

            phenotype_pair = [(pheno, all_phenotypes[pheno]) for pheno in pair]

            print(f"Pair {pair}")

            # get unique random seeds for each exp
            run_seed_modifier = 8888*run
            primer_seeds = [s + run_seed_modifier for s in primer_seeds_list[pair_index]]
            gp_seeds = [s + run_seed_modifier for s in gp_seed_list[pair_index]]

            for i, (pheno, _) in enumerate(phenotype_pair):

                print(pheno)

                ground_truths = [get_spline(pheno_data) for _, pheno_data in phenotype_pair]

                cs = ground_truths[i]
                primer_seed = primer_seeds[i]
                gp_seed = gp_seeds[i]

                # get initial primer labels
                set_random_seed(primer_seed)
                halton_y = simulate_labeling(halton_X[:,0], halton_X[:,1], cs, 0, 0, sigmoid_type=sigmoid_type, psi_sigma=psi_sigma) # changed  psi_gamma, psi_lambda to 0, 0
                initial_y = np.hstack((ghost_y, halton_y))

                # run active learning
                set_random_seed(gp_seed)

                model, likelihood, X, y, rmse_list, _, posterior_list, _ = sample_and_train_gp(
                    cs,
                    grid,
                    xx,
                    yy,
                    sampling_strategy=sampling_strategy,
                    mean_module_name=mean_module,
                    psi_sigma=psi_sigma,
                    sigmoid_type=sigmoid_type,
                    psi_gamma=psi_gamma,
                    psi_lambda=psi_lambda,
                    lr=learning_rate,
                    num_initial_training_iters=num_initial_points_training_iters,
                    num_new_points_training_iters=num_new_points_training_iters,
                    num_new_points=num_new_pts_per_task,
                    beta_for_regularization=beta_for_regularization,
                    train_on_all_points_after_sampling=train_on_all_points_after_sampling,
                    train_on_all_points_iters=train_on_all_points_iters,
                    phi=normalize_to_unit_range,
                    print_training_hyperparameters=print_training_hyperparameters,
                    print_training_iters=print_training_iters,
                    progress_bar=print_progress_bar,
                    min_lengthscale=min_lengthscale,
                    calculate_rmse=calculate_rmse,
                    calculate_entropy=calculate_entropy,
                    calculate_posterior=calculate_posterior,
                    initial_Xs=initial_disjoint_X,
                    initial_ys=initial_y,
                    num_ghost_points=num_disjoint_ghost_points,
                    weight_decay=weight_decay
                )


                zz = evaluate_posterior_mean(model, likelihood, grid_transformed) \
                        .reshape(xx.shape)
                
                level = (1 - psi_lambda + psi_gamma) / 2
                GP_level_curve_y_values_on_grid_cols = []
                for intermediate_model, intermediate_likelihood in posterior_list:
                    intermediate_zz = evaluate_posterior_mean(
                        intermediate_model, 
                        intermediate_likelihood, 
                        grid_transformed
                    ).reshape(xx.shape)
                    
                    zzmin = (intermediate_zz[:, :] - level) ** 2
                    level_curve_indices = np.int64(np.argmin(zzmin, axis=0))
                    level_curve = yy[:, 0][level_curve_indices[:]]
                    GP_level_curve_y_values_on_grid_cols.append(level_curve)
                    
                ground_truth_y_values_on_grid_cols = cs(xx[0, :])

                disjoint_results_dicts[i][pair_index] = {
                    'training_seed': gp_seed,
                    'random_seed': primer_seed,
                    'X': X,
                    'y': y,
                    'zz': zz,
                    'rmse_list': rmse_list,
                    'GP_level_curve_y_values_on_grid_cols': GP_level_curve_y_values_on_grid_cols,
                    'ground_truth_y_values_on_grid_cols': ground_truth_y_values_on_grid_cols
                }

                gif_dict = {
                    'xx': xx,
                    'yy': yy,
                    'X': X,
                    'y': y,
                    'cs': cs,
                    'psi_sigma': psi_sigma,
                    'sigmoid_type': sigmoid_type,
                    'psi_gamma': psi_gamma,
                    'psi_lambda': psi_lambda,
                    'x_min': x_min,
                    'x_max': x_max,
                    'y_min': y_min,
                    'y_max': y_max,
                    'xs': xs,
                    'ys': ys,
                    'grid': grid,
                    'f': normalize_to_unit_range,
                    'posterior_list': posterior_list,
                }

                # create gif of the task
                if make_gp_gifs and save_plots_mode:
                    ntitle = pheno.replace(' ', '_')
                    gif_path = f'{save_dir}run_{run}/task{i}_{ntitle.lower()}/{pair_index}/'
                    ensure_directory_exists(gif_path)
                    create_and_save_plots(gif_dict, gif_path, ntitle, start_index=num_disjoint_ghost_points,
                                xticks_labels=x_tick_labels, yticks_labels=y_tick_labels)
                    create_gif(gif_path)

                gp_list = [model for model, _ in posterior_list]
                disjoint_gp_lists.append(gp_list)
                print()

            # Increment the index for the next iteration
            pair_index += 1
            print(pair_index)

            if save_results_mode:
                full_save_dir = f'{save_dir}run_{run}'
                ensure_directory_exists(full_save_dir)
                with open(f'{full_save_dir}/disjoint_results.pkl', 'wb') as file:
                    pickle.dump(disjoint_results_dicts, file)


In [9]:
# a = disjoint_results_dicts
# # a[task index][pair index][dict key]
# print(a[0][0].keys())
# # ['GP_level_curve_y_values_on_grid_cols'][num points collected][x value]
# print(len(a[0][0]['GP_level_curve_y_values_on_grid_cols']))
# print(a[0][0]['GP_level_curve_y_values_on_grid_cols'][7])
# print(a[0][0]['ground_truth_y_values_on_grid_cols'])

In [10]:
# for i in range(11):
    
#     yvals = a[0][0]['GP_level_curve_y_values_on_grid_cols'][i]
#     GT = a[0][0]['ground_truth_y_values_on_grid_cols']
    
#     # Generate x values based on the index of yvals
#     x = np.arange(len(yvals))

#     # Plot the curve
#     plt.figure(figsize=(8, 6))  # Set the figure size
#     plt.plot(x, yvals, label='yvals')  # Plot x vs. yvals with a label for the legend
#     plt.plot(x, GT, label='GT')
#     plt.xlabel('Index')  # Set the x-axis label
#     plt.ylabel('y-values')  # Set the y-axis label
#     plt.title('Plot of y-values vs. Index')  # Set the title of the plot
#     plt.legend()  # Show the legend
#     plt.grid(True)  # Show a grid

#     # Customize the tick labels for x and y axes
#     plt.yticks([0, 1, 2, 3])

#     plt.show()  # Display the plot

### Conjoint

In [11]:
if train_mode:
    for run in range(2):
        print(f"Run: {run}")
        conjoint_results_dicts = {}
        conjoint_gp_lists = []

        all_phenotypes_keys = list(all_phenotypes.keys()) # TODO # Getting all the keys from your dictionary

        # Using itertools to get all unique combinations of length 2 (pairs)
        phenotype_pairs = combinations_with_replacement(all_phenotypes_keys, 2)

        # initialize pair index to iterate using actual numbers
        pair_index = 0

        # Now iterate over the first 'num_pairs' pairs
        for pair in phenotype_pairs:

            phenotype_pair = [(pheno, all_phenotypes[pheno]) for pheno in pair]

            print(f"Pair {pair}")

            # get unique random seeds for each exp
            run_seed_modifier = 8888*run
            primer_seeds = [s + run_seed_modifier for s in primer_seeds_list[pair_index]]
            gp_seeds = [s + run_seed_modifier for s in gp_seed_list[pair_index]]
            gp_seed = gp_seeds[0]

            # get unique halton labels for each experiment
            halton_y_list = []

            for i, _ in enumerate(phenotype_pair):
                ground_truths = [get_spline(pheno_data) for _, pheno_data in phenotype_pair]
                cs = ground_truths[i]
                set_random_seed(primer_seeds[i])
                halton_y = simulate_labeling(halton_X[:,0], halton_X[:,1], cs, 0, 0, sigmoid_type=sigmoid_type, psi_sigma=psi_sigma)  # changed  psi_gamma, psi_lambda to 0, 0
                halton_y_list.append(halton_y)

            # stack all task halton labels then stack ghost with halton labels
            halton_ys = np.array([y for y_per_halton in zip(*halton_y_list) for y in y_per_halton])
            initial_ys = np.hstack((ghost_ys, halton_ys))

            # run active learning
            set_random_seed(gp_seed)

            model, likelihood, X, y, task_indices, rmse_list, entropy_list, posterior_list, _ = sample_and_train_gp_conjoint(
                cs=ground_truths,
                grid=grid,
                xx=xx,
                yy=yy,
                psi_sigma=psi_sigma,
                psi_gamma=psi_gamma,
                psi_lambda=psi_lambda,
                lr=learning_rate,
                num_initial_training_iters=num_initial_points_training_iters,
                num_new_points_training_iters=num_new_points_training_iters,
                num_new_points=num_new_conjoint_pts,
                beta_for_regularization=beta_for_regularization,
                phi=normalize_to_unit_range,
                print_training_hyperparameters=print_training_hyperparameters,
                print_training_iters=print_training_iters,
                train_on_all_points_after_sampling=train_on_all_points_after_sampling,
                train_on_all_points_iters=train_on_all_points_iters,
                min_lengthscale=min_lengthscale,
                initial_Xs=initial_Xs,
                initial_ys=initial_ys,
                sampling_strategy=sampling_strategy,
                num_ghost_points=num_conjoint_ghost_points,
                calculate_rmse=calculate_rmse,
                calculate_entropy=calculate_entropy,
                calculate_posterior=calculate_posterior,
                progress_bar=print_progress_bar,
                num_tasks=num_tasks,
                num_latents=num_latents,
                task_indices=initial_task_indices,
                sampling_method=sampling_method,
                weight_decay=weight_decay
            )

            zz = evaluate_posterior_mean(model, likelihood, grid_transformed) \
                .reshape((*xx.shape, num_tasks))
            
            level = (1 - psi_lambda + psi_gamma) / 2
            GP_level_curve_y_values_on_grid_cols = [list(), list()]
            for intermediate_model, intermediate_likelihood in posterior_list:
                intermediate_zz = evaluate_posterior_mean(
                    intermediate_model, 
                    intermediate_likelihood, 
                    grid_transformed
                ).reshape((*xx.shape, num_tasks))
                
                for task_index in range(num_tasks):
                    zzmin = (intermediate_zz[:,:,task_index] - level) ** 2
                    level_curve_indices = np.int64(np.argmin(zzmin, axis=0))
                    level_curve = yy[:, 0][level_curve_indices[:]]
                    GP_level_curve_y_values_on_grid_cols[task_index].append(level_curve)
                                
            ground_truth_y_values_on_grid_cols = []
            for task_index in range(num_tasks):
                ground_truth_y_values_on_grid_cols.append(ground_truths[task_index](xx[0, :]))

            conjoint_results_dicts[pair_index] = {
                'training_seed': gp_seed,
                'random_seeds': primer_seeds,
                'X': X,
                'y': y,
                'zz': zz,
                'task_indices': task_indices,
                'entropy_list': entropy_list,
                'rmse_list': rmse_list,
                'GP_level_curve_y_values_on_grid_cols': GP_level_curve_y_values_on_grid_cols,
                'ground_truth_y_values_on_grid_cols': ground_truth_y_values_on_grid_cols
            }

            gp_list = [model for model, _ in posterior_list]
            conjoint_gp_lists.append(gp_list)
            print()

            # Increment the index for the next iteration
            pair_index += 1
            print(pair_index)

            if save_results_mode:
                full_save_dir = f'{save_dir}run_{run}'
                ensure_directory_exists(save_dir)
                with open(f'{full_save_dir}/conjoint_results.pkl', 'wb') as file:
                    pickle.dump(conjoint_results_dicts, file)


In [12]:
# a = conjoint_results_dicts
# # a[pair index][dict key][task index]
# print(a[0].keys())
# # ['GP_level_curve_y_values_on_grid_cols'][num points sampled][x value]
# print(len(a[0]['GP_level_curve_y_values_on_grid_cols'][0]))
# print(a[0]['GP_level_curve_y_values_on_grid_cols'][0][7])
# print(a[0]['ground_truth_y_values_on_grid_cols'][0])

In [13]:
# for i in range(21):
    
#     yvals = a[0]['GP_level_curve_y_values_on_grid_cols'][0][i]
#     GT = a[0]['ground_truth_y_values_on_grid_cols'][0]
#     # Generate x values based on the index of yvals
#     x = np.arange(len(yvals))

#     # Plot the curve
#     plt.figure(figsize=(8, 6))  # Set the figure size
#     plt.plot(x, yvals, label='yvals')  # Plot x vs. yvals with a label for the legend
#     plt.plot(x, GT, label='GT')
#     plt.xlabel('Index')  # Set the x-axis label
#     plt.ylabel('y-values')  # Set the y-axis label
#     plt.title('Plot of y-values vs. Index')  # Set the title of the plot
#     plt.legend()  # Show the legend
#     plt.grid(True)  # Show a grid

#     # Customize the tick labels for x and y axes
#     plt.yticks([0, 1, 2, 3])

#     plt.show()  # Display the plot

### qCSF

In [14]:
if qcsf_train_mode:
    for run in range(2):
        print(f"Run: {run}")
        qcsf_results_dicts = [{} for _ in range(num_tasks)]
        qcsf_gp_lists = []

        all_phenotypes_keys = list(all_phenotypes.keys()) # TODO  # Getting all the keys from your dictionary

        # Using itertools to get all unique combinations of length 2 (pairs)
        phenotype_pairs = combinations_with_replacement(all_phenotypes_keys, 2)

        # initialize pair index to iterate using actual numbers
        pair_index = 0

        # Now iterate over the first 'num_pairs' pairs
        for pair in phenotype_pairs:

            phenotype_pair = [(pheno, all_phenotypes[pheno]) for pheno in pair]

            print(f"Pair {pair}")

            # get unique random seeds for each exp
            run_seed_modifier = 8888*run
            primer_seeds = [s + run_seed_modifier for s in primer_seeds_list[pair_index]]

            for i, (pheno, _) in enumerate(phenotype_pair):

                print(pheno)

                ground_truths = [get_spline(pheno_data) for _, pheno_data in phenotype_pair]

                cs = ground_truths[i]
                primer_seed = primer_seeds[i]

                # get initial primer labels
                set_random_seed(primer_seed)

                simulationParams = {
                    'trials': 100,
                    'stimuli': {
                        'minContrast': raw_contrast_min,
                        'maxContrast': raw_contrast_max,
                        'contrastResolution': xs,
                        'minFrequency': raw_freq_min,
                        'maxFrequency': raw_freq_max,
                        'frequencyResolution': ys,
                    },
                    'parameters': None,
                    'd': 1,  # guess rate used by qcsf model           
                    'psiGamma': 0,
                    'psiLambda': 0,
                    'psiSigma': psi_sigma,
                    "sigmoidType": sigmoid_type,
                    'showPlots': False  # prevents plt from plotting every run 
                }

                rmses, times, params, predictions = simulate.runSimulation(
                    trueThresholdCurve=np.array(phenotype_pair[i][1]),
                    **simulationParams,
                    return_intermediate_predictions=True
                )        
                
                qcsf_y_values_on_grid_cols = [[0 for _ in range(predictions[0].shape[0])]] + predictions
                    
                ground_truth_y_values_on_grid_cols = cs(xx[0, :])

                qcsf_results_dicts[i][pair_index] = {
                    'random_seed': primer_seed,
                    'rmse_list': rmses,
                    'GP_level_curve_y_values_on_grid_cols': qcsf_y_values_on_grid_cols,
                    'ground_truth_y_values_on_grid_cols': ground_truth_y_values_on_grid_cols
                }

            # Increment the index for the next iteration
            pair_index += 1
            print(pair_index)

            if save_results_mode:
                full_save_dir = f'{save_dir}run_{run}'
                ensure_directory_exists(full_save_dir)
                with open(f'{full_save_dir}/qcsf_results.pkl', 'wb') as file:
                    pickle.dump(qcsf_results_dicts, file)


In [15]:
# if qcsf_train_mode:
    
#     qcsf_results = dict() # index like this: qcsf_results['Quantile 1']['rmses'], this will return a list of rmse values

#     all_phenotypes_keys = list(all_phenotypes.keys())  # Getting all the keys from your dictionary

    
#     for i in range (21): # each pheno appears 21 times in 210 pairs
#         set_random_seed(i)
#         for pheno in all_phenotypes_keys:

#             k = pheno + ' - ' + str(i)
#             pheno_data = np.array(all_phenotypes[pheno])

#             print(k)

#             simulationParams = {
#                 'trials': 100,
#                 'stimuli': {
#                     'minContrast': raw_contrast_min,
#                     'maxContrast': raw_contrast_max,
#                     'contrastResolution': xs,
#                     'minFrequency': raw_freq_min,
#                     'maxFrequency': raw_freq_max,
#                     'frequencyResolution': ys,
#                 },
#                 'parameters': None,
#                 'd': 1,  # guess rate used by qcsf model           
#                 'psiGamma': 0,
#                 'psiLambda': 0,
#                 'psiSigma': psi_sigma,
#                 "sigmoidType": sigmoid_type,
#                 'showPlots': False  # prevents plt from plotting every run 
#             }

#             rmses, times, params = simulate.runSimulation(
#                 trueThresholdCurve=pheno_data,
#                 **simulationParams
#             )        

#             qcsf_results[k] = dict()
#             qcsf_results[k]['params'] = params
#             qcsf_results[k]['rmses'] = rmses


#             # save to disk
#             ensure_directory_exists(save_dir)
#             with open(save_dir + "qcsf_results.json", 'w') as file:
#                 json.dump(qcsf_results, file, indent=2)


## Load data

In [19]:
if train_mode:
    load_dir = save_dir
else:
    load_dir = 'C:/Repos/delete_me/Tables/data/'
#     load_dir = './analysis/Figure06/2024-04-29_23-22-15/'

if qcsf_train_mode:
    qcsf_load_dir = save_dir
else:
    qcsf_load_dir = 'C:/Repos/delete_me/Tables/QCSF/'

disjoint_runs = []
conjoint_runs = []
qcsf_runs = []
for run in range(2):
    
    with open(f'{load_dir}Run_{run}/disjoint_results.pkl', 'rb') as file:
        disjoint_runs.append(pickle.load(file))

    with open(f'{load_dir}Run_{run}/conjoint_results.pkl', 'rb') as file:
        conjoint_runs.append(pickle.load(file))
        
    with open(f'{qcsf_load_dir}run_{run}/qcsf_results.pkl', 'rb') as file:
        qcsf_runs.append(pickle.load(file))
    
    
    
    
# # load qcsf

# if qcsf_train_mode:
#     qcsf_load_dir = save_dir + "qcsf_results.json"
# else:
#     qcsf_load_dir = 'C:/Repos/delete_me/Figure06/2024-04-15_04-23-55/qcsf_results.json'

# with open(qcsf_load_dir, 'r') as file:
#     qcsf_results = json.load(file)


In [20]:

# # Sample data
# var = qcsf_runs[0][0][0]['GP_level_curve_y_values_on_grid_cols']

# import matplotlib.pyplot as plt

# # Sample data
# var = qcsf_runs[0][0][0]['GP_level_curve_y_values_on_grid_cols']
# # Plot each list in var on a separate plot
# for i, sublist in enumerate(var):
#     plt.figure()  # Create a new figure for each plot
#     plt.plot(sublist)
#     plt.xlabel('Index')
#     plt.ylabel('Value')
#     plt.title(f'Plot of List {i+1}')

# # Show all plots
# plt.show()



In [21]:
print('qcsf:')
print('qcsf_runs[run_idx][task_idx][pair_key][GP_level_curve_y_values_on_grid_cols][# sampled points (0-100)][grid_col_idx] = y-value')
print('qcsf_runs[run_idx][task_idx][pair_key][ground_truth_y_values_on_grid_cols][grid_col_idx] = y-value')
print()
print(f'Type(qcsf_runs)={type(qcsf_runs)}, Num runs: len(disjoint_runs)={len(qcsf_runs)}')
a = qcsf_runs[0]
print('a = qcsf_runs[0]')
print(f'Type(a)={type(a)}, Num tasks: len(a)={len(a)}')
print(f'Type(a[0])={type(a[0])}, Num pairs: len(list(a[0].keys()))={len(list(a[0].keys()))}')
print()
print(f'a[0].keys()={a[0].keys()}')
print()
print(f'a[0][0].keys()={a[0][0].keys()}')
print()
print('Num samples points for this task/pair:', len(a[0][0]['GP_level_curve_y_values_on_grid_cols']))

qcsf:
qcsf_runs[run_idx][task_idx][pair_key][GP_level_curve_y_values_on_grid_cols][# sampled points (0-100)][grid_col_idx] = y-value
qcsf_runs[run_idx][task_idx][pair_key][ground_truth_y_values_on_grid_cols][grid_col_idx] = y-value

Type(qcsf_runs)=<class 'list'>, Num runs: len(disjoint_runs)=2
a = qcsf_runs[0]
Type(a)=<class 'list'>, Num tasks: len(a)=2
Type(a[0])=<class 'dict'>, Num pairs: len(list(a[0].keys()))=210

a[0].keys()=dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131, 132,

In [22]:
print('Disjoint:')
print('disjoint_runs[run_idx][task_idx][pair_key][GP_level_curve_y_values_on_grid_cols][# sampled points (0-100)][grid_col_idx] = y-value')
print('disjoint_runs[run_idx][task_idx][pair_key][ground_truth_y_values_on_grid_cols][grid_col_idx] = y-value')
print()
print(f'Type(disjoint_runs)={type(disjoint_runs)}, Num runs: len(disjoint_runs)={len(disjoint_runs)}')
a = disjoint_runs[0]
print('a = disjoint_runs[0]')
print(f'Type(a)={type(a)}, Num tasks: len(a)={len(a)}')
print(f'Type(a[0])={type(a[0])}, Num pairs: len(list(a[0].keys()))={len(list(a[0].keys()))}')
print()
print(f'a[0].keys()={a[0].keys()}')
print()
print(f'a[0][0].keys()={a[0][0].keys()}')
print()
print('Num samples points for this task/pair:', len(a[0][0]['GP_level_curve_y_values_on_grid_cols']))

Disjoint:
disjoint_runs[run_idx][task_idx][pair_key][GP_level_curve_y_values_on_grid_cols][# sampled points (0-100)][grid_col_idx] = y-value
disjoint_runs[run_idx][task_idx][pair_key][ground_truth_y_values_on_grid_cols][grid_col_idx] = y-value

Type(disjoint_runs)=<class 'list'>, Num runs: len(disjoint_runs)=2
a = disjoint_runs[0]
Type(a)=<class 'list'>, Num tasks: len(a)=2
Type(a[0])=<class 'dict'>, Num pairs: len(list(a[0].keys()))=210

a[0].keys()=dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128,

In [23]:
print('Conjoint:')
print('conjoint_runs[run_idx][pair_key][GP_level_curve_y_values_on_grid_cols][task_idx][# sampled points (0-100)][grid_col_idx] = y-value')
print('conjoint_runs[run_idx][pair_key][ground_truth_y_values_on_grid_cols][task_idx][grid_col_idx] = y-value')
print()
print(f'Type(conjoint_runs)={type(conjoint_runs)}, Num runs: len(conjoint_runs)={len(conjoint_runs)}')
a = conjoint_runs[0]
print('a = conjoint_runs[0]')
print(f'Type(a[0])={type(a[0])}, Num pairs: len(list(a.keys()))={len(list(a.keys()))}')
print()
print(f'a[0].keys()={a[0].keys()}')
print()
print('Num samples points for this task/pair:', len(a[0]['GP_level_curve_y_values_on_grid_cols'][0]))
print('Num cols:', len(a[0]['GP_level_curve_y_values_on_grid_cols'][0][0]))

Conjoint:
conjoint_runs[run_idx][pair_key][GP_level_curve_y_values_on_grid_cols][task_idx][# sampled points (0-100)][grid_col_idx] = y-value
conjoint_runs[run_idx][pair_key][ground_truth_y_values_on_grid_cols][task_idx][grid_col_idx] = y-value

Type(conjoint_runs)=<class 'list'>, Num runs: len(conjoint_runs)=2
a = conjoint_runs[0]
Type(a[0])=<class 'dict'>, Num pairs: len(list(a.keys()))=210

a[0].keys()=dict_keys(['training_seed', 'random_seeds', 'X', 'y', 'zz', 'task_indices', 'entropy_list', 'rmse_list', 'GP_level_curve_y_values_on_grid_cols', 'ground_truth_y_values_on_grid_cols'])

Num samples points for this task/pair: 201
Num cols: 91


In [24]:
# keys_to_remove = [key for key in qcsf_results.keys() if key.startswith('Quantile 1') or key.startswith('Quantile 5') or key.startswith('Quantile 3')]

# for key in keys_to_remove:
#     del qcsf_results[key]

In [25]:
# for key in qcsf_results.keys():
#     print(key, qcsf_results[key]['rmses'][99])

## Create RMSE plot

### Formatting variables

In [26]:
# figure_width = 6.5  # inches
# figure_height = figure_width / 2  # inches

# dpi_val = 600              # graphics resolution
# plt.rcParams['font.family'] = 'sans-serif'

# title="RMSE Comparison"
# rmse_x_label = "Sample Count"
# rmse_y_label = "RMSE"

# legend_font_size = 8
# tick_font_size = 8
# label_font_size = 10
# title_font_size = 12

# # x_tick_labels = THESE ARE DEFINED AT THE BEGINNING
# # y_tick_labels = THESE ARE DEFINED AT THE BEGINNING

# DISJOINT, CONJOINT = "disjoint", "conjoint"
# std_transparency = 0.2

# rmse_x_ticks = [0, 20, 40, 60, 80, 100, 120, 140, 160, 180, 200]
# rmse_y_ticks = np.arange(0, .31, step=0.05)
# rmse_x_ticks_min, rmse_x_ticks_max = rmse_x_ticks[0], rmse_x_ticks[-1]
# rmse_y_ticks_min, rmse_y_ticks_max = rmse_y_ticks[0], rmse_y_ticks[-1]

# axis_tick_params = {'axis':'both', 'which':'major', 'direction':'out', 'length': 2}

# filename = 'Tables'


### Plotting

In [27]:
# for run in range(2):
    
#     print('################')
#     print(f'Run {run}')
#     print('###############')
    
#     conjoint_results_dicts = conjoint_runs[run]
#     disjoint_results_dicts = disjoint_runs[run]

#     num_pairs = len(conjoint_results_dicts.keys())

#     sample_count = np.arange(1, num_conjoint_pts + 1)

#     # Initialize arrays to store overall means and standard deviations
#     overall_mean = {DISJOINT: np.zeros(num_conjoint_pts), CONJOINT: np.zeros(num_conjoint_pts)}
#     overall_std = {DISJOINT: np.zeros(num_conjoint_pts), CONJOINT: np.zeros(num_conjoint_pts)}

#     # Calculating the overall mean and standard deviation for each condition
#     for condition, results_dicts, _ in [(DISJOINT, disjoint_results_dicts, 'blue'), (CONJOINT, conjoint_results_dicts, 'red')]:
#         all_rmses = []
#         for pair_idx in range(num_pairs):
#             for task_idx in range(num_tasks):
#                 if condition == DISJOINT:
#                     rmse_list = np.repeat(results_dicts[task_idx][pair_idx]['rmse_list'], num_tasks)
#                 elif condition == CONJOINT:
#                     rmse_list = results_dicts[pair_idx]['rmse_list'][task_idx]
#                 all_rmses.append(rmse_list)
#         overall_mean[condition] = np.mean(all_rmses, axis=0)
#         overall_std[condition] = np.std(all_rmses, axis=0)

#     # Calculate the average RMSE values for each key in qcsf_results
#     qcsf_rmses = []
#     for key in qcsf_results.keys():
#         rmses = qcsf_results[key]['rmses']
#         qcsf_rmses.append(rmses)
#     qcsf_mean = np.mean(qcsf_rmses, axis=0)
#     qcsf_mean = np.repeat(qcsf_mean[:100], 2) # create the staircase
#     qcsf_std = np.std(qcsf_rmses, axis=0)    

#     # Set up the plot
#     fig, ax = plt.subplots(figsize=(figure_width, figure_height))

#     # Plotting
#     for condition, color in [(DISJOINT, 'blue'), (CONJOINT, 'red'), ("QCSF", 'purple')]:
#         if condition == "QCSF":
#             mean = qcsf_mean
#             std = qcsf_std
#         else:
#             mean = overall_mean[condition]
#             std = overall_std[condition]
#         ax.plot(sample_count, mean, label=condition, color=color)
#         ax.fill_between(sample_count, mean + std, mean - std, alpha=std_transparency, color=color)

#     # Setting plot attributes
#     ax.tick_params(**axis_tick_params, labelsize=tick_font_size)
#     ax.set_xlim(rmse_x_ticks_min, rmse_x_ticks_max)
#     ax.set_ylim(rmse_y_ticks_min, rmse_y_ticks_max)
#     ax.set_title(title, fontsize=title_font_size)
#     ax.set_xlabel(rmse_x_label)
#     ax.set_ylabel(rmse_y_label, fontsize=label_font_size)
#     ax.legend(loc='upper right', fontsize=legend_font_size, frameon=False)
#     ax.tick_params(axis='both', labelsize=tick_font_size)
#     plt.setp(ax, xticks=rmse_x_ticks, yticks=rmse_y_ticks)

#     plt.gcf().subplots_adjust(bottom=0.15)

#     # Saving or showing the plot
#     if save_plots_mode: 
#         full_save_dir = f'{save_dir}run_{run}'
#         ensure_directory_exists(full_save_dir)
#         plt.savefig(f"{full_save_dir}/rmse_comparison.png", dpi=dpi_val)
#         plt.savefig(f"{full_save_dir}/rmse_comparison.pdf", dpi=dpi_val)
#     if scrn_mode: 
#         plt.show()
#     plt.clf()

## Make tables 1 and 2

In [28]:
GP_key = 'GP_level_curve_y_values_on_grid_cols'
ground_truth_key = 'ground_truth_y_values_on_grid_cols'

num_runs = len(conjoint_runs) # 2
num_pairs = len(conjoint_runs[0]) # 210
num_tasks = len(conjoint_runs[0][0][GP_key]) # 2
num_samples = len(conjoint_runs[0][0][GP_key][0]) # 201
num_grid_cols = len(conjoint_runs[0][0][GP_key][0][0]) # 91

'''
Format of data

y-value = conjoint_runs[run_idx][pair_idx]['GP_level_curve_y_values_on_grid_cols'][task_idx][# sampled points (0-200)][grid_col_idx (0-90)]
y-value = conjoint_runs[run_idx][pair_idx]['ground_truth_y_values_on_grid_cols'][task_idx][grid_col_idx (0-90)]

y-value = disjoint_runs[run_idx][task_idx][pair_idx]['GP_level_curve_y_values_on_grid_cols'][# sampled points (0-200)][grid_col_idx (0-90)]
y-value = disjoint_runs[run_idx][task_idx][pair_idx]['ground_truth_y_values_on_grid_cols'][grid_col_idx (0-90)]
'''
def get_disjoint_y(run, pair, task, grid_col, sample, curve):
    if curve == GP_key:
        return disjoint_runs[run][task][pair][GP_key][sample][grid_col]
    elif curve == ground_truth_key:
        return disjoint_runs[run][task][pair][ground_truth_key][grid_col]
    else:
        raise 'ERROR, WRONG KEY'
        
def get_conjoint_y(run, pair, task, grid_col, sample, curve):
    if curve == GP_key:
        return conjoint_runs[run][pair][GP_key][task][sample][grid_col]
    elif curve == ground_truth_key:
        return conjoint_runs[run][pair][ground_truth_key][task][grid_col]
    else:
        raise 'ERROR, WRONG KEY'
        
        
# Repeat the disjoint and qcsf data
num_disjoint_samples = len(disjoint_runs[0][0][0][GP_key])
for pair_index in range(num_pairs):
    for task_index in range(num_tasks):
        for run_index in range(num_runs):
            disjoint_runs[run_index][task_index][pair_index][GP_key] = [
                np.copy(disjoint_runs[run_index][task_index][pair_index][GP_key][sample])
                for sample in range(num_disjoint_samples)
                for i in range(2)
            ][1:]
            
            qcsf_runs[run_index][task_index][pair_index][GP_key] = [
                np.copy(qcsf_runs[run_index][task_index][pair_index][GP_key][sample])
                for sample in range(num_disjoint_samples)
                for i in range(2)
            ][1:]
            
            
        

# put stuff in numpy so it's easier to work with
np_disjoint_GP = np.empty((num_runs, num_pairs, num_tasks, num_samples, num_grid_cols))
np_conjoint_GP = np.empty((num_runs, num_pairs, num_tasks, num_samples, num_grid_cols))
np_ground_truth = np.empty((num_runs, num_pairs, num_tasks, num_samples, num_grid_cols))
np_qcsf = np.empty((num_runs, num_pairs, num_tasks, num_samples, num_grid_cols))
for run_index in range(num_runs):
    for pair_index in range(num_pairs):
        print(pair_index)
        for task_index in range(num_tasks):
            for sample_index in range(num_samples):
                for grid_col_index in range(num_grid_cols):
                
                    np_disjoint_GP[run_index, pair_index, task_index, sample_index, grid_col_index] \
                    = disjoint_runs[run_index][task_index][pair_index][GP_key][sample_index][grid_col_index]
                    
                    np_conjoint_GP[run_index, pair_index, task_index, sample_index, grid_col_index] \
                    = conjoint_runs[run_index][pair_index][GP_key][task_index][sample_index][grid_col_index]  
                
                    np_ground_truth[run_index, pair_index, task_index, sample_index, grid_col_index] \
                    = conjoint_runs[run_index][pair_index][ground_truth_key][task_index][grid_col_index]   
                    
                    np_qcsf[run_index, pair_index, task_index, sample_index, grid_col_index] \
                    = qcsf_runs[run_index][task_index][pair_index][GP_key][sample_index][grid_col_index]

0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
9

In [29]:
print('axes: (run, pair, task, # sampled points, grid col)')
print("np_disjoint_GP dimensions:", np_disjoint_GP.shape)
print("np_conjoint_GP dimensions:", np_conjoint_GP.shape)
print("np_ground_truth dimensions:", np_ground_truth.shape)
print("np_qcsf dimensions:", np_qcsf.shape)

axes: (run, pair, task, # sampled points, grid col)
np_disjoint_GP dimensions: (2, 210, 2, 201, 91)
np_conjoint_GP dimensions: (2, 210, 2, 201, 91)
np_ground_truth dimensions: (2, 210, 2, 201, 91)
np_qcsf dimensions: (2, 210, 2, 201, 91)


In [30]:

DISJOINT = 'Disjoint'
CONJOINT = 'Conjoint'
GT = 'Ground Truth'
QCSF = 'qCSF'

stuff_to_run = [
    ((QCSF, 0), (QCSF, 1)),
    ((DISJOINT, 0), (DISJOINT, 1)),
    ((CONJOINT, 0), (CONJOINT, 1)),
    ((QCSF, 0), (GT, 0)),
    ((DISJOINT, 0), (GT, 0)),
    ((CONJOINT, 0), (GT, 0)),
]

samples_list = [20, 50, 100]

# Statistics to calculate
def signed_diff(a, b):
    sign_diff = b - a
    return {
        'Mean': np.mean(sign_diff),
        'Std': np.std(sign_diff)
    }

def absolute_diff(a, b):
    abs_diff = np.abs(b - a)
    return {
        'Mean': np.mean(abs_diff),
        'Std': np.std(abs_diff)
    }
    
def root_mean_square_diff(a, b):
    return np.sqrt(np.mean(np.square(b - a)))
    
statistics = {
    'Signed Difference (2nd-1st)': signed_diff,
    'Absolute Difference': absolute_diff,
    'Root Mean Square Difference': root_mean_square_diff
}

condition_mapping = {
    DISJOINT: np_disjoint_GP,
    CONJOINT: np_conjoint_GP,
    GT: np_ground_truth,
    QCSF: np_qcsf
}

for stat_name, stat_func in statistics.items():
    print(f'Statistic: {stat_name}')
    
    for stuff in stuff_to_run:
        a_condition = stuff[0][0]
        a_run = stuff[0][1]
        b_condition = stuff[1][0]
        b_run = stuff[1][1]
        
        print(f'  Curves: {a_condition}(run {a_run}) vs {b_condition}(run {b_run})')
        
        for samples in samples_list:
            print(f'    Samples: {samples} | ', end='')
            
            a = condition_mapping[a_condition][a_run, :, :, samples, :]
            b = condition_mapping[b_condition][b_run, :, :, samples, :]
            just_GT = condition_mapping[GT][0, :, :, samples, :]

            # Create a mask for NaN values
            mask_a = ~np.isnan(a)
            mask_b = ~np.isnan(b)
            mask_just_GT = ~np.isnan(just_GT)
            mask_combined = mask_a & mask_b & mask_just_GT

            # Apply the mask to both arrays
            a_filtered = a[mask_combined]
            b_filtered = b[mask_combined]

            # Calculate the statistic using the corresponding function from the dictionary
            stat_result = stat_func(a_filtered, b_filtered)

            if isinstance(stat_result, dict):
                for k, v in stat_result.items():
                    print(f'{k}: {v:.4f} | ', end='')
            else:
                print(f'Value: {stat_result:.4f}', end='')
                
            print()
                    
        print()

Statistic: Signed Difference (2nd-1st)
  Curves: qCSF(run 0) vs qCSF(run 1)
    Samples: 20 | Mean: 0.0035 | Std: 0.3013 | 
    Samples: 50 | Mean: 0.0029 | Std: 0.1972 | 
    Samples: 100 | Mean: 0.0013 | Std: 0.1623 | 

  Curves: Disjoint(run 0) vs Disjoint(run 1)
    Samples: 20 | Mean: 0.0222 | Std: 0.4587 | 
    Samples: 50 | Mean: 0.0048 | Std: 0.2500 | 
    Samples: 100 | Mean: -0.0034 | Std: 0.1174 | 

  Curves: Conjoint(run 0) vs Conjoint(run 1)
    Samples: 20 | Mean: -0.0108 | Std: 0.3494 | 
    Samples: 50 | Mean: -0.0082 | Std: 0.1862 | 
    Samples: 100 | Mean: 0.0011 | Std: 0.1086 | 

  Curves: qCSF(run 0) vs Ground Truth(run 0)
    Samples: 20 | Mean: -0.0494 | Std: 0.2210 | 
    Samples: 50 | Mean: 0.0026 | Std: 0.1512 | 
    Samples: 100 | Mean: 0.0099 | Std: 0.1287 | 

  Curves: Disjoint(run 0) vs Ground Truth(run 0)
    Samples: 20 | Mean: 0.0273 | Std: 0.3486 | 
    Samples: 50 | Mean: 0.0130 | Std: 0.1915 | 
    Samples: 100 | Mean: -0.0029 | Std: 0.0875 | 

  Cur

In [None]:
# import matplotlib.pyplot as plt

# # Assuming np_disjoint_GP, np_conjoint_GP, and np_ground_truth are defined

# # Get the actual dimensions from the arrays
# num_samples = np_disjoint_GP.shape[3]  # Assuming the sample axis is the 4th axis
# num_grid_cols = np_disjoint_GP.shape[4]  # Assuming the grid column axis is the 5th axis

# # Define the number of samples per figure
# samples_per_figure = 10

# # Calculate the number of figures needed
# num_figures = (num_samples + samples_per_figure - 1) // samples_per_figure

# # Plot samples in batches of samples_per_figure
# for fig_index in range(num_figures):
#     start_sample = fig_index * samples_per_figure
#     end_sample = min((fig_index + 1) * samples_per_figure, num_samples)
    
#     # Create the subplots for the current batch of samples
#     fig, axs = plt.subplots(end_sample - start_sample, figsize=(8, 6 * (end_sample - start_sample)))
    
#     # Iterate over sample indices and plot each one on a separate subplot
#     for sample_index in range(start_sample, end_sample):
#         ax_index = sample_index - start_sample
        
#         # Plot np_disjoint_GP for the current sample index
#         axs[ax_index].plot(np.arange(1, num_grid_cols + 1), np_disjoint_GP[0, 0, 0, sample_index],
#                             label=f'Disjoint GP - Sample {sample_index}')
        
#         # Plot np_conjoint_GP for the current sample index
#         axs[ax_index].plot(np.arange(1, num_grid_cols + 1), np_conjoint_GP[0, 0, 0, sample_index],
#                             label=f'Conjoint GP - Sample {sample_index}')
        
#         # Plot np_ground_truth for the current sample index
#         axs[ax_index].plot(np.arange(1, num_grid_cols + 1), np_ground_truth[0, 0, 0, sample_index],
#                             label=f'Ground Truth - Sample {sample_index}')
        
#         # Set y-axis limits from 0 to 3
#         axs[ax_index].set_ylim(0, 3)
        
#         # Add labels, title, legend, and grid to each subplot
#         axs[ax_index].set_xlabel('X values')
#         axs[ax_index].set_ylabel('Y values')
#         axs[ax_index].set_title(f'Sample {sample_index}')
#         axs[ax_index].legend()
#         axs[ax_index].grid(True)
    
#     # Adjust layout and show the plot
#     plt.tight_layout()
#     plt.show()
