In [1]:
import matplotlib.pyplot as plt
import pickle as pkl
from scipy.stats import gaussian_kde
import numpy as np
import math
import os

In [2]:

base_dir = "../res/gnk/"

observations = [100, 500, 1000, 5000]

simulation_functions = [
    lambda n: n,
    lambda n: int(n * math.log(n)),
    lambda n: int(n ** (3/2)),
    lambda n: n ** 2
]

num_seeds = 21

# LaTeX table header
latex_table = "\\begin{tabular}{|c|c|c|c|c|}\n\\hline\n"
latex_table += "n (num obs) & N=n & N=nlog(n) & N=n^(3/2) & N=n^2 \\\\ \n\\hline\n"

# Process each observation level
for n_obs in observations:
    latex_table += f"{n_obs} "
    for sim_func in simulation_functions:
        n_sims = sim_func(n_obs)
        kl_values = []
        # Directory for this configuration
        for i in range(num_seeds):
            dir_path = os.path.join(base_dir, f"npe_n_obs_{n_obs}_n_sims_{n_sims}_seed_{str(i)}/")
            if os.path.exists(dir_path):
                full_path = os.path.join(dir_path, "kl.txt")
                if os.path.isfile(full_path):
                    with open(full_path, 'r') as file:
                        try:
                            kl_value = float(file.read().strip())
                            kl_values.append(kl_value)
                        except ValueError:
                            continue

        # Calculate mean and standard deviation if data is available
        masked_kl_values = np.ma.masked_invalid(kl_values)
        num_valid = np.isfinite(kl_values).sum()
        if num_valid > 0:
            mean_kl = masked_kl_values.mean()
            median_kl = np.ma.median(masked_kl_values)
            std_dev_kl = masked_kl_values.std() if num_valid > 1 else 0
            latex_table += f" & {mean_kl:.2f}/{median_kl:.2f} ({std_dev_kl:.2f}) [{num_valid}]"
        else:
            latex_table += " & -"
    latex_table += " \\\\ \n\\hline\n"

# Close the table
latex_table += "\\end{tabular}"

# Print the LaTeX table string
print(latex_table)


\begin{tabular}{|c|c|c|c|c|}
\hline
n (num obs) & N=n & N=nlog(n) & N=n^(3/2) & N=n^2 \\ 
\hline
100  & 9.48/10.29 (1.61) [5] & 8.48/9.23 (1.57) [5] & 6.44/6.29 (0.85) [5] & 4.48/4.28 (0.60) [5] \\ 
\hline
500  & 9.64/9.41 (0.92) [21] & 6.78/6.85 (0.74) [21] & 4.89/4.89 (0.67) [21] & 2.90/2.91 (0.45) [21] \\ 
\hline
1000  & 9.21/8.94 (1.12) [21] & 6.26/6.16 (0.85) [21] & 4.37/4.29 (0.46) [21] & 2.78/2.73 (0.71) [21] \\ 
\hline
5000  & 8.95/8.68 (1.13) [21] & 5.66/5.61 (0.55) [21] & 3.97/3.93 (0.43) [21] & 2.47/2.36 (0.50) [20] \\ 
\hline
\end{tabular}
