In [1]:
from jax import random, jit, vmap
import os
path = os.getcwd()
print("Old path:", path)
path = (path.split('/'))
path = path[:path.index("ABC-SBI")+1]
path = '/'.join(path)
print("New path:", path)
os.chdir(path)
from functions.simulation import get_dataset, get_epsilon_star, get_newdataset
from functions.training import train_loop
from functions.SBC import SBC_epsilon, plot_SBC
import jax.numpy as jnp
import time
import pickle 
import lzma
from jax.scipy.stats import norm


@jit
def prior_simulator(key):
    return random.normal(key, (1,))*SIGMA0 + MU0

@jit
def data_simulator(key, theta):
    return (random.normal(key, (N_DATA,))*SIGMA + theta).astype(float)

@jit
def discrepancy(y, y_true):
    return (jnp.mean(y) - jnp.mean(y_true))**2


Old path: /Users/antoineluciano/Documents/Recherche/ABC-SBI/examples/Gauss-Gauss
New path: /Users/antoineluciano/Documents/Recherche/ABC-SBI


In [None]:
key = random.PRNGKey(0)
 
MU0 = 0.
SIGMA = 1.
MODEL_ARGS = [SIGMA]
N_DATA = 100
N_POINTS_TRAIN = 1000000
N_POINTS_TEST = 100000
N_POINTS_EPS = 10000
sim_args = None


N_EPOCHS = 100
LEARNING_RATE = 0.001
PATIENCE = 7
COOLDOWN = 0
FACTOR = .5
RTOL = 1e-4  
ACCUMULATION_SIZE = 200
LEARNING_RATE_MIN = 1e-6

BATCH_SIZE = 256
NUM_BATCH = 1024
NUM_CLASSES = 2
HIDDEN_SIZE = 256
NUM_LAYERS = 7
WDECAY = .001
N_GRID_FINAL = 10000
N_GRID_EXPLO = 1000
MINN, MAXX = -50.,50. 
L = 127
N_SBC = (L+1)*100

PATH_RESULTS = os.getcwd() + "/examples/Gauss-Gauss/results/"



SIGMAS0 = [10*SIGMA]

for SIGMA0 in SIGMAS0:
    TRUE_MUS = [.1*SIGMA0, .5*SIGMA0, SIGMA0, 1.5*SIGMA0]
    PRIOR_ARGS = [MU0, SIGMA0]
    PRIOR_LOGPDF = lambda x: norm.logpdf(x, loc = MU0, scale = SIGMA0)

    for TRUE_MU in TRUE_MUS:
        EPSILON_STAR = jnp.inf
        key, subkey = random.split(key)
        TRUE_DATA = data_simulator(subkey, TRUE_MU)
        for ACCEPT_RATE in [1., .999, .99, .975, .95, .925]:
            print("\n\n--------------------")
            print("SIGMA0 = {}, TRUE_MU = {}, ACCEPT_RATE = {}".format(SIGMA0, TRUE_MU, ACCEPT_RATE))
            print("--------------------\n\n")
                    
            
            time_eps = time.time()
            print("Selection of epsilon star...")
            EPSILON_STAR, key = get_epsilon_star(key, ACCEPT_RATE, N_POINTS_EPS, prior_simulator, data_simulator, discrepancy, TRUE_DATA, quantile_rate = .99, epsilon = EPSILON_STAR)
            print('Time to select epsilon star: {:.2f}s\n'.format(time.time()-time_eps))

            print("Simulations of the testing dataset...")
            time_sim = time.time()
            X_test, y_test, key = get_newdataset(key, N_POINTS_TEST, prior_simulator, data_simulator, discrepancy, EPSILON_STAR, TRUE_DATA)
            print('Time to simulate the testing dataset: {:.2f}s\n'.format(time.time()-time_sim))

            # print("Simulations of the training dataset...")
            # time_sim = time.time()
            # X_train, y_train, key = get_dataset(key, N_POINTS_TRAIN, prior_simulator, data_simulator, discrepancy, EPSILON_STAR, TRUE_DATA)
            # print('Time to simulate the training dataset: {:.2f}s\n'.format(time.time()-time_sim))


            print("Training the neural network...")
            time_nn = time.time()
            params, train_accuracy, train_losses, test_accuracy, test_losses, key = train_loop(key, N_EPOCHS, NUM_LAYERS, HIDDEN_SIZE, NUM_CLASSES, BATCH_SIZE, NUM_BATCH, LEARNING_RATE, WDECAY, PATIENCE, COOLDOWN, FACTOR, RTOL, ACCUMULATION_SIZE, LEARNING_RATE_MIN, prior_simulator, data_simulator, discrepancy, true_data = TRUE_DATA, X_train = None, y_train = None, X_test = X_test, y_test =  y_test, N_POINTS_TRAIN = N_POINTS_TRAIN, N_POINTS_TEST = N_POINTS_TEST, epsilon = EPSILON_STAR, verbose = True)
            print('Time to train the neural network: {:.2f}s\n'.format(time.time()-time_nn))


            print("Simulation Based Calibration...")
            time_sbc = time.time()

            ranks, thetas_tilde, thetas, key = SBC_epsilon(key = key, N_SBC = N_SBC, L = L, params = params, epsilon = EPSILON_STAR, true_data = TRUE_DATA, prior_simulator = prior_simulator, prior_logpdf = PRIOR_LOGPDF, data_simulator = data_simulator, discrepancy = discrepancy, n_grid_explo = N_GRID_EXPLO, n_grid_final = N_GRID_FINAL, minn = MINN, maxx = MAXX)

            print('Time to perform SBC: {:.2f}s\n'.format(time.time()-time_sbc))


            pickle_dico = {"ranks": ranks, "thetas_tilde": thetas_tilde, "thetas": thetas, "epsilon":EPSILON_STAR, "KEY":key, "N_SBC":N_SBC, "L":L, "N_GRID_EXPLO": N_GRID_EXPLO, 'N_GRID_FINAL': N_GRID_FINAL,"TRUE_DATA": TRUE_DATA, "TRUE_THETA": TRUE_MU, "params": params, "train_accuracy":train_accuracy, "test_accuracy":test_accuracy, "MODEL_ARGS":MODEL_ARGS, "PRIOR_ARGS":PRIOR_ARGS, "N_POINTS_TRAIN":N_POINTS_TRAIN, "N_POINTS_TEST":N_POINTS_TEST, "N_DATA":N_DATA, "N_EPOCHS":N_EPOCHS, "LEARNING_RATE":LEARNING_RATE, "PATIENCE":PATIENCE, "COOLDOWN":COOLDOWN, "FACTOR":FACTOR, "RTOL":RTOL, "ACCUMULATION_SIZE":ACCUMULATION_SIZE, "LEARNING_RATE_MIN":LEARNING_RATE_MIN, "BATCH_SIZE":BATCH_SIZE, "NUM_BATCH":NUM_BATCH, "NUM_CLASSES":NUM_CLASSES, "HIDDEN_SIZE":HIDDEN_SIZE, "NUM_LAYERS":NUM_LAYERS, "WDECAY":WDECAY}

            NAMEFILE = PATH_RESULTS+"GaussGauss_sigma_{}_sigma0_{}_mu_{}_acc_{:.3}_eps_{:.5}"
            with lzma.open(NAMEFILE+".xy", "wb") as f:
                pickle.dump(pickle_dico, f)
            print("Data saved in ", NAMEFILE+".xy")

            title = "Normal w/ known std\nsigma = {}, sigma0 = {} mu = {}\nalpha = {:.2%}, eps = {:.3} accuracy = {:.2%}".format(SIGMA, SIGMA0, TRUE_MU, ACCEPT_RATE, EPSILON_STAR, test_accuracy[-1])

            plot_SBC(ranks, L, B = 16, title = title, save_name = NAMEFILE+".png")
            print("Plot saved in ",NAMEFILE+".png")
            
            print("\n\n--------------------")
            print("ITERATION (ACC = {}) DONE IN {} SECONDS!".format(ACCEPT_RATE, time.time()-time_eps))
            print("--------------------\n\n")



--------------------
SIGMA0 = 10.0, TRUE_MU = 1.0, ACCEPT_RATE = 1.0
--------------------


Selection of epsilon star...
Distances: min =  3.2896725e-08 max =  1381.2203 mean =  101.74406 std =  141.39943
Time to select epsilon star: 0.64s

Simulations of the testing dataset...
Time to simulate the testing dataset: 0.90s

Training the neural network...
Initial accuracy: 50.34%, Initial test accuracy: 50.34%
Training for 100 epochs...
Epoch 1/100, mean train accuracy: 90.42%, mean test accuracy: 97.35%, lr scale: 1.0 in 8.64 sec
Epoch 2/100, mean train accuracy: 95.11%, mean test accuracy: 97.47%, lr scale: 1.0 in 7.88 sec
Epoch 3/100, mean train accuracy: 95.80%, mean test accuracy: 94.71%, lr scale: 1.0 in 8.18 sec
Epoch 4/100, mean train accuracy: 96.12%, mean test accuracy: 96.20%, lr scale: 1.0 in 7.93 sec
Epoch 5/100, mean train accuracy: 97.30%, mean test accuracy: 98.26%, lr scale: 0.5 in 7.91 sec
Epoch 6/100, mean train accuracy: 97.45%, mean test accuracy: 98.08%, lr scale: 

  0%|          | 0/12800 [00:00<?, ?it/s]