# Test the consistency of model correction on spectra

Design:

1. Generate a spectrum consisting of only peaks
2. Insert random components (Baseline, Cosmic rays, Noise) into several copies of that spectrum
3. Have the models correct the copies and save the peak prediction
4. Measure the consistency of the prediction by measuring the variance on each wavelength. The smaller the variance, the higher the consistency

In [None]:
# Imports
from Scripts.essentials import *
from Scripts.generator import *

In [None]:
plt.rcParams.update({'font.size': 40})
plt.rcParams["font.family"] = "Times New Roman"
flier_props = dict(marker='o', markerfacecolor='gray', alpha = 0.05, markersize=5,
                  linestyle='none', markeredgecolor="gray")
mean_props = {"marker": "_", 'markerfacecolor': "Red", 'markeredgecolor': "Red"}

min_width, max_width = 10, 300
length = 1024

# How many peak vector copies?
num_peaks = 500

# How many components to use on each peak vector?
num_components = 1000

# Peaks (ground truth)
np.random.seed(2024)
peaks = np.array([generate_peaks(length = length,
                                 min_peak_width = min_width,
                                 max_peak_width = max_width) for i in range(num_peaks)])

np.random.seed(42)
mean_color = "orangered"
# Other components
baselines = []
cosmic_rays = []
noise = []

for i in range(num_components):
    bl, cr, n, _ = generate_spectrum(length, min_width, max_width)
    
    baselines.append(bl)
    cosmic_rays.append(cr)
    noise.append(n * np.random.randint(2.0, 10.0)) # Increase noise
    
baselines = np.array(baselines)
cosmic_rays = np.array(cosmic_rays)
noise = np.array(noise)

ensemble_model = make_ensemble()
ensemble_model.load_weights(filepath= "Models/ensembleModelRes.h5")

standard_model = make_standard()
standard_model.load_weights(filepath = "Models/standardModel.h5")

# Cascaded model
cascade = tf.keras.models.load_model("Models/unet_three.23-64.96.h5")

# Wahl model
import WahlModel     
wahl_model = WahlModel.load_model()

# Lists for storing metrics 
ensemble_stats = []
standard_stats = []
cascaded_stats = []
wahl_stats = []

# Index denoting the enumeration of P starting from 0
ix = 0
# Now, for each spectrum. Create 1000 copies of it and add them with the random components
for p in peaks:
    
    # Add the components to disrupt P
    X = baselines
    X = cosmic_rays + X
    X = noise + X
    
    # Normalize the spectra
    maxim = np.max(X, axis = 1)
    minim = np.min(X, axis = 1)
    # Normalize the spectrum such that maximum is 1, minimum is 0
    X = (X - np.expand_dims(minim, -1))/(np.expand_dims(maxim, -1) - np.expand_dims(minim, -1))
    X = X * (3/4) # 3 of four components are in X

    p = p/4 # p is divided by 4 to make it 1/4 of the X vectors. Addition with X will in principle make X into 4/4 components
    
    # expand dimension for model use
    X = X + np.expand_dims(p, 0) # Add peaks after normalization, This way, the peak remains the same for all x in X. 
    
    X = np.expand_dims(X, -1)
    
    ## DL-predictions ##

    # Standard model
    standard_preds = standard_model.predict(X, verbose = 0)
    # Get the peaks
    standard_preds = standard_preds[-1]
    
    # Retrained model predictions
    ensemble_preds = ensemble_model.predict(X, verbose = 0)
    # Get the peaks
    ensemble_preds = ensemble_preds[-1]
    
    # Cascaded preds, multiply X by 1000 to set the required input shape
    cascaded_preds = cascade.predict(X * 1000, verbose = 0)
    # Get baselines, by subtracting the baseline corrected spectrum
    cascaded_preds = np.squeeze(cascaded_preds[-1]/1000)
    
    # The suggested way to preprocess our data for Wahls model
    # Wahl preds
    wahl_mean = np.expand_dims(np.mean(np.squeeze(X), axis = 1), -1)
    wahl_norm = np.expand_dims(np.linalg.norm(np.squeeze(X), axis = 1), -1)
    wahl_prep = 256 * (np.squeeze(X) - wahl_mean)/wahl_norm
    wahl_prep = np.expand_dims(wahl_prep, -1)
    
    # Get the Wahl predictions
    wahl_preds = wahl_model.predict(wahl_prep, verbose = 0)
    # Rescale them to the original scope
    wahl_preds = (wahl_norm * wahl_preds/256) + wahl_mean

    
    names = ["Ensemble", "Standard", "Kazemzadeh et al.", "Wahl"]
    lists = [ensemble_stats, standard_stats, cascaded_stats, wahl_stats]
    preds = [ensemble_preds, standard_preds, cascaded_preds, wahl_preds]

    # Get maximum values for the y-axes for z-score and boxplots
    max_preds = [0, 0]
    for l, pred in zip(lists, preds):
        pred_std = np.std(pred, axis = 0)
        pred_mean = np.mean(pred, axis = 0)
        
        pred_z_score = np.abs(np.nan_to_num((pred - pred_mean)/(pred_std + 0.00001)))
        max_preds[0] = np.max([max_preds[0], np.max(pred_z_score)])
        
        # Gather the statistics about peak-specific sites
        pred_p = pred[:, p > 0]
        var_pred_p = np.var(pred_p, axis = 0)
        max_preds[1] = np.max([max_preds[1], np.max(var_pred_p)])
        # Gather statistics where the signal should be 0 globally
        pred_non_p = pred[:, p == 0]
        var_pred_non_p = np.var(pred_non_p, axis = 0)
        max_preds[1] = np.max([max_preds[1], np.max(var_pred_non_p)])

    # Input X stats
    X_mean = np.mean(np.squeeze(X), axis=0)
    X_median = np.median(X, axis = 0)
    
    fig, ax = plt.subplots(1,1, sharey=False, figsize = (11, 11))
        # Display the distribution of X, between the minimum and maximum of each frequency
    ax.fill_between(np.arange(len(pred_std)),
                        np.max(np.squeeze(X), axis=0),
                        np.min(np.squeeze(X), axis=0),
                        color = "Black", alpha = 0.5, label = "Distribution")
    #ax.plot(np.squeeze(X).T, color = "Black", alpha = 0.01)
    ax.plot(X_mean, ls="--", color = mean_color, label = "Mean", linewidth=3.0) 
    ax.plot(p, ls = "-", color = "Green", label = "P", linewidth=4.0)
    ax.set_ylim([-0.05, 1.3])
    ax.set_yticks([0, 0.5, 1])
    ax.set_xticks([])
    ax.legend()
    ax.title.set_text("Generated spectra")
    fig.tight_layout()
        # Save the figure for later
    plt.savefig("Figures/ConsistencyExperiments/consistency_example"+str(ix)+"_spectra.png", transparent = True,
                    dpi = 1000,
                    bbox_inches='tight',
                    pad_inches=0.5)
    plt.show()
        
    # Index denoting the model being tested
    ax_index = 0
    for l, pred in zip(lists, preds):
        
        # When we check Wahls predictions, we normalize the outputs and scale them according to the maximum P-intensity
        # This makes the predictions more comparable to the other models
        if ax_index == 3:
            pred = pred / np.max(pred)


        # Prediction stats
        pred_std = np.std(pred, axis = 0)
        pred_mean = np.mean(pred, axis = 0)
        pred_max = np.max(pred, axis = 0)
        pred_min = np.min(pred, axis = 0)
        pred_minmax_diff = (pred_max - pred_min)/len(pred_max)
        
        
        
        pred_z_score = np.abs(np.nan_to_num((pred - pred_mean)/(pred_std + 0.00001)))

        # Gather the statistics about peak-specific sites
        pred_p = pred[:, p > 0]
        var_pred_p = np.var(pred_p, axis = 0)
        # Gather statistics where the signal should be 0 globally
        pred_non_p = pred[:, p == 0]
        var_pred_non_p = np.var(pred_non_p, axis = 0)
        
        
        fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, sharey=False, figsize = (25, 8))
        

        # Display the distribution of P in a similar fashion
        ax1.fill_between(np.arange(len(pred_std)),
                        pred_max,
                        pred_min,
                        color = "Green", alpha = 0.5, label = "Distribution")

        #ax1.plot(pred.T, color = "Green", alpha = 0.1)
        ax1.plot(pred_mean, ls="--", color = mean_color, label = "Mean", linewidth=4.0)
        ax1.set_ylim([-0.05, 1.3])
        ax1.set_yticks([0, 0.5, 1])
        ax1.set_xticks([])
        ax1.legend(fontsize="30")
        ax1.title.set_text("P-predictions")
        
        # Show the standard deviation on each frequency
        ax2.plot(pred_std, label ="Mean: "+ str(np.round(np.mean(pred_std), 3)), color = mean_color, linewidth=4.0)
        ax2.set_ylim([-0.05, 1.3])
        ax2.set_yticks([0, 0.5, 1])
        ax2.set_xticks([])
        ax2.title.set_text("Standard deviation") 
        ax2.legend(fontsize="30")
        
        # Show the absolute z-score on each frequency
        ax3.fill_between(np.arange(len(pred_z_score[0])),
                        np.max(np.abs(pred_z_score), axis = 0),
                        np.min(np.abs(pred_z_score), axis = 0),
                        color = mean_color, alpha = 0.5, label = "Distribution")
        ax3.plot(np.mean(pred_z_score, axis = 0),
                 label = "Mean: " + str(np.round(np.mean(pred_z_score), 3)), color = mean_color,
                 linewidth=4.0)
        ticks = np.linspace(0, max_preds[0], 3)
        ax3.yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
        ax3.set_yticks(ticks)
        #ax3.set_yticks([0, 10, 20, 30, 40])
        ax3.set_xticks([])
        ax3.set_ylim([-0.2, max_preds[0] + 0.2])
        ax3.title.set_text("Absolute z-score")
        ax3.legend(fontsize="30")
        
        #ax4.boxplot([pred_minmax_diff, std_pred_p, std_pred_non_p, pred_std],
        #            labels = ["Pred Spans", "Peak std", "Non-peak std", "std"],
        ax4.boxplot([var_pred_p, var_pred_non_p],
                    labels = ["P > 0", "P = 0"],
                   showmeans=True, 
                  flierprops = flier_props,
                  meanprops= mean_props
        )
        ticks = np.linspace(0, np.max([0.008, np.max(var_pred_p)]), 3)
        ax4.set_yticks(ticks)
        ax4.yaxis.set_major_formatter(FormatStrFormatter('%.3f'))
        #ax4.set_xticklabels(["Pred Spans", "Peak std", "Non-peak std", "std"], rotation = 45)
        #ax4.set_yticks([])
        #ax4.set_xticks([])
        ax4.title.set_text("Variance")
        #ax4.legend()
        
        # Save the metrics in the related list
        l.append([np.mean(pred_z_score),
                  np.mean(pred_std),
                  np.mean(var_pred_p),
                  np.mean(var_pred_non_p),
                  np.max(pred_z_score),
                  np.max(pred_std),
                  np.max(var_pred_p),
                  np.max(var_pred_non_p),
                 ])

        fig.suptitle(names[ax_index], fontsize=40, y=0.9)
        fig.tight_layout()
        # Save the figure for later
        plt.savefig("Figures/ConsistencyExperiments/consistency_example"+str(ix)+"_"+ str(ax_index) +".png", format="png", transparent = True,
                    dpi = 1000,
                    bbox_inches='tight',
                    pad_inches=0.5)
        plt.show()
        ax_index += 1
    
    ix += 1
    print(ix)

In [None]:
# Save the metric stats
np.save("Results/Consistency_ensemble_stats.npy", ensemble_stats)
np.save("Results/Consistency_standard_stats.npy", standard_stats)
np.save("Results/Consistency_cascaded_stats.npy", cascaded_stats)
np.save("Results/Consistency_Wahl_stats.npy", wahl_stats)

In [None]:
# Load the metrics (So we don't need to run the cells above again)
# Take the mean of the metrics and display them as a table
ensemble_stats = np.load("Results/Consistency_ensemble_stats.npy")
standard_stats = np.load("Results/Consistency_standard_stats.npy")
cascaded_stats = np.load("Results/Consistency_cascaded_stats.npy")
wahl_stats = np.load("Results/Consistency_Wahl_stats.npy")
header = [" Mean Absolute z-score", "mean_std", "mean_P_variance", "mean_non_P_variance", "Max Absolute z-score", "Max std", "max_P_variance", "max_non_P_variance",]

stats = {}
stats["Ensemble"] = np.mean(ensemble_stats, axis = 0)
stats["Standard"] = np.mean(standard_stats, axis = 0)
stats["Kazemzadeh et al."] = np.mean(cascaded_stats, axis = 0)
stats["Wahl"] = np.mean(wahl_stats, axis = 0)


df = pd.DataFrame.from_dict(stats,
                            columns = header,
                            orient = "index")
df.style.format(decimal=',', thousands='.', precision=4)

In [None]:
df = df.style.format(decimal=',', thousands='.', precision=2)
print(df.to_latex())

In [None]:
plt.rcParams.update({'font.size': 25})
flier_props = dict(marker='o', markerfacecolor='gray', alpha = 0.5, markersize=5,
                  linestyle='none', markeredgecolor="gray")
mean_props = {"marker": "_", 'markerfacecolor': "Red", 'markeredgecolor': "Red"}
num_ticks = 4
names = ["Ensemble", "Standard", "Kazemzadeh et al.", "Wahl"]

fix, ax = plt.subplots(ncols = 3, figsize = (13, 7))
ax[0].boxplot([ensemble_stats.T[0], standard_stats.T[0], cascaded_stats.T[0], wahl_stats.T[0]],
             showmeans=True, 
                  flierprops = flier_props,
                  meanprops= mean_props,)

max = np.max([ensemble_stats.T[0], standard_stats.T[0], cascaded_stats.T[0], wahl_stats.T[0]])
ticks = np.linspace(0, max, num_ticks)
ax[0].set_yticks(ticks)
ax[0].set_xticklabels(names, rotation = 90)
ax[0].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))

ax[0].yaxis.grid(True, linestyle='-', which='major', color='lightgrey',
               alpha=0.5)
ax[0].title.set_text("Mean Absolute z-score")


ax[1].boxplot([ensemble_stats.T[1], standard_stats.T[1], cascaded_stats.T[1], wahl_stats.T[1]],
             showmeans=True, 
                  flierprops = flier_props,
                  meanprops= mean_props,)
max = np.max([ensemble_stats.T[1], standard_stats.T[1], cascaded_stats.T[1], wahl_stats.T[1]])
ticks = np.linspace(0, max, num_ticks)
ax[1].set_yticks(ticks)
ax[1].set_xticklabels(names, rotation = 90)
ax[1].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))

ax[1].yaxis.grid(True, linestyle='-', which='major', color='lightgrey',
               alpha=0.5)
ax[1].title.set_text("Mean Standard Deviation")

ax[2].boxplot([ensemble_stats.T[2], standard_stats.T[2], cascaded_stats.T[2], wahl_stats.T[2]],
             showmeans=True, 
                  flierprops = flier_props,
                  meanprops= mean_props,)

max = np.max([ensemble_stats.T[2], standard_stats.T[2], cascaded_stats.T[2], wahl_stats.T[2]])
ticks = np.linspace(0, max, num_ticks)
ax[2].set_yticks(ticks)
ax[2].set_xticklabels(names, rotation = 90)
ax[2].yaxis.set_major_formatter(FormatStrFormatter('%.3f'))

ax[2].yaxis.grid(True, linestyle='-', which='major', color='lightgrey',
               alpha=0.5)
ax[2].title.set_text("Mean P-variance")
plt.tight_layout()
plt.savefig("Figures/ConsistencyExperiments/BoxplotMean.png", format="png", transparent = True,
                    dpi = 1000,
                    bbox_inches='tight',
                    pad_inches=0.5)


plt.show()

fix, ax = plt.subplots(ncols = 3, figsize = (13, 7))
ax[0].boxplot([ensemble_stats.T[4], standard_stats.T[4], cascaded_stats.T[4], wahl_stats.T[4]],
             showmeans=True, 
                  flierprops = flier_props,
                  meanprops= mean_props,)

max = np.max([ensemble_stats.T[4], standard_stats.T[4], cascaded_stats.T[4], wahl_stats.T[4]])
min = np.min([ensemble_stats.T[4], standard_stats.T[4], cascaded_stats.T[4], wahl_stats.T[4]])
ticks = np.linspace(min, max, num_ticks)
ax[0].set_yticks(ticks)
ax[0].set_xticklabels(names, rotation = 90)
ax[0].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))

ax[0].yaxis.grid(True, linestyle='-', which='major', color='lightgrey',
               alpha=0.5)
ax[0].title.set_text("Max Absolute z-score")

ax[1].boxplot([ensemble_stats.T[5], standard_stats.T[5], cascaded_stats.T[5], wahl_stats.T[5]],
             showmeans=True, 
                  flierprops = flier_props,
                  meanprops= mean_props,)

max = np.max([ensemble_stats.T[5], standard_stats.T[5], cascaded_stats.T[5], wahl_stats.T[5]])
ticks = np.linspace(0, max, num_ticks)
ax[1].set_yticks(ticks)
ax[1].set_xticklabels(names, rotation = 90)
ax[1].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
ax[1].yaxis.grid(True, linestyle='-', which='major', color='lightgrey',
               alpha=0.5)
ax[1].title.set_text("Max Standard Deviation")


ax[2].boxplot([ensemble_stats.T[6], standard_stats.T[6], cascaded_stats.T[6], wahl_stats.T[6]],
             showmeans=True, 
                  flierprops = flier_props,
                  meanprops= mean_props,)

max = np.max([ensemble_stats.T[6], standard_stats.T[6], cascaded_stats.T[6], wahl_stats.T[6]])
ticks = np.linspace(0, max, num_ticks)
ax[2].set_yticks(ticks)
ax[2].set_xticklabels(names, rotation = 90)
ax[2].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
ax[2].yaxis.grid(True, linestyle='-', which='major', color='lightgrey',
               alpha=0.5)
ax[2].title.set_text("Max P-variance")


plt.tight_layout()

plt.savefig("Figures/ConsistencyExperiments/BoxplotMax.png", format="png", transparent = True,
                    dpi = 1000,
                    bbox_inches='tight',
                    pad_inches=0.5)
plt.show()