In [None]:
from CLfitter import *
from Validation import *

from joblib import Parallel, delayed
import torch.nn as nn

In [None]:
class EELSBackgroundNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear( 2, 16),
            nn.SiLU(),
            nn.Linear(16, 1),
        )
    def forward(self, x):
        return self.model(x)

In [None]:
#----------- Synthetic data generation -----------
SNR = 10
pre_edge_region = 30

E_start, E_stop = 300,600
E_edge = E_start+pre_edge_region

SIG = SpectralImageGenerator(20, 20, 900, E_start, E_stop)
syntheticdata, energy_axis = SIG.generate_realistic_spectral_image(gaussian_snr = SNR, A_range=(1e5,1e6),r_range=(2,3),
                                                                    scale= 50)
syntheticdata = syntheticdata.reshape(900, 400)

ground_truth = SIG.background.reshape(900,400)


#----------- Data loading -----------
handler = DataHandler()
handler.other_data(syntheticdata, np.arange(20), np.arange(20), energy_axis)
signal =  handler.signal.copy()
i = (E_start-5 ,E_edge ,E_stop+5)
        


    
#----------- pre-processing -----------
range_mask = (handler.energy_axis > i[0]) & (handler.energy_axis < i[-1])  # Range for clustering shape [n_E_1]

energy_range = handler.energy_axis[range_mask] # shape [n_E_1]
signal_range = signal.copy()[range_mask,:]  # shape [n_E_1, n_y*n_x]
ground_truth_range = ground_truth.copy()[range_mask,:]
pre_edge_mask = (energy_range > i[0]) & (energy_range < i[1])  # Range for pre-edge shape [n_E_2]


#----------- clustering  -----------
clusterer = ClusterAnalyzer(signal_range)
clusterer.cluster_data(n_clusters = 6, pre_edge_mask=pre_edge_mask,)
clusterer.cholesky_decomp()

X_builder = X_Builder(energy_range)
X_builder.prepare_X_mc_data(clusterer.cluster_centers, i[1])
X_builder.prepare_X_eval_data(clusterer.total_integrated_intensity)

In [None]:

def train_background(ii, signal_range, pre_edge_mask, X_builder, clusterer, i):
    background_trainer = BackgroundTrainer(
        signal=signal_range,
        pre_edge_mask=pre_edge_mask,
        X_mc=X_builder.X_mc,
        X_eval=X_builder.X_eval,
        clustered_spectra_mean=clusterer.clusters_mean,
        triangular_matices=clusterer.triangular_matices,
        covariance_matrices=clusterer.clusters_covariance,
        cluster_labels=clusterer.clusters
    )

    background_trainer.train_MC_replica_consecutive(
        n_mc_replicas=10, 
        epochs=20000, 
        edge_onset=i[1],
        replica_version='covariance',
        model = EELSBackgroundNN(),        
        progress = False,
        logging = False
    )

    np.savez(fr'validationresults/pred_{ii}.npz', pred = background_trainer.background)



# Number of parallel jobs (adjust based on your CPU cores)
n_jobs = 80  # e.g., 4 cores

Parallel(n_jobs=n_jobs)(
    delayed(train_background)(ii, signal_range, pre_edge_mask, X_builder, clusterer, i)
    for ii in range(0, 1)
)

np.savez('validationresults/mc_data.npz', cov = clusterer.clusters_covariance, mean = clusterer.clusters_mean, labels = clusterer.clusters)
np.savez('validationresults/run_data.npz', signal = signal, GT = ground_truth)


In [None]:
import numpy as np
import matplotlib.pyplot as plt

preds = []
for ii in range(1):
    preds.append(np.load(fr'validationresults/pred_{ii}.npz')['pred'])  

GT = np.load(r'validationresults/run_data.npz')['GT']
signal = np.load(r'validationresults/run_data.npz')['signal']

predictions = np.concatenate(preds, axis=0)

pred_std = np.std(predictions, axis=0).T
pred_mean = np.mean(predictions, axis=0).T

predictions = np.concatenate(preds, axis=0)
pred_median = np.median(predictions, axis=0)
upper, lower = np.percentile(predictions, [16,84], axis=0)
pred_error = (lower-pred_median)

normalized_difference_to_theory = (predictions - GT.T) / pred_error
normalized_difference_to_theory[np.abs(normalized_difference_to_theory)>100] = None # get rid of extreme outliers skewing data

import numpy as np
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(8,8))  # square figure

# Histogram
ax.hist(normalized_difference_to_theory.flatten(), bins=100, density=True, 
        alpha=0.6, color='#5CECC4', edgecolor='#d45087', linewidth=0.5, label='Replica Distribution')

# Theoretical Normal distribution
x = np.linspace(-10, 10, 10000)
ax.plot(x, 1/np.sqrt(2*np.pi)*np.exp(-x**2/2), c='#003f5c', lw=2, label='Normal Distribution')

# Labels
ax.set_xlabel('Difference to Theory ($\\sigma$)', fontsize=20)
ax.set_ylabel('Probability Density', fontsize=20)

# Limit
ax.set_xlim(-5,5)

ax.set_yticks([0])
ax.set_xticks([-4,-2,0,2,4])

ax.legend(fontsize=20)

# Tight layout
fig.tight_layout()
plt.savefig('fig3_histogram.svg', bbox_inches='tight')
plt.savefig('fig3_histogram.png', bbox_inches='tight', dpi=300)
plt.show()

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl
from mpl_toolkits.axes_grid1 import ImageGrid
import sys
import os

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), '..', 'Code')))

from CLfitter import *

cluster_data = np.load(r'validationresults/mc_data.npz')
run_data = np.load(r'validationresults/run_data.npz')
signal = run_data['signal']
labels = cluster_data['labels']
cov = cluster_data['cov']
mean = cluster_data['mean']

clusterer = ClusterAnalyzer(signal)
clusterer.cluster_data(n_clusters = mean.shape[1]+1)
centers = clusterer.cluster_centers
TII = np.log(signal.sum(axis=0))
TII_max = (TII >= centers.max())
TII_min = (TII <= centers.min())
TII_middle =  (TII <= centers.max())&(TII >= centers.min())

# preds = []
# for ii in range(1,19):
#     preds.append(np.load(fr'validationresults/pred_{ii}.npz')['pred'])  
predictions = np.concatenate(preds, axis=0)
pred_median = np.median(predictions, axis=0)
lower, upper = np.percentile(predictions, [16,84], axis=0)
pred_error = (upper-lower)[:,:len(cov)]/2
print(pred_error.shape, cov.shape)
# pred_error = predictions.std(axis=0)
cluster_std = np.zeros_like(pred_error)
for i in range(np.max(labels)+1):
    mu = mean[:,i]       # shape (E,)
    sigma2 = np.diag(cov[:,:,i])     # diag of covariance, shape (E,)
    
    sigma_linear = np.sqrt( (np.exp(sigma2) - 1) * np.exp(2*mu + sigma2) )
    cluster_std[labels == i] = sigma_linear#np.sqrt(np.exp(np.diag(cov[:,:,i])))


TII_expanded = np.repeat(TII[:, None], cluster_std.shape[1], axis=1).flatten()

plt.rcParams['font.size'] = 20
# ---- figure and square axes ----
fig = plt.figure(figsize=(8,8))
ax = fig.add_axes([0.2,0.2,0.7,0.7])
# ---- scatter plot ----


sc = ax.scatter(
    cluster_std[TII_max].flatten()/cluster_std.min(),
    pred_error[TII_max].flatten()/cluster_std.min(),
    c='#bc5090',
    s=6,
    edgecolors='none',
    label = 'Upper-Extrapolated Values'
)

sc = ax.scatter(
    cluster_std[TII_middle].flatten()/cluster_std.min(),
    pred_error[TII_middle].flatten()/cluster_std.min(),
    c='#003f5c',
    s=6,
    edgecolors='none',
    label = 'Interpolated Values'
    
)

sc = ax.scatter(
    cluster_std[TII_min].flatten()/cluster_std.min(),
    pred_error[TII_min].flatten()/cluster_std.min(),
    c="#7d72c6",
    s=6,
    edgecolors='none',
    label = 'Lower-Extrapolated Values'    
)


# 1:1 line
xmin, xmax = ax.get_xlim()
x_line = np.linspace(0, 5, 100)
ax.plot(x_line, x_line, linestyle='--', color='#ff6361', lw = 3)

ax.set_xlim(0, 5)
ax.set_ylim(0, 10)

plt.tight_layout()
ax.set_yticks([])
ax.set_xticks([])

# lgnd = plt.legend(markerscale=4)

# for handle in lgnd.legend_handles:
#     handle.set_sizes([30.0])
plt.savefig('scatterplotcolored_inside.svg')
plt.savefig('scatterplotcolored_inside.png', dpi = 300)
plt.show()

# ---- figure and square axes ----
fig = plt.figure(figsize=(8,8))
ax = fig.add_axes([0.2,0.2,0.7,0.7])
# ---- scatter plot ----


sc = ax.scatter(
    cluster_std[TII_max].flatten()[1]/cluster_std.min(),
    pred_error[TII_max].flatten()[1]/cluster_std.min(),
    c='#bc5090',
    s=6,
    edgecolors='none',
    label = 'Upper-Extrapolated Values'
)

sc = ax.scatter(
    cluster_std[TII_middle].flatten()[1]/cluster_std.min(),
    pred_error[TII_middle].flatten()[1]/cluster_std.min(),
    c='#003f5c',
    s=6,
    edgecolors='none',
    label = 'Interpolated Values'
    
)

sc = ax.scatter(
    cluster_std[TII_min].flatten()[1]/cluster_std.min(),
    pred_error[TII_min].flatten()[1]/cluster_std.min(),
    c='#7d72c6',
    s=6,
    edgecolors='none',
    label = 'Lower-Extrapolated Values'    
)

# --- axes & formatting ---
ax.set_ylabel('NN uncertainty (a.u.)', fontsize=20)
ax.set_xlabel('MC replica uncertainty (a.u.)', fontsize=20)

# 1:1 line
xmin, xmax = ax.get_xlim()
x_line = np.linspace(0, 5, 100)
ax.plot(x_line, x_line, linestyle='--', color='#FF6F61', lw = 3)

ax.set_xlim(0, 5)
ax.set_ylim(0, 10)

plt.tight_layout()
lgnd = plt.legend(markerscale=4)

# for handle in lgnd.legend_handles:
#     handle.set_sizes([30.0])
plt.savefig('scatterplotcolored_outside.svg')
plt.savefig('scatterplotcolored.png', dpi = 300)
plt.show()

In [None]:
background_trainer = BackgroundTrainer(
    signal=signal_range,
    pre_edge_mask=pre_edge_mask,
    X_mc=X_builder.X_mc,
    X_eval=X_builder.X_eval,
    clustered_spectra_mean=clusterer.clusters_mean,
    triangular_matices=None,
    covariance_matrices=clusterer.clusters_covariance,
    cluster_labels=clusterer.clusters

)

for i in range(1):
    cluster_ids = clusterer.clusters == i
        # Original cluster mean (log-space)
    cluster_mean = clusterer.clusters_mean[:, i]

    # Draw replicas
    replicas = background_trainer._generate_mc_replica_covariance()[:, i][None,:]
    for ii in range(2):
        replica = background_trainer._generate_mc_replica_covariance()[:, i][None,:]
        replicas = np.concatenate((replicas, replica), axis=0 )
    # replicas = [background_trainer._generate_mc_replica_covariance()[:, i] for _ in range(1000)]
    # replicas = np.stack(replicas, axis=0)  # shape (n_samples, n_E)
    print(replicas.shape)
    replica_size = np.mean(replicas, axis=0).shape[0]
    print(replica_size)
    plt.figure(figsize=(8, 8), dpi=300)
    plt.plot(energy_range[:replica_size], np.mean(replicas, axis=0), label="Replica mean", color="#5CECC4", linewidth=4)
    plt.fill_between(
        energy_range[:replica_size],
        np.percentile(replicas, 5, axis=0),
        np.percentile(replicas, 95, axis=0),
        color='#5CECC4',
        alpha=0.5, label="Replica uncertainty (90% CL)",
        hatch = 'xx'
    )


    plt.plot(energy_range[:replica_size], np.mean(np.log(signal_range[:replica_size][:,cluster_ids]), axis = 1), label=f"Cluster mean", color="#FF6F61")
    plt.fill_between(
        energy_range[:replica_size],
        np.log(np.percentile(signal_range[:replica_size][:,cluster_ids], 5, axis=1)),
        np.log(np.percentile(signal_range[:replica_size][:,cluster_ids], 95, axis=1)),
        alpha=0.3, color='#FF6F61',
        label=f"Cluster Uncertainty (90% CL)",
        hatch = '..'
    )
    plt.xlabel("Energy Loss [eV]",fontsize = 20)
    plt.ylabel("Log Intensity [a.u.]",fontsize = 20)
    plt.yticks([])
    plt.xlim(energy_range[0], energy_range[90])
    plt.legend(fontsize = 18)
    plt.savefig(r'ch2_XX_validationreplicas.pdf', bbox_inches="tight")