In [None]:
import os
import matplotlib.pyplot as plt
import numpy as np
import yaml
from argparse import ArgumentParser

plt.rcParams.update({
    "text.usetex": True,
    "font.family": "Helvetica",
    "font.size": 20,
})

script_dir = os.path.dirname(os.path.abspath(''))

# ----------------------------------------------------------------------------------- #

log_file_paths = [
    f"{script_dir}/experiments/output/physical_im/J_0.1_do/2023-09-01_11:07:27+0200/logs.yaml",
]  # paths to files with logs
rows_number = 1  # number of rows in a figure
columns_number = 1  # number of columns in a figure

# path to the output *.pdf file
output_path = f"{script_dir}/../6467a173cb828c16e8ae9ac3/learning_curves.pdf"

# ----------------------------------------------------------------------------------- #
fig = plt.figure(figsize=(columns_number * 6.4 / 1.5, rows_number * 4.8 / 1.5))
for k, path in enumerate(log_file_paths):
    with open(path, "r") as log_file:
        try:
            log = yaml.safe_load(log_file)
        except yaml.YAMLError as exc:
            print(exc)
    epochs_number = int(log["training_params"]["epochs_number"])
    loss_value_exact_model = float(log["loss_value_exact_model"])
    initial_cosin_sim = float(log["initial_metrics"]["cosin_sim"])
    initial_mean_trace_dist = float(log["initial_metrics"]["mean_trace_dist"])
    cosin_sim = [initial_cosin_sim] + [float(log[i]["cosin_sim"]) for i in range(1, epochs_number + 1)]
    mean_trace_dist = [initial_mean_trace_dist] + [float(log[i]["mean_trace_dist"]) for i in range(1, epochs_number + 1)]
    loss_value = [float(log[i]["loss_value"]) for i in range(1, epochs_number + 1)]
    plt.subplot(rows_number, columns_number, k + 1)
    plt.plot(list(range(1, epochs_number + 2)), 1 - np.array(cosin_sim), 'b-')
    plt.plot(list(range(1, epochs_number + 2)), mean_trace_dist, 'b--')
    plt.yscale('log')
    plt.xscale('log')
    plt.legend([r"$1 - F$", r"${\rm err}$"],frameon=False)
fig.text(0.5, -0.1, r"${\rm Epochs \ number}$", ha='center')
fig.show()

In [None]:
fig.savefig(output_path, bbox_inches = 'tight')