In [None]:
%matplotlib inline
import numpy as np
import matplotlib.pyplot as plt

from os import path
from glob import glob
import sys
import csv
sys.path.insert(0, path.abspath('./'))
import matplotlib.patches as mpatches

from src import workdir, parse_model_parameter_file
from src.emulator_BAND import EmulatorBAND
from src.emulator import Emulator


# Define functions to compare multiple GP emulators 

In [None]:
def train_multiple_emulators(training_set, model_par, number_test_points, logFlag):
    emu1 = EmulatorBAND(training_set, model_par, method='PCGP', logTrafo=logFlag)
    emu2 = EmulatorBAND(training_set, model_par, method='PCSK', logTrafo=logFlag)
    emu3 = Emulator(training_set, model_par, npc = 4, logTrafo=logFlag)

    output_emu1 = emu1.testEmulatorErrors(number_test_points=number_test_points)
    emu_pred_1 = output_emu1[0]
    emu_pred_err_1 = output_emu1[1]
    vali_data_1 = output_emu1[2]
    vali_data_err_1 = output_emu1[3]

    output_emu2 = emu2.testEmulatorErrors(number_test_points=number_test_points)
    emu_pred_2 = output_emu2[0]
    emu_pred_err_2 = output_emu2[1]
    vali_data_2 = output_emu2[2]
    vali_data_err_2 = output_emu2[3]

    output_emu3 = emu3.testEmulatorErrors(nTestPoints=number_test_points)
    emu_pred_3 = output_emu3[0]
    emu_pred_err_3 = output_emu3[1]
    vali_data_3 = output_emu3[2]
    vali_data_err_3 = output_emu3[3]

    nObs = vali_data_1.shape[1]  # Assuming all datasets have the same number of observables

    X1_obs = []
    X2_obs = []
    X3_obs = []
    mean_emulator_pred_err_obs_1 = []
    mean_emulator_pred_err_obs_2 = []
    mean_emulator_pred_err_obs_3 = []
    moments_X1 = []
    moments_X2 = []
    moments_X3 = []
    for obsIdx in range(nObs):
        X1 = (emu_pred_1[:, obsIdx] - vali_data_1[:, obsIdx]) / emu_pred_err_1[:, obsIdx]
        X2 = (emu_pred_2[:, obsIdx] - vali_data_2[:, obsIdx]) / emu_pred_err_2[:, obsIdx]
        X3 = (emu_pred_3[:, obsIdx] - vali_data_3[:, obsIdx]) / emu_pred_err_3[:, obsIdx]
        X1_obs.append(X1)
        X2_obs.append(X2)
        X3_obs.append(X3)
        mean_emulator_pred_err_obs_1.extend([np.mean(emu_pred_err_1[:, obsIdx]/emu_pred_1[:, obsIdx])])
        mean_emulator_pred_err_obs_2.extend([np.mean(emu_pred_err_2[:, obsIdx]/emu_pred_2[:, obsIdx])])
        mean_emulator_pred_err_obs_3.extend([np.mean(emu_pred_err_3[:, obsIdx]/emu_pred_3[:, obsIdx])])

        # Compute first four moments of the distributions
        mean1 = np.mean(X1)
        variance1 = np.mean((X1 - mean1)**2.)
        skewness1 = np.mean((X1 - mean1)**3.) / variance1**(3./2.)
        kurtosis1 = np.mean((X1 - mean1)**4.) / variance1**(4./2.) - 3.
        moments1 = (mean1, variance1, skewness1, kurtosis1)
        moments_X1.append(moments1)

        mean2 = np.mean(X2)
        variance2 = np.mean((X2 - mean2)**2.)
        skewness2 = np.mean((X2 - mean2)**3.) / variance2**(3./2.)
        kurtosis2 = np.mean((X2 - mean2)**4.) / variance2**(4./2.) - 3.
        moments2 = (mean2, variance2, skewness2, kurtosis2)
        moments_X2.append(moments2)

        mean3 = np.mean(X3)
        variance3 = np.mean((X3 - mean3)**2.)
        skewness3 = np.mean((X3 - mean3)**3.) / variance3**(3./2.)
        kurtosis3 = np.mean((X3 - mean3)**4.) / variance3**(4./2.) - 3.
        moments3 = (mean3, variance3, skewness3, kurtosis3)
        moments_X3.append(moments3)

    return (X1_obs,X2_obs,X3_obs), (moments_X1,moments_X2,moments_X3), (mean_emulator_pred_err_obs_1,mean_emulator_pred_err_obs_2,mean_emulator_pred_err_obs_3)


Write functions for the output of the different GP's

In [None]:
def write_output_to_csv_mean_emu_err(filename, data):
    with open(filename, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        for row in data:
            writer.writerow([row])

def write_output_to_csv_moments_and_X(filename, data):
    with open(filename, 'w', newline='') as csvfile:
        writer = csv.writer(csvfile)
        for row in data:
            writer.writerow(row)

def train_multiple_emulators_and_write_to_csv(training_set, model_par, number_test_points, output_file, logFlag):
    # Your existing function code...
    (X1_obs, X2_obs, X3_obs), (moments_X1, moments_X2, moments_X3), (mean_emulator_pred_err_obs_1, mean_emulator_pred_err_obs_2, mean_emulator_pred_err_obs_3) = train_multiple_emulators(training_set, model_par, number_test_points, logFlag)
    
    # Write X1_obs, X2_obs, X3_obs to CSV
    write_output_to_csv_moments_and_X(output_file + f'_{number_test_points}_X1_obs.dat', X1_obs)
    write_output_to_csv_moments_and_X(output_file + f'_{number_test_points}_X2_obs.dat', X2_obs)
    write_output_to_csv_moments_and_X(output_file + f'_{number_test_points}_X3_obs.dat', X3_obs)

    # Write moments_X1, moments_X2, moments_X3 to CSV
    write_output_to_csv_moments_and_X(output_file + f'_{number_test_points}_moments_X1.dat', moments_X1)
    write_output_to_csv_moments_and_X(output_file + f'_{number_test_points}_moments_X2.dat', moments_X2)
    write_output_to_csv_moments_and_X(output_file + f'_{number_test_points}_moments_X3.dat', moments_X3)

    # Write mean_emulator_pred_err_obs_1, mean_emulator_pred_err_obs_2, mean_emulator_pred_err_obs_3 to CSV
    write_output_to_csv_mean_emu_err(output_file + f'_{number_test_points}_mean_emulator_pred_err_obs_1.dat', mean_emulator_pred_err_obs_1)
    write_output_to_csv_mean_emu_err(output_file + f'_{number_test_points}_mean_emulator_pred_err_obs_2.dat', mean_emulator_pred_err_obs_2)
    write_output_to_csv_mean_emu_err(output_file + f'_{number_test_points}_mean_emulator_pred_err_obs_3.dat', mean_emulator_pred_err_obs_3)

Read functions for the different file types

In [None]:
def read_emulator_file_errors(filename):
    data = []
    with open(filename, 'r') as file:
        for line in file:
            data.append(float(line.strip()))
    return data

def read_emulator_file_moments(filename):
    data = []
    with open(filename, 'r') as file:
        for line in file:
            values = line.strip().split(',')
            data.append([float(value) for value in values])
    return data

def read_multiple_emulator_errors_files(number_test_points_list,filename):
    data_list1 = []
    data_list2 = []
    data_list3 = []
    for i in number_test_points_list:
        filename1 = f"./emulator_output/{filename}_{i}_mean_emulator_pred_err_obs_1.dat"
        data1 = read_emulator_file_errors(filename1)
        data_list1.append(data1)

        filename2 = f"./emulator_output/{filename}_{i}_mean_emulator_pred_err_obs_2.dat"
        data2 = read_emulator_file_errors(filename2)
        data_list2.append(data2)

        filename3 = f"./emulator_output/{filename}_{i}_mean_emulator_pred_err_obs_3.dat"
        data3 = read_emulator_file_errors(filename3)
        data_list3.append(data3)
    return (data_list1,data_list2,data_list3)

def read_multiple_moments_files(number_test_points_list,filename):
    data_list1 = []
    data_list2 = []
    data_list3 = []
    for i in number_test_points_list:
        filename1 = f"./emulator_output/{filename}_{i}_moments_X1.dat"
        data1 = read_emulator_file_moments(filename1)
        data_list1.append(data1)

        filename2 = f"./emulator_output/{filename}_{i}_moments_X2.dat"
        data2 = read_emulator_file_moments(filename2)
        data_list2.append(data2)

        filename3 = f"./emulator_output/{filename}_{i}_moments_X3.dat"
        data3 = read_emulator_file_moments(filename3)
        data_list3.append(data3)
    return (data_list1, data_list2, data_list3)

def read_multiple_X_files(number_test_points_list,filename):
    data_list1 = []
    data_list2 = []
    data_list3 = []
    for i in number_test_points_list:
        filename1 = f"./emulator_output/{filename}_{i}_X1_obs.dat"
        data1 = read_emulator_file_moments(filename1)
        data_list1.append(data1)

        filename2 = f"./emulator_output/{filename}_{i}_X2_obs.dat"
        data2 = read_emulator_file_moments(filename2)
        data_list2.append(data2)

        filename3 = f"./emulator_output/{filename}_{i}_X3_obs.dat"
        data3 = read_emulator_file_moments(filename3)
        data_list3.append(data3)
    return (data_list1, data_list2, data_list3)


Plot functions for the read files. These plots are customized for the dNdy data set with 21 observables. The format might need a change for the other data sets.

In [None]:
def plot_emulator_errors_combined(err1, err2, err3, number_training_points, filename, plotformat):
    rows = plotformat[0]
    cols = plotformat[1]
    fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=(10, 20))  # Create subplots
    
    for i in range(rows):  # Rows
        for j in range(cols):  # Columns
            idx = i * cols + j
            axs[i, j].plot(number_training_points, [err[idx] for err in err1], label="PCGP")
            axs[i, j].plot(number_training_points, [err[idx] for err in err2], label="PCSK")
            axs[i, j].plot(number_training_points, [err[idx] for err in err3], label="Scikit GP")
            axs[i, j].set_title(f"Observable {idx+1}")
            if i == 0 and j == 0:
                axs[i, j].legend()
            #axs[i, j].set_yscale('log')
            if i == rows-1:  # Set x label for bottom row
                axs[i, j].set_xlabel("Training Points")
            if j == 0:  # Set y label for leftmost column
                axs[i, j].set_ylabel("Mean Emulator Uncertainty")
    
    plt.tight_layout()
    plt.savefig(filename+".pdf")
    plt.show()

def plot_emulator_moments_combined(mom1, mom2, mom3, number_training_points, filename, plotformat):
    rows = plotformat[0]
    cols = plotformat[1]
    fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=(10, 20))  # Create subplots
    for i in range(rows):  # Rows
        for j in range(cols):  # Columns
            idx = i * cols + j
            axs[i, j].plot(number_training_points, [mom[idx][0] for mom in mom1], label="PCGP")
            axs[i, j].plot(number_training_points, [mom[idx][0] for mom in mom2], label="PCSK")
            axs[i, j].plot(number_training_points, [mom[idx][0] for mom in mom3], label="Scikit GP")
            axs[i, j].set_title(f"Observable {idx+1}")
            if i == 0 and j == 0:
                axs[i, j].legend()
            if i == rows-1:  # Set x label for bottom row
                axs[i, j].set_xlabel("Training Points")
            if j == 0:  # Set y label for leftmost column
                axs[i, j].set_ylabel("Mean X")
    
    plt.tight_layout()
    plt.savefig(filename+"_mean.pdf")
    plt.show()
    plt.clf()

    fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=(10, 20))  # Create subplots
    for i in range(rows):  # Rows
        for j in range(cols):  # Columns
            idx = i * cols + j
            axs[i, j].plot(number_training_points, [mom[idx][1] for mom in mom1], label="PCGP")
            axs[i, j].plot(number_training_points, [mom[idx][1] for mom in mom2], label="PCSK")
            axs[i, j].plot(number_training_points, [mom[idx][1] for mom in mom3], label="Scikit GP")
            axs[i, j].set_title(f"Observable {idx+1}")
            if i == 0 and j == 0:
                axs[i, j].legend()
            if i == rows-1:  # Set x label for bottom row
                axs[i, j].set_xlabel("Training Points")
            if j == 0:  # Set y label for leftmost column
                axs[i, j].set_ylabel("Variance X")
    
    plt.tight_layout()
    plt.savefig(filename+"_variance.pdf")
    plt.show()
    plt.clf()

    fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=(10, 20))  # Create subplots
    for i in range(rows):  # Rows
        for j in range(cols):  # Columns
            idx = i * cols + j
            axs[i, j].plot(number_training_points, [mom[idx][2] for mom in mom1], label="PCGP")
            axs[i, j].plot(number_training_points, [mom[idx][2] for mom in mom2], label="PCSK")
            axs[i, j].plot(number_training_points, [mom[idx][2] for mom in mom3], label="Scikit GP")
            axs[i, j].set_title(f"Observable {idx+1}")
            if i == 0 and j == 0:
                axs[i, j].legend()
            if i == rows-1:  # Set x label for bottom row
                axs[i, j].set_xlabel("Training Points")
            if j == 0:  # Set y label for leftmost column
                axs[i, j].set_ylabel("Skewness X")
    
    plt.tight_layout()
    plt.savefig(filename+"_skewness.pdf")
    plt.show()
    plt.clf()

    fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=(10, 20))  # Create subplots
    for i in range(rows):  # Rows
        for j in range(cols):  # Columns
            idx = i * cols + j
            axs[i, j].plot(number_training_points, [mom[idx][3] for mom in mom1], label="PCGP")
            axs[i, j].plot(number_training_points, [mom[idx][3] for mom in mom2], label="PCSK")
            axs[i, j].plot(number_training_points, [mom[idx][3] for mom in mom3], label="Scikit GP")
            axs[i, j].set_title(f"Observable {idx+1}")
            if i == 0 and j == 0:
                axs[i, j].legend()
            if i == rows-1:  # Set x label for bottom row
                axs[i, j].set_xlabel("Training Points")
            if j == 0:  # Set y label for leftmost column
                axs[i, j].set_ylabel("Excess Kurtosis X")
    
    plt.tight_layout()
    plt.savefig(filename+"_kurtosis.pdf")
    plt.show()
    plt.clf()

def plot_emulator_X_combined(X1, X2, X3, filename, plotformat):
    rows = plotformat[0]
    cols = plotformat[1]
    fig, axs = plt.subplots(nrows=rows, ncols=cols, figsize=(10, 20))  # Create subplots
    
    x = np.linspace(-5, 5, 300)
    y = (1 / (np.sqrt(2*np.pi))) * np.exp(-0.5 * x**2)
    for i in range(rows):  # Rows
        for j in range(cols):  # Columns
            idx = i * cols + j
            axs[i, j].hist(X1[idx], density=True, histtype='step', label="PCGP")
            axs[i, j].hist(X2[idx], density=True, histtype='step', label="PCSK")
            axs[i, j].hist(X3[idx], density=True, histtype='step',  label="Scikit GP")
            axs[i, j].set_title(f"Observable {idx+1}")
            if i == 0 and j == 0:
                axs[i, j].legend()
            if i == rows-1:  # Set x label for bottom row
                axs[i, j].set_xlabel("X")
            if j == 0:  # Set y label for leftmost column
                axs[i, j].set_ylabel("p(X)")

            axs[i, j].plot(x, y, color = 'black', zorder = -2, linewidth = 2, label = 'Normal dist.')
    
    plt.tight_layout()
    plt.savefig(filename+".pdf")
    plt.show()


Generate the data files for a different number of test points

In [None]:
model_par = "../data/modelDesign_3DMCGlauber.txt"

training_set_list_dNdy = ["../data/AuAu7.7_dNdy.pkl","../data/AuAu19p6_dNdy.pkl","../data/AuAu200_dNdy.pkl","../data/AuAu200_PHOBOSdNdeta.pkl"]
output_file_list_dNdy = ["./emulator_output/7p7_dNdy_emu_out","./emulator_output/19p6_dNdy_emu_out","./emulator_output/200_dNdy_emu_out","./emulator_output/200PHOBOS_dNdeta_emu_out"]
output_file_list_LOGdNdy = ["./emulator_output/7p7_LOGdNdy_emu_out","./emulator_output/19p6_LOGdNdy_emu_out","./emulator_output/200_LOGdNdy_emu_out","./emulator_output/200PHOBOS_LOGdNdeta_emu_out"]

training_set_list_pTvn = ["../data/AuAu7.7_pTvn.pkl","../data/AuAu19p6_pTvn.pkl","../data/AuAu200_pTvn.pkl", "../data/AuAu200_PHOBOSv2eta.pkl"]
output_file_list_pTvn = ["./emulator_output/7p7_pTvn_emu_out","./emulator_output/19p6_pTvn_emu_out","./emulator_output/200_pTvn_emu_out","./emulator_output/200PHOBOS_vn_emu_out"]

Quick summary:
- The AuAu7.7_dNdy data set does not sort out any of the 1000 training points.
- The AuAu7.7_pTvn data set sorts out 40 of the 1000 training points due to large statistical errors.
- The AuAu19p6_dNdy data set does not sort out any of the 1100 training points.
- The AuAu19p6_pTvn data set sorts out 5 of the 1100 training points due to large statistical errors.
- The AuAu200_dNdy data set does not sort out any of the 1100 training points.
- The AuAu200_pTvn data set sorts out 46 of the 1100 training points due to large statistical errors.
- The AuAu200_PHOBOS_dNdeta data set sorts out 1 of the 1000 training points due to large statistical errors.
- The AuAu200_PHOBOSv2eta data set does not sort out any of the 1100 training points.

In [None]:
for i in [900,800,700,600,500,400,300,200,100]:
    for tr_set in range(len(training_set_list_dNdy)):
        train_multiple_emulators_and_write_to_csv(training_set_list_dNdy[tr_set], model_par, i, output_file_list_dNdy[tr_set], False)

In [None]:
for i in [900,800,700,600,500,400,300,200,100]:
    for tr_set in range(len(training_set_list_dNdy)):
        train_multiple_emulators_and_write_to_csv(training_set_list_dNdy[tr_set], model_par, i, output_file_list_LOGdNdy[tr_set], True)

In [None]:
for i in [900,800,700,600,500,400,300,200,100]:
    for tr_set in range(len(training_set_list_pTvn)):
        train_multiple_emulators_and_write_to_csv(training_set_list_pTvn[tr_set], model_par, i, output_file_list_pTvn[tr_set], False)

Read the files for different numbers of test points and plot the data

In [None]:
filename_prefix = "7p7_dNdy_emu_out"
#filename_prefix = "19p6_dNdy_emu_out"
#filename_prefix = "200_dNdy_emu_out"
#filename_prefix = "200PHOBOS_dNdeta_emu_out"

#filename_prefix = "7p7_LOGdNdy_emu_out"
#filename_prefix = "19p6_LOGdNdy_emu_out"
#filename_prefix = "200_LOGdNdy_emu_out"
#filename_prefix = "200PHOBOS_LOGdNdeta_emu_out"

#filename_prefix = "7p7_pTvn_emu_out"
#filename_prefix = "19p6_pTvn_emu_out"
#filename_prefix = "200_pTvn_emu_out"
#filename_prefix = "200PHOBOS_vn_emu_out"

#[900,800,700,600,500,400,300,200,100]
err1, err2, err3 = read_multiple_emulator_errors_files([900,800],filename_prefix)
mom1, mom2, mom3 = read_multiple_moments_files([900,800],filename_prefix)
X1, X2, X3 = read_multiple_X_files([900,800],filename_prefix)

In [None]:
#[60,160,260,360,460,560,660,760,860]
#[195,295,395,495,595,695,795,895,995]
#[200,300,400,500,600,700,800,900,1000]
#[154,254,354,454,554,654,754,854,954]
#[99,199,299,399,499,599,699,799,899]
plot_emulator_errors_combined(err1,err2,err3,[100,200,300,400,500,600,700,800,900],"./AuAu7p7_emu_uncertainty_dNdy",(7,3))
plot_emulator_moments_combined(mom1,mom2,mom3,[100,200,300,400,500,600,700,800,900],"./AuAu7p7_emu_moment_dNdy",(7,3))
plot_emulator_X_combined(X1[-1],X2[-1],X3[-1],"./AuAu7p7_emu_Xhist_dNdy_200trainingpoints",(7,3))

### Old functions

In [None]:
def plot_emulator_vs_truth(vali_data,vali_data_err,emu_pred,emu_pred_err):
    nValidationPoints, nObs = vali_data_1.shape

    for obsIdx in range(nObs):
        fig = plt.figure()
        ax = plt.axes([0.12, 0.12, 0.83, 0.83])
        plt.errorbar(vali_data[:, obsIdx], emu_pred[:, obsIdx],
                    yerr=np.sqrt(emu_pred_err[:, obsIdx,obsIdx]),
                    xerr=vali_data_err[:, obsIdx],
                    marker="o", linestyle="")
        plt.plot([-200, 200], [-200, 200], '--k')
        plt.xlim([vali_data[:, obsIdx].min() - 1,
                vali_data[:, obsIdx].max() + 1])
        plt.ylim([emu_pred[:, obsIdx].min() - 1,
                emu_pred[:, obsIdx].max() + 1])
        ax.set_aspect('equal')
        ax.text(0.05, 0.95, "obs {}".format(obsIdx), fontsize=20, transform=ax.transAxes, verticalalignment='top')
        plt.xlabel("truth")
        plt.ylabel("emulator results")

def plot_emulator_vs_truth_relative(vali_data,vali_data_err,emu_pred,emu_pred_err):
    nValidationPoints, nObs = vali_data.shape

    for obsIdx in range(nObs):
        fig = plt.figure()
        ax = plt.axes([0.12, 0.12, 0.83, 0.83])
        plt.plot(range(len(vali_data)),
                (emu_pred[:, obsIdx] - vali_data[:, obsIdx])/np.sqrt(vali_data_err[:, obsIdx]**2.+ emu_pred_err[:, obsIdx, obsIdx]),
                marker="o", linestyle="")
        plt.fill_between([-nValidationPoints, nValidationPoints], [2, 2], [-2, -2], color='g', alpha=0.2)
        plt.xlim([-1,len(vali_data)+1])
        plt.ylim([-4, 4])
        ax.text(0.05, 0.95, "obs {}".format(obsIdx), fontsize=20, transform=ax.transAxes, verticalalignment='top')
        plt.xlabel("test point")
        plt.ylabel("relative diff. [$\sigma$]")

def plot_histogram_emulator_vs_truth_relative(vali_data_list,emu_pred_list,emu_pred_err_list,dataset_labels=None):
    if not isinstance(vali_data_list, list):
        vali_data_list = [vali_data_list]
    if not isinstance(emu_pred_list, list):
        emu_pred_list = [emu_pred_list]
    if not isinstance(emu_pred_err_list, list):
        emu_pred_err_list = [emu_pred_err_list]
    if not isinstance(dataset_labels, list):
        dataset_labels = [dataset_labels]
    
    nObs = vali_data_list[0].shape[1]  # Assuming all datasets have the same number of observables

    for obsIdx in range(nObs):
        fig, ax = plt.subplots()
        ax.set_xlabel("A = (GP-truth)/GPerr")
        ax.set_ylabel("p(A)")
        
        moments = []
        handles = []
        for vali_data, emu_pred, emu_pred_err, label in zip(vali_data_list, emu_pred_list, emu_pred_err_list, dataset_labels):
            A = (emu_pred[:, obsIdx] - vali_data[:, obsIdx]) / np.sqrt(emu_pred_err[:, obsIdx, obsIdx])
            ax.hist(A, bins=25, density=True, histtype='step', label=label)

            # Compute first four moments of the histogram
            mean = np.mean(A)
            variance = np.mean((A - mean)**2.)
            skewness = np.mean((A - mean)**3.) / variance**(3./2.)
            kurtosis = np.mean((A - mean)**4.) / variance**(4./2.)
            moment_string = f'μ={mean:.2f}, σ²= {variance:.2f}, γ₁={skewness:.2f}, γ₂={kurtosis:.2f}'
            moments.append((label, moment_string))

            patch = mpatches.Patch(color='none', label=label)
            handles.append(patch)

        x = np.linspace(-6, 6, 300)
        y = (1 / (np.sqrt(2*np.pi))) * np.exp(-0.5 * x**2)        
        plt.plot(x, y, color = 'black', zorder = 2, linewidth = 2, label = 'Normal dist.')

        ax.text(0.05, 0.8, "obs {}".format(obsIdx), fontsize=20, transform=ax.transAxes, verticalalignment='top')
        ax.legend()

        # Add a separate legend for the moment strings
        ax2 = ax.twinx()
        ax2.legend(handles, moments, loc='upper left', fontsize=6)
        ax2.axis('off')  # Turn off the axis for the second legend
        plt.tight_layout() 
        plt.show()