Skip to content

Commit

Permalink
Add recording of runtime data (#8)
Browse files Browse the repository at this point in the history
* Add folder for timing results

* Add code to create, clear, and add to a timing dictionary for running each evidence network
* Add timing recordings to key parts of the three scripts 

* Bump version

* Re-add style files and ignore files that went AWOL
  • Loading branch information
ThomasGesseyJones committed Aug 16, 2023
1 parent a8a8073 commit 3f58236
Show file tree
Hide file tree
Showing 9 changed files with 120 additions and 14 deletions.
4 changes: 2 additions & 2 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Introduction

:Name: FullyBayesianForecastsExample
:Author: Thomas Gessey-Jones
:Version: 0.0.7
:Version: 0.0.8
:Homepage: https://github.com/ThomasGesseyJones/FullyBayesianForecastsExample
:Paper: TBD

Expand All @@ -18,4 +18,4 @@ Introduction
:target: https://github.com/ThomasGesseyJones/ErrorAffirmations/blob/main/LICENSE
:alt: License information

Example of a fully bayesian forecasts using an evidence network applied to 21-cm cosmology.
Example of a fully bayesian forecasts using an evidence network applied to 21-cm cosmology.
File renamed without changes.
File renamed without changes.
File renamed without changes.
1 change: 1 addition & 0 deletions figures_and_results/timing_data/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*
81 changes: 79 additions & 2 deletions train_evidence_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import yaml
import os
import matplotlib.pyplot as plt
import pickle as pkl
import time


# Parameters
Expand Down Expand Up @@ -72,6 +74,58 @@ def load_configuration_dict() -> dict:
return yaml.safe_load(file)


def timing_filename(noise_sigma: float) -> str:
"""Get the filename for the timing data.
Parameters
----------
noise_sigma : float
The noise sigma in K.
Returns
-------
filename : str
The filename for the timing data.
"""
folder = os.path.join('figures_and_results', 'timing_data')
os.makedirs(folder, exist_ok=True)
return os.path.join(folder, f'en_noise_{noise_sigma:.4f}_timing_data.pkl')


def clear_timing_data(timing_file: str):
"""Clear the timing data file.
Parameters
----------
timing_file : str
The filename for the timing data.
"""
with open(timing_file, 'wb') as file:
pkl.dump({}, file)


def add_timing_data(timing_file: str, entry_name: str, time_s: float):
"""Add timing data to the timing data file.
Parameters
----------
timing_file : str
The filename for the timing data.
entry_name : str
The name of the entry to add.
time_s : float
The time to add in seconds.
"""
if os.path.isfile(timing_file):
with open(timing_file, 'rb') as file:
timing_data = pkl.load(file)
else:
timing_data = {}
timing_data[entry_name] = time_s
with open(timing_file, 'wb') as file:
pkl.dump(timing_data, file)


# Priors
def create_globalemu_prior_samplers(config_dict: dict) -> Collection[Callable]:
"""Create a prior sampler over the globalemu parameters.
Expand Down Expand Up @@ -167,14 +221,21 @@ def main():
# IO
sigma_noise = get_noise_sigma()
config_dict = load_configuration_dict()
timing_file = timing_filename(sigma_noise)

# Set-up simulators
start = time.time()
noise_only_simulator, noisy_signal_simulator = assemble_simulators(
config_dict, sigma_noise)
end = time.time()
add_timing_data(timing_file, 'simulator_assembly', end - start)

# Create and train evidence network
start = time.time()
en = EvidenceNetwork(noise_only_simulator, noisy_signal_simulator)
en.train()
end = time.time()
add_timing_data(timing_file, 'network_training', end - start)

# Save the network
network_folder = os.path.join("models", f'en_noise_{sigma_noise:.4f}')
Expand All @@ -183,11 +244,16 @@ def main():
en.save(network_file)

# Perform blind coverage test
plt.style.use(os.path.join('figures', 'mnras_single.mplstyle'))
start = time.time()
plt.style.use(os.path.join('figures_and_results', 'mnras_single.mplstyle'))
fig, ax = plt.subplots()
_ = en.blind_coverage_test(plotting_ax=ax, num_validation_samples=10_000)
fig.savefig(os.path.join('figures', 'blind_coverage_tests',
figure_folder = os.path.join('figures_and_results', 'blind_coverage_tests')
os.makedirs(figure_folder, exist_ok=True)
fig.savefig(os.path.join(figure_folder,
f'en_noise_{sigma_noise:.4f}_blind_coverage.pdf'))
end = time.time()
add_timing_data(timing_file, 'bct', end - start)

# Verification evaluations for comparison with other methods
verification_ds_per_model = config_dict['verification_data_sets_per_model']
Expand All @@ -198,6 +264,17 @@ def main():
f'noise_{sigma_noise:.4f}_verification_data.npz'),
data=data, labels=labels, log_bayes_ratios=log_bayes_ratios)

# Verification evaluations for comparison with other methods
verification_ds_per_model = config_dict['verification_data_sets_per_model']
data, labels = en.get_simulated_data(verification_ds_per_model)
log_bayes_ratios = en.evaluate_log_bayes_ratio(data)
os.makedirs('verification_data', exist_ok=True)
np.savez(os.path.join('verification_data',
f'noise_{sigma_noise:.4f}_verification_data.npz'),
data=np.squeeze(data),
labels=np.squeeze(labels),
log_bayes_ratios=np.squeeze(log_bayes_ratios))


if __name__ == "__main__":
main()
23 changes: 17 additions & 6 deletions verification_with_polychord.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@
# Required imports
from __future__ import annotations
from typing import Callable, Tuple
from train_evidence_network import get_noise_sigma, load_configuration_dict
from train_evidence_network import get_noise_sigma, load_configuration_dict, \
timing_filename, add_timing_data
from simulators.twenty_one_cm import load_globalemu_emulator, \
global_signal_experiment_measurement_redshifts, GLOBALEMU_INPUTS, \
GLOBALEMU_PARAMETER_RANGES
Expand All @@ -33,6 +34,7 @@
from copy import deepcopy
from scipy.stats import truncnorm
import matplotlib.pyplot as plt
import time

# Parameters
CHAIN_DIR = "chains"
Expand Down Expand Up @@ -217,6 +219,7 @@ def main():
# Get noise sigma and configuration data
sigma_noise = get_noise_sigma()
config_dict = load_configuration_dict()
timing_file = timing_filename(sigma_noise)

# Load verification data
verification_data_file = (
Expand All @@ -243,6 +246,7 @@ def main():
pc_nlike = []

settings = None
start = time.time()
for data in v_data:
# Can find noise only evidence analytically
log_z_noise_only = noise_only_log_evidence(data, sigma_noise)
Expand Down Expand Up @@ -292,23 +296,30 @@ def main():
if rank != 0:
return

# Record timing data
end = time.time()
add_timing_data(timing_file, 'total_polychord_log_k',
end - start)
add_timing_data(timing_file, 'average_polychord_log_k',
(end - start) / v_data.shape[0])

try:
shutil.rmtree(settings.base_dir)
except OSError:
pass

# Save PolyChord log bayes ratios if needed for later comparison
pc_log_bayes_ratios = np.array(pc_log_bayes_ratios)
pc_log_bayes_ratios = np.squeeze(np.array(pc_log_bayes_ratios))
polychord_data_file = (
os.path.join('verification_data',
f'noise_{sigma_noise:.4f}_polychord_log_k.npz'))
np.savez(polychord_data_file, log_bayes_ratios=pc_log_bayes_ratios)

# Create output directory for results of comparison
os.makedirs(os.path.join("figures",
os.makedirs(os.path.join("figures_and_results",
"polychord_verification"), exist_ok=True)
numeric_results_filename = os.path.join(
"figures",
"figures_and_results",
"polychord_verification",
f"polychord_verification_"
f"en_noise_{sigma_noise:.4f}_K_results.txt")
Expand All @@ -334,7 +345,7 @@ def main():
numeric_results_file.close()

# Plot results
plt.style.use(os.path.join('figures', 'mnras_single.mplstyle'))
plt.style.use(os.path.join('figures_and_results', 'mnras_single.mplstyle'))
fig, ax = plt.subplots()
ax.scatter(en_log_bayes_ratios, pc_log_bayes_ratios, c='C0')
min_log_z = np.min([np.min(en_log_bayes_ratios),
Expand All @@ -348,7 +359,7 @@ def main():
# Save figure
fig.tight_layout()
filename = os.path.join(
"figures",
"figures_and_results",
"polychord_verification",
f"polychord_verification_"
f"en_noise_{sigma_noise:.4f}_K.pdf")
Expand Down
25 changes: 21 additions & 4 deletions visualize_forecast.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,14 @@
from typing import Collection
from evidence_networks import EvidenceNetwork
from train_evidence_network import get_noise_sigma, load_configuration_dict, \
assemble_simulators
assemble_simulators, timing_filename, add_timing_data
from matplotlib.ticker import MaxNLocator
import os
import numpy as np
from pandas import DataFrame
import matplotlib.pyplot as plt
from math import erf
import time


# Parameters
Expand Down Expand Up @@ -378,8 +379,10 @@ def main():
# IO
sigma_noise = get_noise_sigma()
config_dict = load_configuration_dict()
timing_file = timing_filename(sigma_noise)

# Set up simulators
start = time.time()
noise_only_simulator, noisy_signal_simulator = assemble_simulators(
config_dict, sigma_noise)

Expand All @@ -388,15 +391,22 @@ def main():
network_folder = os.path.join("models", f'en_noise_{sigma_noise:.4f}')
network_file = os.path.join(network_folder, "global_signal_en.h5")
en.load(network_file)
end = time.time()
add_timing_data(timing_file, 'en_loading', end - start)

# Generate mock data for forecast and evaluate log Bayes ratio
start = time.time()
num_data_sets = config_dict["br_evaluations_for_forecast"]
mock_data_w_signal, signal_params = \
noisy_signal_simulator(num_data_sets)
log_bayes_ratios = en.evaluate_log_bayes_ratio(mock_data_w_signal)
end = time.time()
add_timing_data(timing_file, 'en_fbf_log_k_evaluations',
end - start)

# Set-up plotting style and variables
plt.style.use(os.path.join('figures', 'mnras_single.mplstyle'))
start = time.time()
plt.style.use(os.path.join('figures_and_results', 'mnras_single.mplstyle'))
plt.rcParams.update({'figure.figsize': (3.33, 3.33)})
plt.rcParams.update({'ytick.labelsize': 6})
plt.rcParams.update({'xtick.labelsize': 6})
Expand All @@ -414,7 +424,7 @@ def main():
parameters_to_log = config_dict["parameters_to_log"]

# Plotting
os.makedirs(os.path.join("figures",
os.makedirs(os.path.join("figures_and_results",
"detectability_triangle_plots"), exist_ok=True)
for detection_threshold in detection_thresholds:
fig = detectability_corner_plot(
Expand All @@ -426,14 +436,21 @@ def main():
parameters_to_log,
plotting_ranges={'tau': (0.040, 0.075)})
filename = os.path.join(
"figures",
"figures_and_results",
"detectability_triangle_plots",
f"detectability_triangle_"
f"{str(detection_threshold).replace(' ', '_')}_"
f"noise_{sigma_noise:.4f}_K.pdf")
fig.savefig(filename)
plt.close(fig)

# Store timing data
end = time.time()
add_timing_data(timing_file, 'total_fbf_plotting',
end - start)
add_timing_data(timing_file, 'average_fbf_plotting',
(end - start)/len(detection_thresholds))


if __name__ == "__main__":
main()

0 comments on commit 3f58236

Please sign in to comment.