# Conditional Gaussian on a sphere

## Setup

In [1]:
%matplotlib inline

import sys
import numpy as np
import matplotlib
from matplotlib import pyplot as plt
import torch
from torch.utils.data import TensorDataset
import logging
from mpl_toolkits.mplot3d import Axes3D

sys.path.append("../../")

logging.basicConfig(
    format="%(asctime)-5.5s %(name)-30.30s %(levelname)-7.7s %(message)s",
    datefmt="%H:%M",
    level=logging.DEBUG,
)
logger = logging.getLogger(__name__)
# Output of all other modules (e.g. matplotlib)
for key in logging.Logger.manager.loggerDict:
    if "manifold_flow" not in key:
        logging.getLogger(key).setLevel(logging.WARNING)


## Load results

In [2]:
setup_filenames = [
    "2_3_0.010",
    "2_3_0.001",
    "2_3_0.100",
]
setup_labels = [
    "0.01",
    "0.001",
    "0.1",
]

In [3]:
algo_filenames = []
algo_additionals = []
algo_labels = []

def add_algo(filename, add, label):
    algo_filenames.append(filename)
    algo_additionals.append(add)
    algo_labels.append(label)
    
    
add_algo("flow", "_small", "Flow")
add_algo("flow", "_small_long", "Flow (long)")
add_algo("flow", "_small_shallow_long", "Flow (shallow, long)")

add_algo("pie", "_small", "PIE") 
add_algo("pie", "_small_long", "PIE (long)") 
add_algo("pie", "_small_shallow_long", "PIE (shallow, long)") 

add_algo("mf", "_small", "MAD AF")
add_algo("mf", "_small_noprepost", "MAD AF (no pre/post)")
add_algo("mf", "_small_complex", "MAD AF (complex)") 
add_algo("mf", "_small_long", "MAD AF (long)")
add_algo("mf", "_small_shallow_long", "MAD AF (shallow, long)")

add_algo("emf", "_small", "MAD AF + Enc.")

add_algo("gamf", "_small_largebs", "OT MAD AF")
add_algo("gamf", "_small_hugebsbs", "OT MAD AF (5k batchsize)") 
add_algo("gamf", "_small_largebs_long", "OT MAD AF (long)") 
add_algo("gamf", "_small_largebs_shallow_long", "OT MAD AF (shallow, long)") 

add_algo("pie_specified", "_small", "Prescr. PIE") 
add_algo("pie_specified", "_small_long", "Prescr. PIE (long)") 
add_algo("pie_specified", "_small_shallow_long", "Prescr. PIE (shallow, long)") 

add_algo("mf_specified", "_small", "Prescr. MAD AF") 
add_algo("mf_specified", "_small_long", "Prescr. MAD AF (long)") 
add_algo("mf_specified", "_small_shallow_long", "Prescr. MAD AF (shallow, long)")

add_algo("gamf_specified", "_small_largebs", "Prescr. OT MAD AF")
add_algo("gamf_specified", "_small_hugebs", "Prescr. OT MAD AF (5k batchsize)") 
add_algo("gamf_specified", "_small_largebs_long", "Prescr. OT MAD AF (long)") 
add_algo("gamf_specified", "_small_largebs_shallow_long", "Prescr. OT MAD AF (shallow, long)") 


In [4]:
def load(quantity, shape, numpyfy=True, result_dir="../data/results"):
    all_results = []
    
    for algo_filename, algo_add in zip(algo_filenames, algo_additionals):
        results = []
        
        for setup_filename in setup_filenames:
            try:
                results.append(np.load(
                    "{}/{}_2_conditional_spherical_gaussian_{}{}_{}.npy".format(
                        result_dir, algo_filename, setup_filename, algo_add, quantity
                    )
                ))
            except FileNotFoundError as e:
                print(e)
                results.append(np.nan*np.ones(shape))
            
        all_results.append(results)
    
    return np.asarray(all_results) if numpyfy else all_results

true_posterior_samples = load("true_posterior_samples", (1000,2))
model_posterior_samples = load("model_posterior_samples", (1000, 2))
mmds = load("mmd", (1,))

true_test_log_likelihood = load("true_log_likelihood_test", (11*11, 1000,))
model_test_log_likelihood = load("model_log_likelihood_test", (11*11, 1000,))
model_test_reco_error = load("model_reco_error_test", (1000,))
parameter_grid = load("parameter_grid_test", (11*11,2))

model_test_log_likelihood.shape

[Errno 2] No such file or directory: '../data/results/mf_2_conditional_spherical_gaussian_2_3_0.010_small_long_true_posterior_samples.npy'
[Errno 2] No such file or directory: '../data/results/mf_2_conditional_spherical_gaussian_2_3_0.001_small_long_true_posterior_samples.npy'
[Errno 2] No such file or directory: '../data/results/mf_2_conditional_spherical_gaussian_2_3_0.100_small_long_true_posterior_samples.npy'
[Errno 2] No such file or directory: '../data/results/mf_2_conditional_spherical_gaussian_2_3_0.001_small_shallow_long_true_posterior_samples.npy'
[Errno 2] No such file or directory: '../data/results/mf_2_conditional_spherical_gaussian_2_3_0.100_small_shallow_long_true_posterior_samples.npy'
[Errno 2] No such file or directory: '../data/results/emf_2_conditional_spherical_gaussian_2_3_0.010_small_true_posterior_samples.npy'
[Errno 2] No such file or directory: '../data/results/emf_2_conditional_spherical_gaussian_2_3_0.001_small_true_posterior_samples.npy'
[Errno 2] No such f

[Errno 2] No such file or directory: '../data/results/gamf_2_conditional_spherical_gaussian_2_3_0.100_small_hugebsbs_true_log_likelihood_test.npy'
[Errno 2] No such file or directory: '../data/results/gamf_2_conditional_spherical_gaussian_2_3_0.010_small_largebs_long_true_log_likelihood_test.npy'
[Errno 2] No such file or directory: '../data/results/gamf_2_conditional_spherical_gaussian_2_3_0.001_small_largebs_long_true_log_likelihood_test.npy'
[Errno 2] No such file or directory: '../data/results/gamf_2_conditional_spherical_gaussian_2_3_0.100_small_largebs_long_true_log_likelihood_test.npy'
[Errno 2] No such file or directory: '../data/results/gamf_2_conditional_spherical_gaussian_2_3_0.010_small_largebs_shallow_long_true_log_likelihood_test.npy'
[Errno 2] No such file or directory: '../data/results/gamf_2_conditional_spherical_gaussian_2_3_0.001_small_largebs_shallow_long_true_log_likelihood_test.npy'
[Errno 2] No such file or directory: '../data/results/gamf_2_conditional_spherical

(26, 3, 121, 1000)

In [5]:
x_test = np.asarray([
    np.load("../data/samples/conditional_spherical_gaussian/conditional_spherical_gaussian_{}_x_test.npy".format(setup_filename))
    for setup_filename in setup_filenames
])

true_distances = np.abs(np.sum(x_test**2, axis=-1)**0.5 - 1.)

In [6]:
true_expected_nll = -2. * np.mean(true_test_log_likelihood, axis=-1)
model_expected_nll = -2. * np.mean(model_test_log_likelihood, axis=-1)

n_observed = 20
true_observed_nll = -2. * np.sum(true_test_log_likelihood[:,:,:,:n_observed], axis=-1)
model_observed_nll = -2. * np.sum(model_test_log_likelihood[:,:,:,:n_observed], axis=-1)

## Print metrics

In [7]:
def print_results(setup):
    print("{:<36.36s} | {:>8.8s}".format(setup_labels[setup], "MMD"))
    print("-"*47)
    for label, mmd in zip(algo_labels, mmds):
        if np.isfinite(logp[setup]) or np.isfinite(dist[setup]):
            print("{:<36.36s} | {:>8.2f}".format(label, mmd[setup]))
        else:
            print("{:<36.36s} |         ".format(label))

In [8]:
def print_all_results():
    print(
        "{:<36.36s} | {:>8.8s} | {:>8.8s} | {:>8.8s}".format(
        "epsilon", setup_labels[0], setup_labels[1], setup_labels[2]
        )
    )
    print(
        "{:<36.36s} | {:>8.8s} | {:>8.8s} | {:>8.8s}".format(
        "", "MMD", "MMD", "MMD"
        )
    )
    print("-"*69)
    for label, mmd in zip(
        algo_labels,
        mmds
    ):
        def _f(val):
            return "{:>8.3f}".format(val) if np.isfinite(val) else "        "
        
        print(
            "{:<36.36s} | {} | {} | {}".format(
                label, _f(mmd[0]), _f(mmd[1]), _f(mmd[2])
            )
        )
        

In [9]:
print_all_results()

epsilon                              |     0.01 |    0.001 |      0.1
                                     |      MMD |      MMD |      MMD
---------------------------------------------------------------------
Flow                                 |    0.391 |    0.285 |    0.099
Flow (long)                          |    1.262 |    1.070 |    0.359
Flow (shallow, long)                 |    0.456 |    0.220 |    0.038
PIE                                  |    0.321 |    0.008 |    0.149
PIE (long)                           |    0.934 |    0.959 |    0.258
PIE (shallow, long)                  |    1.066 |    0.861 |    0.091
MAD AF                               |    0.367 |    0.009 |    0.012
MAD AF (no pre/post)                 |    0.796 |    1.218 |    0.216
MAD AF (complex)                     |    1.354 |    0.689 |    0.646
MAD AF (long)                        |          |          |         
MAD AF (shallow, long)               |    1.330 |          |         
MAD AF + Enc.       

In [None]:
print_results(1)

In [None]:
print_results(2)