# Benchmark pytests in ml4gw
## Commands
Run `pytest tests/waveforms/cbc/test_cbc_waveforms.py --benchmark=N` to get testing data at N samples

Run `python plot_benchmark.py` to a 2d histogram with chirp mass and mass ratio by default

## Below is two seperate section explaining the plotting code and data collection code

### Plotting Code
The plotting code runs on the assumption that the data was saved with the benchmarking code as it infers what each file contains based on the saved now. The code below shows `plot_benchmark.ipynb` which plots a 2d histogram with chirp mass and mass ratio as the axis and error as the density. The plots can be adjusted by passing a different plotting function into `plot_tests` as well as updating the plotting_keys to reflect what parameters your want to plot; the lines associated to doing this are,
```
plotting_function(data[err_key], data[plotting_keys[1]], data[plotting_keys[2]], f"{test}_{err_key}")
```
where changing the data[X] inputs will change the inputs to your plotting function; this function is called on all error keys found where err_key is each unqiue error key.
```
plot_tests(tests, plotting_function=plot_err_chirp_mass, plotting_keys=['chirp_mass', 'mass_ratio'])
```
When calling plot_tests make sure to specifiy which plotting function you want to use and what keys you are using as well.

In [None]:
import h5py
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LogNorm
from tqdm import tqdm
import os

def load_data_from_h5(filename):
    """
    Loads data from an HDF5 file while preserving the original structure.
    args:
        filename (str): The path to the HDF5 file.
    returns:
        data (dict): A dictionary containing the data from the HDF5 file.
    """
    data = {}
    with h5py.File(filename, "r") as f:
        for group_key in tqdm(f.keys()): # Iterate over singular tests
            h5_group = f[group_key]
            if group_key not in data:
                data[group_key] = {} # Adds group_key (params and errors) to data if it doesn't exist
            for dataset_key in h5_group.keys():
                dataset = h5_group[dataset_key] # add group do data dict
                if dataset.shape == ():  # Check if the dataset is a scalar
                    data[group_key][dataset_key] = np.array([dataset[()]])
                else:
                    data[group_key][dataset_key] = dataset[:]
    return data

def plot_err_chirp_mass(err, chirp_mass, mass_ratio, file):
    """
    Plots the error data as a 2d historgram with chirp mass and
    mass ratio as the axis.
    """
    # Make sure the data is 1D
    chirp_mass = np.asarray(chirp_mass).reshape(-1)
    mass_ratio = np.asarray(mass_ratio).reshape(-1)
    err = np.asarray(err).reshape(-1)

    # Make plots
    plt.figure(figsize=(10, 5))
    cmap = plt.get_cmap('viridis')
    min_color = cmap(0.0)
    cmap.set_under(min_color)
    cmap.set_bad(min_color)
    plt.hist2d(chirp_mass, mass_ratio, bins=250, norm=LogNorm(), weights=err, cmap=cmap)
    plt.colorbar(label="Error")
    plt.xlabel("Chirp Mass")
    plt.ylabel("Mass Ratio")
    plt.title("Histogram of Differences between lal and ml4gw waveforms")
    plt.suptitle(f"Total number of differences ({len(err)})")
    plt.show()
    plt.savefig(f"benchmark_plots/{file}_histogram.png")
    plt.close()
    print(f"Saved plot to benchmark_plots/{file}_histogram.png")

def plot_tests(tests, plotting_function=plot_err_chirp_mass, plotting_keys=['chirp_mass', 'mass_ratio']):
    """
    Function to plot data using plot_err_chirp_mass.
    args:
        tests (list): List of tests to plot. (Uses these strings to find the files)
    returns:
        None
    """
    # Loop through provided tests
    for test in tests:
        # Change this based on where data is stored and what it is called
        file_prefix = "benchmark_data"
        folder = "benchmark_data"
        num_files = 2 # Number of data files to include in plots, 0 defaults to all files
        start_file = 0 # Start file index

        segmented_data = {} # Dictionary to hold segmented data as each file is loaded seperately before conjoining into one dataset
        print(f"Loading data for test: {test}")

        # Find all files matching the pattern if num_files = 0
        if num_files == 0:
            index = 0
            while os.path.exists(f"{folder}/{file_prefix}_{test}_{index}.h5"): # Find all files with the test name
                num_files += 1
                print(f"Found file: {file_prefix}_{test}_{index}.h5")
                index += 1
            
            if num_files == 0:
                print(f"No files found for test: {test}. Skipping...")
                continue  # Skip to the next test
        
        # Load data from each file
        for i in range(num_files):
            filename = f"{folder}/{file_prefix}_{test}_{i + start_file}.h5"
            file_data = load_data_from_h5(filename)
            print(f"loaded data: {i + start_file}")
            if i not in segmented_data:
                segmented_data[i] = {}
            segmented_data[i].update(file_data)

        # Find all the error keys to plot using file names
        # Data files must have err in the file name or program will not recognize
        err_keys = []
        for file_key in segmented_data:
            for dataset_key in segmented_data[file_key].keys():
                for key in segmented_data[file_key][dataset_key].keys():
                    if 'err' in key and key not in err_keys:
                        err_keys.append(key)
        
        print(f"Found error keys: {err_keys}")

        # Loop over segmented data to compile into one dataset using each unique err_key
        for err_key in err_keys:
            # Initialize data
            data = {}
            # Loop over segmented data to compile all datsets that contain specific err_key
            for file_key in segmented_data:
                for dataset_key in segmented_data[file_key].keys():
                    if err_key in segmented_data[file_key][dataset_key].keys():
                        for key in segmented_data[file_key][dataset_key]:
                            if key not in data:
                                data[key] = []
                            data[key].extend(segmented_data[file_key][dataset_key][key])
            
            # Plot the err data
            print(f"Plotting {err_key}")
            plotting_function(data[err_key], data[plotting_keys[0]], data[plotting_keys[1]], f"{test}_{err_key}")

# Example Usage
tests = ['phenom_p', 'phenom_d']

plot_tests(tests, plotting_function=plot_err_chirp_mass, plotting_keys=['chirp_mass', 'mass_ratio'])

### Data Collection Code
Benchmarking code is the same as the normal pytest code i.e.
```
assert np.allclose(
                1e21 * hp_lal_data.real, 1e21 * hp_ml4gw.real.numpy(), atol=1e-3
            )
```
but loaded and saved with,

In [None]:
# EXAMPLE
test_1_hp_real_abs_err = np.max(np.abs(1e21 * hp_lal_data.real - 1e21 * hp_ml4gw.real.numpy()))

data = {
                "test_1_hp_real_abs_err": test_1_hp_real_abs_err
        }
def write_benchmark_data(filename, dataset):
    """Write benchmark data to an HDF5 file, creating a new group for each run."""
    try:
        with h5py.File(f"{filename}", "a") as f:
            # Iterate using numbers as group names
            existing_groupnames = [int(name) for name in f.keys() if name.isdigit()]
            if existing_groupnames:
                next_groupname = max(existing_groupnames) + 1
            else:
                next_groupname = 0
            groupname = str(next_groupname)
            group = f.require_group(groupname)
            # Create datasets for each key in the dataset dictionary
            for key, data in dataset.items():
                data = np.array(data, dtype=np.float32)
                if key not in f:
                    if data.ndim == 0:
                        dset = group.create_dataset(key, data=data)
                    else:
                        dset = group.create_dataset(key, data=data, maxshape=(None,), compression="gzip")
                else:
                    dset = f[key]
                    if data.ndim == 0:
                        dset[...] = data
                    else:
                        current_size = dset.shape[0]
                        new_size = current_size + data.shape[0]
                        dset.resize(new_size, axis=0)
                        dset[current_size:new_size] = data
            f.flush()
    except Exception as e:
        print(f"Error writing data to file: {e}")
        raise

When calling pytest there is an additional option --benchmark=N where N is the number of samples you want to run the tests with. Additionally the data is saved and batched to save I/O on the device. The place to change this option is,

In [None]:
def get_next_file_name(base_name, extension="h5"):
    """Get the next available file name with a numeric suffix."""
    index = 0
    while os.path.exists(f"{base_name}_{index}.{extension}") and get_file_size(f"{base_name}_{index}.{extension}") >= 0.05:
        index += 1
    return f"{base_name}_{index}.{extension}"

Where `while os.path.exists(f"{base_name}_{index}.{extension}") and get_file_size(f"{base_name}_{index}.{extension}") >= 0.05:` specifies how large you want each file to be before moving on as well as,

In [None]:
@pytest.fixture()
def batch_size():
    return 100

Being the batch size of testing data to be written to each file at a time.

Another nice feature of pytest is the ability to select which tests to run in a particular file. To use this feature add the option `-k` follow with the name of the test function you want. In `test_cbc_waveforms.py` the three options are: `test_phenom_d`, `test_phenom_p`, and `test_taylor_f2`. A full example of a command would be to run,

In [None]:
pytest --benchmark=100 -k test_phenom_d tests/waveforms/cbc/test_cbc_waveforms.py

This would start collecting 100 samples of testing data for ONLY the test for the phenom_d waveform generation.