# Spherical Gaussian experiment: generative results

## Setup

In [1]:
import sys
import numpy as np
import torch
from torch.utils.data import TensorDataset
import logging
from sklearn.metrics import roc_auc_score

sys.path.append("../../")

logging.basicConfig(
    format="%(asctime)-5.5s %(name)-30.30s %(levelname)-7.7s %(message)s",
    datefmt="%H:%M",
    level=logging.DEBUG,
)
logger = logging.getLogger(__name__)
# Output of all other modules (e.g. matplotlib)
for key in logging.Logger.manager.loggerDict:
    if "manifold_flow" not in key:
        logging.getLogger(key).setLevel(logging.WARNING)


## Load results

In [2]:
setup_filenames = [
    "2_3_0.010",
    "2_3_0.001",
    "2_3_0.100",
]
setup_labels = [
    "ε = 0.01",
    "ε = 0.001",
    "ε = 0.1",
]

In [3]:
algo_filenames = []
algo_additionals = []
algo_labels = []

def add_algo(filename, add, label):
    algo_filenames.append(filename)
    algo_additionals.append(add)
    algo_labels.append(label)
    
    
add_algo("flow", "_small", "Flow")
add_algo("flow", "_small_long", "Flow (long)")
add_algo("flow", "_small_shallow_long", "Flow (shallow, long)")

add_algo("pie", "_small", "PIE") 
add_algo("pie", "_small_long", "PIE (long)") 
add_algo("pie", "_small_shallow_long", "PIE (shallow, long)") 

add_algo("mf", "_small", "MAD AF")
add_algo("mf", "_small_noprepost", "MAD AF (no pre/post)")
add_algo("mf", "_small_complex", "MAD AF (complex)") 
add_algo("mf", "_small_long", "MAD AF (long)")
add_algo("mf", "_small_shallow_long", "MAD AF (shallow, long)")

add_algo("gamf", "_small_largebs", "OT MAD AF (Sinkhorn)")
add_algo("gamf", "_small_ged_largebs", "OT MAD AF (GED)") 
add_algo("gamf", "_small_ged_largebs_long", "OT MAD AF (GED, long)") 
add_algo("gamf", "_small_ged_largebs_shallow_long", "OT MAD AF (GED, shallow, long)") 

add_algo("pie_specified", "_small", "Prescr. PIE") 
add_algo("pie_specified", "_small_long", "Prescr. PIE (long)") 
add_algo("pie_specified", "_small_shallow_long", "Prescr. PIE (shallow, long)") 

add_algo("mf_specified", "_small", "Prescr. MAD AF") 
add_algo("mf_specified", "_small_long", "Prescr. MAD AF (long)") 
add_algo("mf_specified", "_small_shallow_long", "Prescr. MAD AF (shallow, long)") 

add_algo("gamf_specified", "_small_largebs", "Prescr. OT MAD AF")
add_algo("gamf_specified", "_small_long", "Prescr. OT MAD AF (long)") 
add_algo("gamf_specified", "_small_shallow_long", "Prescr. OT MAD AF (shallow, long)") 


In [4]:
def load(name, shape, numpyfy=True, result_dir="../data/results"):
    all_results = []
    
    for algo_filename, algo_add in zip(algo_filenames, algo_additionals):
        results = []
        
        for setup_filename in setup_filenames:
            try:
                results.append(np.load(
                    "{}/{}_2_spherical_gaussian_{}{}_{}.npy".format(
                        result_dir, algo_filename, setup_filename, algo_add, name
                    )
                ))
            except FileNotFoundError as e:
                # print(e)
                
                if shape is None:
                    results.append(None)
                else:
                    results.append(np.nan*np.ones(shape))
            
        all_results.append(results)
    
    return np.asarray(all_results) if numpyfy else all_results

x_gen = load("samples", None, numpyfy=False)
logp_gen = load("samples_likelihood", (10000,))
distance_gen = load("samples_manifold_distance", (10000,))

true_test_log_likelihood = load("true_log_likelihood_test", (1000,))
model_test_log_likelihood = load("model_log_likelihood_test", (1000,))
model_test_reco_error = load("model_reco_error_test", (1000,))

model_ood_log_likelihood = load("model_log_likelihood_ood", (1000,))
model_ood_reco_error = load("model_reco_error_ood", (1000,))


In [5]:
x_test = np.asarray([
    np.load("../data/samples/spherical_gaussian/spherical_gaussian_{}_x_test.npy".format(setup_filename))
    for setup_filename in setup_filenames
])

true_distances = np.abs(np.sum(x_test**2, axis=-1)**0.5 - 1.)

## Calculate metrics

In [6]:
min_logp = -100.
max_distance = 10.

mean_logp_gen = np.mean(np.clip(logp_gen, min_logp, None), axis=2)
mean_distance_gen = np.mean(np.clip(distance_gen, None, max_distance), axis=2)

mean_logp_truth = np.mean(np.clip(true_test_log_likelihood[0], min_logp, None), axis=1)
mean_distance_truth = np.mean(np.clip(true_distances, None, max_distance), axis=1)

In [7]:
def calculate_roc_auc(x0, x1):
    return np.asarray(
        [
            [
                np.nan if not np.all(np.isfinite(np.hstack((this_x0, this_x1)))) else (
                    np.maximum(
                        roc_auc_score(
                            np.hstack((np.zeros(this_x0.shape[0], dtype=np.int),
                                       np.ones(this_x1.shape[0], dtype=np.int))),
                            np.hstack((this_x0, this_x1)),
                        ),
                        roc_auc_score(
                            np.hstack((np.zeros(this_x0.shape[0], dtype=np.int),
                                       np.ones(this_x1.shape[0], dtype=np.int))),
                            - np.hstack((this_x0, this_x1)),
                        )
                    )
                )
                for this_x0, this_x1 in zip(x0_, x1_)
            ]
            for x0_, x1_ in zip(x0, x1)
        ]
    )


auc_logp = calculate_roc_auc(model_test_log_likelihood, model_ood_log_likelihood)
auc_err = calculate_roc_auc(model_test_reco_error, model_ood_reco_error)
auc_use_err = (auc_err > auc_logp)
auc_best = np.maximum(auc_err, auc_logp)



## Print metrics

In [12]:
auc_best

array([[0.694833 , 0.707393 , 0.653708 ],
       [      nan,       nan,       nan],
       [      nan,       nan,       nan],
       [0.708953 , 0.672565 , 0.646126 ],
       [      nan,       nan,       nan],
       [      nan,       nan,       nan],
       [0.658694 , 0.519913 , 0.572225 ],
       [0.507252 , 0.513103 ,       nan],
       [0.506155 , 0.5257865, 0.535804 ],
       [      nan,       nan,       nan],
       [      nan,       nan,       nan],
       [0.510948 , 0.51055  , 0.630187 ],
       [      nan,       nan,       nan],
       [      nan,       nan,       nan],
       [      nan,       nan,       nan],
       [0.73033  , 0.727107 , 0.734296 ],
       [      nan,       nan,       nan],
       [      nan,       nan,       nan],
       [0.730332 , 0.727107 , 0.734296 ],
       [      nan,       nan,       nan],
       [      nan,       nan,       nan],
       [0.5101445, 0.507926 , 0.682529 ],
       [      nan,       nan,       nan],
       [      nan,       nan,     

In [23]:
def print_results(setup):
    print("{:<36.36s} | {:>8.8s} {:>8.8s} {:>8.8s}".format(setup_labels[setup], "Log p", "Dist", "AUC"))
    print("-"*65)
    for label, logp, dist, auc in zip(
        ["Simulator"] + algo_labels,
        [mean_logp_truth] + list(mean_logp_gen),
        [mean_distance_truth] + list(mean_distance_gen),
        [[0.0, 0.0, 0.0]] + list(auc_best)
    ):
        if np.isfinite(logp[setup]) or np.isfinite(dist[setup]):
            print("{:<36.36s} | {:>8.2f} {:>8.3f} {:>8.3f}".format(label, logp[setup], dist[setup], auc[setup]))
        else:
            print("{:<36.36s} |                           ".format(label))

In [24]:
print_results(0)

ε = 0.01                             |    Log p     Dist      AUC
-----------------------------------------------------------------
Simulator                            |     2.32    0.008    0.000
Flow                                 |     1.31    0.033    0.695
Flow (long)                          |                           
Flow (shallow, long)                 |                           
PIE                                  |     1.92    0.024    0.709
PIE (long)                           |                           
PIE (shallow, long)                  |                           
MAD AF                               |     1.32    0.032    0.659
MAD AF (no pre/post)                 |   -48.34    0.139    0.507
MAD AF (complex)                     |   -32.26    0.103    0.506
MAD AF (long)                        |                           
MAD AF (shallow, long)               |                           
OT MAD AF (Sinkhorn)                 |    -2.28    0.037    0.511
OT MAD AF 

In [25]:
print_results(1)

ε = 0.001                            |    Log p     Dist      AUC
-----------------------------------------------------------------
Simulator                            |     4.68    0.001    0.000
Flow                                 |     2.98    0.034    0.707
Flow (long)                          |                           
Flow (shallow, long)                 |                           
PIE                                  |     0.36    0.026    0.673
PIE (long)                           |                           
PIE (shallow, long)                  |                           
MAD AF                               |   -95.85    0.159    0.520
MAD AF (no pre/post)                 |   -91.93    0.117    0.513
MAD AF (complex)                     |   -91.64    0.177    0.526
MAD AF (long)                        |                           
MAD AF (shallow, long)               |                           
OT MAD AF (Sinkhorn)                 |   -69.70    0.034    0.511
OT MAD AF 

In [26]:
print_results(2)

ε = 0.1                              |    Log p     Dist      AUC
-----------------------------------------------------------------
Simulator                            |     0.05    0.081    0.000
Flow                                 |    -0.82    0.103    0.654
Flow (long)                          |                           
Flow (shallow, long)                 |                           
PIE                                  |    -0.28    0.096    0.646
PIE (long)                           |                           
PIE (shallow, long)                  |                           
MAD AF                               |    -1.54    0.179    0.572
MAD AF (no pre/post)                 |                           
MAD AF (complex)                     |    -3.26    0.145    0.536
MAD AF (long)                        |                           
MAD AF (shallow, long)               |                           
OT MAD AF (Sinkhorn)                 |     0.05    0.087    0.630
OT MAD AF 