# PFN Comparison

In this file, we compare the performance of different BODi-trained PFNs. First, we compare the regression performance, then at overlap scores, then at BO performance 

In [None]:
import sys
#sys.path.append('/Users/timothyshinners/Library/Python/3.9/lib/python/site-packages')
import torch
import numpy as np
import scipy
import matplotlib.pyplot as plt
import pandas as pd
#import mcbo

from mcbo import task_factory
from mcbo.optimizers.bo_builder import BoBuilder, BO_ALGOS
from mcbo.optimizers.non_bo.random_search import RandomSearch
from mcbo.optimizers.non_bo.local_search import LocalSearch
from mcbo.acq_funcs import acq_factory

import pfns4bo
from pfns4bo.pfn_bo_bayesmark import PFNOptimizer
from pfns4mvbo.priors import cocabo_prior
from pfns4mvbo.pfn_plotting import showPFNPosteriorDistributions, showPFNvsCOCABOPosteriorDistributions, showPFNvsCOCABO
from pfns4mvbo.mvpfn_optimizer import MVPFNOptimizer
from pfns4mvbo.pfn_acq_func import PFNAcqFunc
from pfns4mvbo.pfn_model import PFNModel
from pfns4mvbo.evaluation import do_regression_evaluation, compare_pfn_vs_mcbo, do_validation_experiment
import re


def bootstrap_ub(col):
    # returns the lower and upper bounds of the 95% confidence interval
    bootstrap_samples = np.random.choice(col, size=[1000, len(col)], replace=True).mean(axis=1)
    assert bootstrap_samples.shape[0] == 1000
    ub = np.quantile(bootstrap_samples, 0.975)
    return ub

def bootstrap_lb(col):
    # returns the lower and upper bounds of the 95% confidence interval
    bootstrap_samples = np.random.choice(col, size=[1000, len(col)], replace=True).mean(axis=1)
    lb = np.quantile(bootstrap_samples, 0.025)
    return lb

In [None]:
%matplotlib notebook
    
%matplotlib inline

In [None]:
PFN_LIST = ['6', '24']
#for i in range(34,39):
#    PFN_LIST += [str(i)]


# INFO and training curves

In [None]:
for pfn_no in PFN_LIST:
    try:
        pfn = torch.load(f'trained_models/pfn_bodi_{pfn_no}.pth')
        print(pfn_no, pfn.info)
        loss_curve = pfn.info['loss_curve']
        print(pfn_no, pfn.info.get('learning_rate', None), pfn.info.get('bptt', None))
        plt.plot(np.linspace(0, 1, len(loss_curve)), loss_curve, label=pfn_no)
    except:
        print(f'PFN {pfn_no} has no loss curve')

plt.legend()

# Regression

We assess the regression performance of our trained PFN's

In [None]:
loss = 'nll_normed'

regression_df_list = []

for i, pfn_no in enumerate(PFN_LIST):
    regression_df_list += [pd.read_csv('bodi_evaluation_data/regression_results_'+pfn_no+'.csv')]
    regression_df_list[i] = regression_df_list[i][regression_df_list[i]['model_name'] == 'PFN']
    regression_df_list[i]['PFN_number'] = pfn_no

cocabo_baseline = pd.read_csv('evaluation_data/regression_results_casmo.csv')
cocabo_baseline = cocabo_baseline[cocabo_baseline['model_name'] != 'PFN']
cocabo_baseline['PFN_number'] = 'Casmopolitan'
cocabo_baseline = cocabo_baseline.reset_index(drop=True)

assert 1==0, 'switch cocabo to bodi'

regression_df_list += [cocabo_baseline]

# check the seed is consistent across data frames
seed = regression_df_list[0]['seed'][0]
for df in regression_df_list:
    assert df['seed'][0] == seed

if 'Casmopolitan' not in PFN_LIST:
    PFN_LIST += ['Casmopolitan']

fig, ax = plt.subplots(3, 2, figsize=(10, 12))
ax[0,0].set_title('Loss vs Amount of Training Data, PFN')
ax[0,0].set_xlabel('Number of Training Points')
ax[0,0].set_ylabel('Loss')
ax[1,0].set_title('Loss vs Number of Dimensions, PFN')
ax[1,0].set_xlabel('Number of Dimensions')
ax[1,0].set_ylabel('Loss')
ax[2,0].set_title('Histogram of Losses, PFN')
ax[2,0].set_xlabel('Loss')
ax[0,1].set_title('Ranks')
ax[0,1].set_xlabel('Number of Training Points')
ax[0,1].set_ylabel('Rank')
ax[1,1].set_title('Ranks')
ax[1,1].set_xlabel('Number of Dimensions')
ax[1,1].set_ylabel('Rank')
ax[2,1].set_title('Best PFN')
ax[2,1].set_xlabel('Number of Training Points')
ax[2,1].set_ylabel('Number of Dimensions')

for i, regression_df in enumerate(regression_df_list):
    mean = regression_df.groupby('n_training_data')[loss].mean()
    median = regression_df.groupby('n_training_data')[loss].median()
    ub = regression_df.groupby('n_training_data')[loss].apply(bootstrap_ub)
    lb = regression_df.groupby('n_training_data')[loss].apply(bootstrap_lb)
    ax[0,0].plot(mean, label = PFN_LIST[i])
    ax[0,0].fill_between(mean.index.get_level_values('n_training_data'),
                           lb,
                           ub,
                           alpha=0.5)

    regression_df['n_dims'] = regression_df['num_dims'] + regression_df['cat_dims']
    mean = regression_df.groupby('n_dims')[loss].mean()
    median = regression_df.groupby('n_dims')[loss].median()
    ub = regression_df.groupby('n_dims')[loss].apply(bootstrap_ub)
    lb = regression_df.groupby('n_dims')[loss].apply(bootstrap_lb)
    ax[1,0].plot(mean, label = PFN_LIST[i])
    ax[1,0].fill_between(mean.index.get_level_values('n_dims'),
                           lb,
                           ub,
                           alpha=0.5)

    ax[2,0].hist(regression_df[loss], alpha=0.3, bins=100, label = PFN_LIST[i])

# now we make ranking plot, using the raw loss calculations
regression_df = regression_df_list[0]
losses = np.array([df[loss[0:3]+'_raw'] for df in regression_df_list]).T
ranks = np.argsort(losses, axis=1)
ranks = np.argsort(ranks)

avg_ranks = np.zeros(len(PFN_LIST))
proportion_wins = np.zeros(len(PFN_LIST))

for i in range(ranks.shape[1]):
    regression_df[PFN_LIST[i]+'_rank'] = ranks[:, i]
    avg_ranks[i] = regression_df[PFN_LIST[i]+'_rank'].mean()
    proportion_wins[i] = (regression_df[PFN_LIST[i]+'_rank'] == 0).mean()
    mean = regression_df.groupby('n_training_data')[PFN_LIST[i]+'_rank'].mean()
    ub = regression_df.groupby('n_training_data')[PFN_LIST[i]+'_rank'].apply(bootstrap_ub)
    lb = regression_df.groupby('n_training_data')[PFN_LIST[i]+'_rank'].apply(bootstrap_lb)
    ax[0,1].plot(mean, label = PFN_LIST[i])
    ax[0,1].fill_between(mean.index.get_level_values('n_training_data'),
                           lb,
                           ub,
                           alpha=0.5)

    mean = regression_df.groupby('n_dims')[PFN_LIST[i]+'_rank'].mean()
    ub = regression_df.groupby('n_dims')[PFN_LIST[i]+'_rank'].apply(bootstrap_ub)
    lb = regression_df.groupby('n_dims')[PFN_LIST[i]+'_rank'].apply(bootstrap_lb)
    ax[1,1].plot(mean, label = PFN_LIST[i])
    ax[1,1].fill_between(mean.index.get_level_values('n_dims'),
                           lb,
                           ub,
                           alpha=0.5)

    best = regression_df[regression_df[PFN_LIST[i]+'_rank'] == 0]
    noise_x = np.random.normal(0, 1.5, best.shape[0])
    noise_y = np.random.normal(0, 1.5, best.shape[0])
    ax[2,1].scatter(best['n_training_data']+noise_x, best['n_dims']+noise_y, s=3, label = PFN_LIST[i])

# BASELINES
cocabo_baseline = pd.read_csv('evaluation_data/regression_results_'+PFN_LIST[0]+'.csv')
cocabo_baseline = cocabo_baseline[cocabo_baseline['model_name'] != 'PFN']


   

print(PFN_LIST)
print('Average Ranks', avg_ranks)
print('Win Proportion', proportion_wins)

    
ax[0,0].legend()
ax[1,0].legend()
ax[2,0].legend()
ax[0,1].legend()
ax[2,1].legend()
plt.tight_layout()

del regression_df_list
del PFN_LIST[-1]

# Overlap

In [None]:
div = 'overlap'

divergence_df_list = []
i = 0
for index, pfn_no in enumerate(PFN_LIST):
    print(index)
    try:
        dat_frame = pd.read_csv('bodi_evaluation_data/divergence_results_'+pfn_no+'.csv')
        if (dat_frame['task_name']=='zakharov').sum() > 0:
            divergence_df_list += [dat_frame]
            divergence_df_list[i] = divergence_df_list[i][divergence_df_list[i]['model_name'] == 'PFN']
            #divergence_df_list[i] = divergence_df_list[i][divergence_df_list[i]['task_name'] == 'ackley']
            divergence_df_list[i]['PFN_number'] = pfn_no
            divergence_df_list[i]['overlap'] = -divergence_df_list[i]['overlap']
            i += 1
        else:
            print('BAD')
    except:
        print('CORRUPT')

del dat_frame


fig, ax = plt.subplots(3, 2, figsize=(10, 12))
ax[0,0].set_title('Overlap vs Amount of Training Data')
ax[0,0].set_xlabel('Number of Training Points')
ax[0,0].set_ylabel('Overlap')
ax[1,0].set_title('Overlap vs Number of Dimensions')
ax[1,0].set_xlabel('Number of Dimensions')
ax[1,0].set_ylabel('Overlap')
ax[2,0].set_title('Histogram of Divergence')
ax[2,0].set_xlabel('Overlap')
ax[0,1].set_title('Ranks')
ax[0,1].set_xlabel('Number of Training Points')
ax[0,1].set_ylabel('Rank')
ax[1,1].set_title('Ranks')
ax[1,1].set_xlabel('Number of Dimensions')
ax[1,1].set_ylabel('Rank')
ax[2,1].set_title('Best PFN with respect to Overlap')
ax[2,1].set_xlabel('Number of Training Points')
ax[2,1].set_ylabel('Number of Dimensions')

mean_overlaps = np.zeros(len(PFN_LIST))

for i, divergence_df in enumerate(divergence_df_list):
    mean_overlaps[i] = divergence_df['overlap'].mean()
    mean = divergence_df.groupby('n_training_data')[div].mean()
    median = divergence_df.groupby('n_training_data')[div].median()
    ub = divergence_df.groupby('n_training_data')[div].apply(bootstrap_ub)
    lb = divergence_df.groupby('n_training_data')[div].apply(bootstrap_lb)
    ax[0,0].plot(-mean, label = PFN_LIST[i])
    ax[0,0].fill_between(mean.index.get_level_values('n_training_data'),
                           -lb,
                           -ub,
                           alpha=0.5)

    divergence_df['n_dims'] = divergence_df['num_dims'] + divergence_df['cat_dims']
    mean = divergence_df.groupby('n_dims')[div].mean()
    median = divergence_df.groupby('n_dims')[div].median()
    ub = divergence_df.groupby('n_dims')[div].apply(bootstrap_ub)
    lb = divergence_df.groupby('n_dims')[div].apply(bootstrap_lb)
    ax[1,0].plot(-mean, label = PFN_LIST[i])
    ax[1,0].fill_between(mean.index.get_level_values('n_dims'),
                           -lb,
                           -ub,
                           alpha=0.5)

    #divergence_df = divergence_df[divergence_df['n_training_data'] == 100]

    #ax[2,0].hist(np.log(divergence_df[div]), alpha=0.3, bins=100, label = PFN_LIST[i])
    ax[2,0].hist(-divergence_df[div], alpha=0.3, bins=100, label = PFN_LIST[i])

# now we make ranking plot, using the raw loss calculations
divergence_df = divergence_df_list[0]
losses = np.array([df['overlap'] for df in divergence_df_list]).T
ranks = np.argsort(losses, axis=1)
ranks = np.argsort(ranks)

avg_ranks = np.zeros(len(PFN_LIST))
proportion_wins = np.zeros(len(PFN_LIST))

for i in range(ranks.shape[1]):
    divergence_df[PFN_LIST[i]+'_rank'] = ranks[:, i]
    avg_ranks[i] = divergence_df[PFN_LIST[i]+'_rank'].mean()
    proportion_wins[i] = (divergence_df[PFN_LIST[i]+'_rank'] == 0).mean()
    mean = divergence_df.groupby('n_training_data')[PFN_LIST[i]+'_rank'].mean()
    ub = divergence_df.groupby('n_training_data')[PFN_LIST[i]+'_rank'].apply(bootstrap_ub)
    lb = divergence_df.groupby('n_training_data')[PFN_LIST[i]+'_rank'].apply(bootstrap_lb)
    ax[0,1].plot(mean, label = PFN_LIST[i])
    ax[0,1].fill_between(mean.index.get_level_values('n_training_data'),
                           lb,
                           ub,
                           alpha=0.5)

    mean = divergence_df.groupby('n_dims')[PFN_LIST[i]+'_rank'].mean()
    ub = divergence_df.groupby('n_dims')[PFN_LIST[i]+'_rank'].apply(bootstrap_ub)
    lb = divergence_df.groupby('n_dims')[PFN_LIST[i]+'_rank'].apply(bootstrap_lb)
    ax[1,1].plot(mean, label = PFN_LIST[i])
    ax[1,1].fill_between(mean.index.get_level_values('n_dims'),
                           lb,
                           ub,
                           alpha=0.5)

    best = divergence_df[divergence_df[PFN_LIST[i]+'_rank'] == 0]
    noise_x = np.random.normal(0, 1.5, best.shape[0])
    noise_y = np.random.normal(0, 1.5, best.shape[0])
    ax[2,1].scatter(best['n_training_data']+noise_x, best['n_dims']+noise_y, s=3, label = PFN_LIST[i])

print(PFN_LIST)
print('Average Overlap', -mean_overlaps)
print('Average Ranks', avg_ranks)
print('Win Proportion', proportion_wins)

    
ax[0,0].legend()
ax[1,0].legend()
ax[2,0].legend()
ax[0,1].legend()
ax[2,1].legend()
plt.tight_layout()

#del divergence_df_list

# BO Runs

In [None]:
BO_df_list = []

for i, pfn_no in enumerate(PFN_LIST):
    BO_df_list += [pd.read_csv('casmo_evaluation_data/BO_results_pfn_'+pfn_no+'.csv')]
    BO_df_list[i]['PFN_number'] = pfn_no
    BO_df_list[i].reset_index(drop=True)

cocabo_baseline = pd.read_csv('evaluation_data/BO_results_cocabo_baseline.csv')
cocabo_baseline['PFN_number'] = 'CoCaBO'
cocabo_baseline = cocabo_baseline.reset_index(drop=True)
cocabo_baseline['optimizer_name'] = 'CoCaBO_pfnAcqFunc'
cocabo_baseline_2 = cocabo_baseline.copy(deep=True)
cocabo_baseline_2['optimizer_name'] = 'CoCaBO_mcboAcqFunc'
cocabo_baseline = pd.concat([cocabo_baseline, cocabo_baseline_2], axis=0, ignore_index=True)
BO_df_list += [cocabo_baseline]

cocabo_baseline = pd.read_csv('evaluation_data/BO_results_casmo_baseline.csv')
cocabo_baseline['PFN_number'] = 'Casmopolitan'
cocabo_baseline = cocabo_baseline.reset_index(drop=True)
cocabo_baseline['optimizer_name'] = 'Casmo_pfnAcqFunc'
cocabo_baseline_2 = cocabo_baseline.copy(deep=True)
cocabo_baseline_2['optimizer_name'] = 'Casmo_mcboAcqFunc'
cocabo_baseline = pd.concat([cocabo_baseline, cocabo_baseline_2], axis=0, ignore_index=True)
BO_df_list += [cocabo_baseline]

random_baseline = pd.read_csv('evaluation_data/BO_results_random_baseline.csv')
random_baseline['PFN_number'] = 'Random'
random_baseline = random_baseline.reset_index(drop=True)
random_baseline['optimizer_name'] = 'Random_pfnAcqFunc'
random_baseline_2 = random_baseline.copy(deep=True)
random_baseline_2['optimizer_name'] = 'Random_mcboAcqFunc'
random_baseline = pd.concat([random_baseline, random_baseline_2], axis=0, ignore_index=True)
BO_df_list += [random_baseline]

if 'CoCaBO' not in PFN_LIST:
    PFN_LIST += ['CoCaBO']
if 'Casmo' not in PFN_LIST:
    PFN_LIST += ['Casmo']
if 'Random' not in PFN_LIST:
    PFN_LIST += ['Random']

results_full = pd.concat(BO_df_list, axis=0, ignore_index=False)

results_full['task_number'] = pd.factorize(results_full['task_name'])[0]

results = results_full[results_full['optimizer_name'].str.contains('pfnAcqFunc')]

def rescale(series):
    return (series - series.min()) / (series.max() - series.min())

results['best_y_scaled'] = results.groupby('task_name')['best_y'].transform(rescale)

fig, ax = plt.subplots(3, 2, figsize=(8, 12))


for i, optimizer_name in enumerate(results['optimizer_name'].unique()):
    print('HERE', optimizer_name)
    filtered_df = results[results['optimizer_name'] == optimizer_name]
    mean = filtered_df.groupby('nth_guess')['best_y_scaled'].mean()
    median = filtered_df.groupby('nth_guess')['best_y_scaled'].median()
    ub = filtered_df.groupby('nth_guess')['best_y_scaled'].apply(bootstrap_ub)
    lb = filtered_df.groupby('nth_guess')['best_y_scaled'].apply(bootstrap_lb)
    ax[0,0].plot(mean, label = PFN_LIST[i])
    ax[0,0].fill_between(mean.index.get_level_values('nth_guess'),
                       lb,
                       ub,
                       alpha=0.1)

  

ax[0,0].set_ylabel('best_y_scaled, averaged across all runs')
ax[0,0].set_xlabel('iteration')
ax[0,0].set_title('Optimizer Performance, PFN Acq Func')
ax[0,0].legend()


#now lets do a rank plot!
BO_df = BO_df_list[0]
best_ys = np.array([df['best_y'] for df in BO_df_list]).T
ranks = np.argsort(best_ys, axis=1)
ranks = np.argsort(ranks, axis=1)
ranks = ranks.T.flatten()
results_full['rank'] = ranks

results = results_full[results_full['optimizer_name'].str.contains('pfnAcqFunc')]

for i, optimizer_name in enumerate(results['optimizer_name'].unique()):
    rank_df = results[results['optimizer_name'] == optimizer_name]
    print('mean rank for ', optimizer_name, rank_df['rank'].mean())
    mean = rank_df.groupby('nth_guess')['rank'].mean()
    median = rank_df.groupby('nth_guess')['rank'].median()
    ub = rank_df.groupby('nth_guess')['rank'].apply(bootstrap_ub)
    lb = rank_df.groupby('nth_guess')['rank'].apply(bootstrap_lb)
    ax[1,0].plot(mean)
    ax[1,0].fill_between(mean.index.get_level_values('nth_guess'),
                       lb,
                       ub,
                       alpha=0.1)

ax[1,0].set_title('Ranking Plot')
ax[1,0].set_ylabel('Average Rank Across All Setups')
ax[1,0].set_xlabel('Iteration')

for i, optimizer_name in enumerate(results['optimizer_name'].unique()):
    rank_df = results[results['optimizer_name'] == optimizer_name]
    print(optimizer_name+' win proportion: ', (rank_df['rank'] == 0).mean())
    rank_df = rank_df[rank_df['rank'] == 0]

    noise_x = np.random.normal(0, 0.3, rank_df.shape[0])
    noise_y = np.random.normal(0, 0.3, rank_df.shape[0])

    ax[2,0].scatter(rank_df['nth_guess']+noise_x, rank_df['task_number']+noise_y, s=1, label=optimizer_name)

ax[2,0].set_ylabel('task number')
ax[2,0].set_xlabel('iteration')
ax[2,0].legend()


#Now we do mcbo acq func!
results = results_full[results_full['optimizer_name'].str.contains('mcboAcqFunc')]

results['best_y_scaled'] = results.groupby('task_name')['best_y'].transform(rescale)

for i, optimizer_name in enumerate(results['optimizer_name'].unique()):
    filtered_df = results[results['optimizer_name'] == optimizer_name]
    mean = filtered_df.groupby('nth_guess')['best_y_scaled'].mean()
    median = filtered_df.groupby('nth_guess')['best_y_scaled'].median()
    ub = filtered_df.groupby('nth_guess')['best_y_scaled'].apply(bootstrap_ub)
    lb = filtered_df.groupby('nth_guess')['best_y_scaled'].apply(bootstrap_lb)
    ax[0,1].plot(mean, label = PFN_LIST[i])
    ax[0,1].fill_between(mean.index.get_level_values('nth_guess'),
                       lb,
                       ub,
                       alpha=0.1)

  

ax[0,1].set_ylabel('best_y_scaled, averaged across all runs')
ax[0,1].set_xlabel('iteration')
ax[0,1].set_title('Optimizer Performance, MCBO Acq Func')
ax[0,1].legend()


#now lets do a rank plot!
BO_df = BO_df_list[0]
best_ys = np.array([df['best_y'] for df in BO_df_list]).T
ranks = np.argsort(best_ys, axis=1)
ranks = np.argsort(ranks, axis=1)
ranks = ranks.T.flatten()
results_full['rank'] = ranks

results = results_full[results_full['optimizer_name'].str.contains('mcboAcqFunc')]

for i, optimizer_name in enumerate(results['optimizer_name'].unique()):
    rank_df = results[results['optimizer_name'] == optimizer_name]
    print('mean rank for ', optimizer_name, rank_df['rank'].mean())
    mean = rank_df.groupby('nth_guess')['rank'].mean()
    median = rank_df.groupby('nth_guess')['rank'].median()
    ub = rank_df.groupby('nth_guess')['rank'].apply(bootstrap_ub)
    lb = rank_df.groupby('nth_guess')['rank'].apply(bootstrap_lb)
    ax[1,1].plot(mean)
    ax[1,1].fill_between(mean.index.get_level_values('nth_guess'),
                       lb,
                       ub,
                       alpha=0.1)

ax[1,1].set_title('Ranking Plot')
ax[1,1].set_ylabel('Average Rank Across All Setups')
ax[1,1].set_xlabel('Iteration')

for i, optimizer_name in enumerate(results['optimizer_name'].unique()):
    rank_df = results[results['optimizer_name'] == optimizer_name]
    print(optimizer_name+' win proportion: ', (rank_df['rank'] == 0).mean())
    rank_df = rank_df[rank_df['rank'] == 0]

    noise_x = np.random.normal(0, 0.3, rank_df.shape[0])
    noise_y = np.random.normal(0, 0.3, rank_df.shape[0])

    ax[2,1].scatter(rank_df['nth_guess']+noise_x, rank_df['task_number']+noise_y, s=1, label=optimizer_name)

ax[2,1].set_ylabel('task number')
ax[2,1].set_xlabel('iteration')
ax[2,1].legend()

fig.tight_layout()

del PFN_LIST[-1]
del PFN_LIST[-1]
del PFN_LIST[-1]

# Direct Comparisons

I now plot actual posterior distributions for a simple task to compare the behavior of the models

In [None]:
MODEL_FILENAME = 'trained_models/pfn_bodi_23.pth'

In [None]:
task_kws = dict(variable_type=['num', 'nominal'],
                    num_dims=[1, 1],
                    lb=-1,
                    ub=2,
                    num_categories=[2, 2])

task = task_factory(task_name="ackley", **task_kws)

search_space = task.get_search_space()

n_init =1
cocabo = BO_ALGOS['Casmopolitan'].build_bo(search_space=search_space, n_init=n_init, device=torch.device("cpu"))
optimizer_kwargs = {
        'pfn_file': MODEL_FILENAME,
        'acq_func': 'pi',
        'acq_func_optim': 'mab',
        #'device': 'cpu',
        'fast':True
    }

mvpfn = MVPFNOptimizer(search_space=search_space,
                       **optimizer_kwargs)

def divergence(mu_cocabo, var_cocabo, mu_pfn, var_pfn):
    div = torch.log(var_pfn / var_cocabo) + (var_cocabo + (mu_cocabo - mu_pfn) ** 2) / (2 * var_pfn) - 0.5
    return div.mean().item()

for i in [2, 4, 10, 20, 50, 100, 1000]:
    print(i)

    fig, ax = plt.subplots(1, 2, figsize=(10, 4))

    mvpfn.restart()

    # PFN    
    X_pd = search_space.sample(i)
    x = search_space.transform(X_pd).to(torch.float32)
    y = task(X_pd)
    test_x = torch.vstack([torch.linspace(0, 1, 100).repeat(2), torch.zeros(200)]).T.to(torch.float32)
    test_x[100:, 1] = 1
    

    #plot the ground truth
    ground_truth = task(search_space.inverse_transform(test_x))
    ax[0].plot(test_x[:100,0].detach(), ground_truth[:100], c='black')
    ax[0].plot(test_x[100:,0].detach(), ground_truth[100:], c='black')
    
    mvpfn.observe(X_pd, y)
    
    logits = mvpfn.model_pfn.pfn(x, mvpfn.model_pfn.y_to_fit_y(torch.from_numpy(y).to(torch.float32)), test_x.to(torch.float32))
    mean = mvpfn.model_pfn.pfn.criterion.mean(logits)
    var = mvpfn.model_pfn.pfn.criterion.variance(logits)
    lower_pfn = mvpfn.model_pfn.pfn.criterion.icdf(logits, 0.025).flatten().detach()
    upper_pfn = mvpfn.model_pfn.pfn.criterion.icdf(logits, 0.975).flatten().detach()
    
    mean = mvpfn.model_pfn.fit_y_to_y(mean)
    lower_pfn = mvpfn.model_pfn.fit_y_to_y(lower_pfn)
    upper_pfn = mvpfn.model_pfn.fit_y_to_y(upper_pfn)
    
    ax[0].plot(test_x[:100,0].detach(), mean[:100].detach())
    ax[0].plot(test_x[100:,0].detach(), mean[100:].detach())
    ax[0].fill_between(test_x[:100,0].detach(),
                       lower_pfn[:100],
                       upper_pfn[:100],
                       alpha=0.5)
    ax[0].fill_between(test_x[100:,0].detach(),
                       lower_pfn[100:],
                       upper_pfn[100:],
                       alpha=0.5)

    ax[0].scatter(x[:,0], y, c=x[:,1])

    ax[0].set_title('PFN')

    # COCABO
    cocabo.restart()
    cocabo = BO_ALGOS['Casmopolitan'].build_bo(search_space=search_space, n_init=n_init, device=torch.device("cpu"))
    _ = cocabo.model.fit(x, torch.from_numpy(y))
    mean_cocabo, variance_cocabo = cocabo.model.predict(test_x)
    lower_cocabo = (mean_cocabo - 1.96 * torch.sqrt(variance_cocabo)).detach().flatten()
    upper_cocabo = (mean_cocabo + 1.96 * torch.sqrt(variance_cocabo)).detach().flatten()

    ax[1].plot(test_x[:100, 0].detach(), mean_cocabo[:100].detach().flatten())
    ax[1].plot(test_x[100:, 0].detach(), mean_cocabo[100:].detach().flatten())
    ax[1].fill_between(test_x[:100, 0].detach(),
                       lower_cocabo[:100],
                       upper_cocabo[:100],
                       alpha=0.5)
    ax[1].fill_between(test_x[100:, 0].detach(),
                       lower_cocabo[100:],
                       upper_cocabo[100:],
                       alpha=0.5)

    ax[1].plot(test_x[:100,0].detach(), ground_truth[:100], c='black')
    ax[1].plot(test_x[100:,0].detach(), ground_truth[100:], c='black')
    ax[1].scatter(x[:,0], y, c=x[:,1])
    ax[1].set_title('Casmopolitan')

    # now we calculate the divergence
    print(divergence(mean_cocabo, variance_cocabo, mean, var))
    print(logits.shape)
    # print('here', cocabo_logits_overlap(mean_cocabo, variance_cocabo, logits, mvpfn.model_pfn).mean())