In [None]:
from CLfitter import *
from joblib import Parallel, delayed
import numpy as np
import os
import torch.nn as nn

The first step to obtain a background prediction is defining the NN architecture. It was found that a 2,16,1 architecture works well.

In [None]:
class EELSBackgroundNN(nn.Module):
     def __init__(self):
          super().__init__()
          self.model = nn.Sequential(
               nn.Linear( 2, 16),
               nn.SiLU(),
               nn.Linear(16, 1)
          )
     def forward(self, x):
          return self.model(x)

In [None]:
path= 'data/twisted-CrSBr-2deg-eels-SI_004.dm4'
i = (520,570,670) #E_min, E_I, E_max
n_clusters = 4


**DataHandler** is called to handle the data.
.read_dm4_SI can obtain data from either dm3 and dm4 files, with core_loss_index begin the index of the core-loss data. This index can either be found by trail and error or opening GMS and checking the entry. Index runs from 0, so keep that in mind.

**Pooler** is used to pool the data per energy bin. Either a gaussian or a square kernel can be used, and a radius can be defined.

Due to how some EELS spectroscopes gather data, negative counts can be obtained. All entries less than 1 are therefore set to 1 (also to avoid ln(0) later on)


In [None]:
file = dm.file_reader(path)
data_handler = DataHandler()
data_handler.read_dm4_SI(path, core_loss_index=3, lowloss=False)

pooler = Pooler(data_handler.signal, data_handler.si_size)
signal =  pooler.pool_data(sqr_radius=2, gaussian_kernel=True)
signal[signal<1] = 1

Some pre-processing steps.

In [None]:
#Define E_min, E_I and E_max
range_mask = (data_handler.energy_axis > i[0]) & (data_handler.energy_axis < i[-1])  # Range for clustering shape [n_E_1]
energy_range = data_handler.energy_axis[range_mask] # shape [n_E_1]
signal_range = signal.copy()[range_mask,:]  # shape [n_E_1, n_y*n_x]
pre_edge_mask = (energy_range > i[0]) & (energy_range < i[1])  # Range for pre-edge shape [n_E_2]

**Clusterer** clusters the data in n_clusters clusters. .cholesky_decomp calculates the cholesky decomposition of the covariance matrices.

**X_Builder** makes the X data for the NN training, both MC and evaluation. 

In [None]:

clusterer = ClusterAnalyzer(signal_range)
clusterer.cluster_data(n_clusters = n_clusters, pre_edge_mask=pre_edge_mask,)
clusterer.cholesky_decomp()

X_builder = X_Builder(energy_range)
X_builder.prepare_X_mc_data(clusterer.cluster_centers, i[1])
X_builder.prepare_X_eval_data(clusterer.total_integrated_intensity)



The following cell trains the NN's and saves the results. It does this in parallel, each core running 10 NN's consecutively.

In [None]:
def train_background(ii, signal_range, pre_edge_mask, X_builder, clusterer, i):
    background_trainer = BackgroundTrainer(
        signal=signal_range,
        pre_edge_mask=pre_edge_mask,
        X_mc=X_builder.X_mc,
        X_eval=X_builder.X_eval,
        clustered_spectra_mean=clusterer.clusters_mean,
        triangular_matices=clusterer.triangular_matices,
        covariance_matrices=clusterer.clusters_covariance,
        cluster_labels=clusterer.clusters
    )

    background_trainer.train_MC_replica_consecutive(
        n_mc_replicas=5, 
        epochs=20000, 
        edge_onset=i[1],
        replica_version='covariance',
        model = EELSBackgroundNN(),        
        progress = False,
        logging = False
    )
    predictions = background_trainer.background
    prediction_saver = PredictionSaver(
        signal=signal_range,
        energy_axis=energy_range,
        spatial_axis_x=data_handler.spatial_axis_x,
        spatial_axis_y=data_handler.spatial_axis_y,
        predictions=predictions
    )

    folder_path = fr'fitruns'
    os.makedirs(folder_path, exist_ok=True)
    filename = f'run_{i[0]}-{i[1]}-{i[2]}v{ii}.npz'
    path_to_save = os.path.join(folder_path, filename)
    prediction_saver.save_predictions(path_to_save)
    
    # np.savez(f'bvdwiele/realresults/pred_{ii}.npz', pred = background_trainer.background)
    
    del background_trainer
    # Force garbage collection
    import gc
    gc.collect()

# Number of parallel jobs (adjust based on your CPU cores)
n_jobs = 4

Parallel(n_jobs=n_jobs)(
    delayed(train_background)(ii, signal_range, pre_edge_mask, X_builder, clusterer, i)
    for ii in range(0, 4) #100*10 = 1000 total replicas
)

In [None]:
import numpy as np
from lmfit import Model, Parameters

class PredictionChecker:
    def __init__(self, path):
        data = np.load(path)

        self.predictions = data['predictions']  # shape (n_mc, n_spatial, n_E)
        self.signal = data['signal']            # shape (n_E, n_spatial)
        self.energy_axis = data['energy_axis']  # shape (n_E,)
        self.y_axis = data['spatial_axis_y']
        self.x_axis = data['spatial_axis_x']

        self.edge_mean = np.mean(
            self.signal.T[None, :, :] - self.predictions, axis=0
        )  # shape (n_spatial, n_E)
        self.edge_stdev = np.std(
            self.signal.T[None, :, :] - self.predictions, axis=0
        )

        self.spatial_size = (len(self.y_axis), len(self.x_axis))

    @staticmethod
    def _two_gauss_two_arctan(x,
                            A1, mu1, sigma1,
                            A2, mu2, sigma2,
                            C1, x01, w1,
                            C2, x02, w2):
        """
        Model = 2 Gaussians + 2 shifted arctan steps).
        """
        g1 = A1 * np.exp(-(x - mu1) ** 2 / (2 * sigma1 ** 2))
        g2 = A2 * np.exp(-(x - mu2) ** 2 / (2 * sigma2 ** 2))
        ar1 = C1 * (np.arctan((x - x01) / (w1 + 1e-12)) + np.pi / 2.0)
        ar2 = C2 * (np.arctan((x - x02) / (w2 + 1e-12)) + np.pi / 2.0)
        return g1 + g2 + ar1 + ar2
    
        
    def _fit_and_integrate_white_lines(self, spectrum, mu_guesses, fit_window=5.0, integration_window=1.5):
        """
        LMfit: fit spectrum with 2 Gaussians + 2 shifted arctans.
        Returns: area1, area2, mu1, mu2, (step1_height, step2_height)
        """
        energy = self.energy_axis
        mask_fit = (energy > min(mu_guesses) - fit_window) & (energy < max(mu_guesses) + fit_window)
        x = energy[mask_fit]
        y = spectrum[mask_fit]

        if len(x) < 5 or np.all(y <= 0):
            return np.nan, np.nan,np.nan, np.nan,np.nan, np.nan,np.nan, np.nan

        # Initial guesses
        A1_guess = max(y[(x > mu_guesses[0]-0.5) & (x < mu_guesses[0]+0.5)].max(), 1e-3)
        A2_guess = max(y[(x > mu_guesses[1]-0.5) & (x < mu_guesses[1]+0.5)].max(), 1e-3)
        C_guess = max((y[-1] - y[0]) * 0.3, 1e-3)

        model = Model(self._two_gauss_two_arctan)
        params = Parameters()

        #sharpness
        params.add("A1", value=A1_guess, min=0)
        params.add("A2", value=A2_guess, min=0)

        params.add("C1", value=C_guess, min=1)
        params.add("C2", value=C_guess, min=1)
       
        #sharpness
        params.add("sigma1", value=0.7, min=0.1, max=2.0)
        params.add("sigma2", value=0.7, min=0.1, max=2.0)

        params.add("w1", value=0.5, min=0.01, max=2)

        params.add("delta_w", value=0.0, min=-0.5, max=0.5)  # allow Â±0.5 eV difference
        params.add("w2", expr="w1+delta_w")

        #onsets
        params.add("mu1", value=mu_guesses[0], min=mu_guesses[0]-2, max=mu_guesses[0]+2)
        params.add("mu2", value=mu_guesses[1], min=mu_guesses[1]-2, max=mu_guesses[1]+2)

        params.add('deltax01', value=0, min=-1, max=1)
        params.add("x01", expr = 'mu1+deltax01')
        params.add('deltax02', value=0, min=-1, max=1)
        params.add("x02", expr = 'mu2+deltax02')
        
        try:
            result = model.fit(y, params, x=x)
            best = result.best_values
            fit_y = result.best_fit

            mu1, mu2 = best["mu1"], best["mu2"]
            C1, x01, w1 = best["C1"], best["x01"], best["w1"]
            C2, x02, w2 = best["C2"], best["x02"], best["w2"]
            A1, A2, sigma1, sigma2 = best['A1'],best['A2'], best['sigma2'], best['sigma2']

            return C1, C2, A1, A2, sigma1, sigma2, mu1, mu2

        except Exception:
            print('bonk')
            return np.nan, np.nan,np.nan, np.nan,np.nan, np.nan,np.nan, np.nan


    def white_line_calculation_MC(self, L3_guess=577, L2_guess=586, integration_window=1.5, fit_window=5):
        '''
        Code for obtaining the normalized white line intensity values for each entry in the MC ensemble

        '''
        J = self.predictions.shape[0]
        I = self.edge_mean.shape[0]

        C1  = np.full((J,I), np.nan)
        C2  = np.full((J,I), np.nan)
        A1  = np.full((J,I), np.nan)
        A2 = np.full((J,I), np.nan)
        sigma1 = np.full((J,I), np.nan)
        sigma2 = np.full((J,I), np.nan)
        mu1 = np.full((J,I), np.nan)
        mu2 = np.full((J,I), np.nan)

        replicas = self.signal.T[None, :, :] - self.predictions

        for j in range(J):
            print(f'replica {j+1}/{J}')
            for i in range(I):
                spectrum = replicas[j, i, :]
                C1[j, i], C2[j, i], A1[j, i], A2[j, i], sigma1[j, i], sigma2[j, i], mu1[j, i], mu2[j, i] = self._fit_and_integrate_white_lines(
                    spectrum, (L3_guess, L2_guess),
                    fit_window=fit_window, integration_window=integration_window
                )

        return C1, C2, A1, A2, sigma1, sigma2, mu1, mu2
    
    def plot_fit_diagnostic(self, spectrum, mu_guesses=(577, 586), fit_window=5):
        """
        Plot a spectrum with its fit and individual components. Does nothing other than plot, just for the figure in the paper
        """
        # mask window
        mask_fit = (self.energy_axis > min(mu_guesses) - fit_window) & \
                   (self.energy_axis < max(mu_guesses) + fit_window)
        x = self.energy_axis[mask_fit]
        y = spectrum[mask_fit]

        # Run the fit
        fit_result = self._fit_single_spectrum_with_initial_guess(
            spectrum, self.energy_axis, mu_guesses, fit_window=fit_window
        )

        if fit_result is None or not fit_result.success:
            print("Fit failed.")
            return

        best = fit_result.best_values

        # Reconstruct components
        g1 = best["A1"] * np.exp(-(x - best["mu1"]) ** 2 / (2 * best["sigma1"] ** 2))
        g2 = best["A2"] * np.exp(-(x - best["mu2"]) ** 2 / (2 * best["sigma2"] ** 2))
        ar1 = best["C1"] * (np.arctan((x - best["x01"]) / (best["w1"] + 1e-12)) + np.pi/2)
        ar2 = best["C2"] * (np.arctan((x - best["x02"]) / (best["w2"] + 1e-12)) + np.pi/2)
        total_fit = g1 + g2 + ar1 + ar2

        # Plot        
        fig = plt.figure(figsize = (8,8))

        ax = fig.add_axes([0.2,0.2,0.7,0.7])        
        
        ax.plot(x, y, "-", label="Subtracted EELS Data", lw=2, color = 'black')
        ax.plot(x, total_fit, "r-", lw=3, label="Total Fit", color = '#4B4DED')
        ax.plot(x, g1, "--",lw=3, label="$L_3$ Peak", color='#5CECC4')
        ax.plot(x, g2, "--",lw=3, label="$L_2$ Peak", color='#369178')
        ax.plot(x, ar1+ar2, ":",lw=3, label="Continuum states", color = '#FF6F61')

        ax.set_xlabel("Energy loss (eV)")
        ax.set_ylabel("Intensity (a.u.)")
        ax.legend()
        ax.set_yticks([0,])

        ax.set_xlim([x.min(), x.max()])
        fig.savefig('fig4_.svg', bbox_inches='tight')
        plt.show()

In [None]:
import matplotlib.pyplot as plt


version, number = (2,4)

for xx in range(0,4):
    fname_pred = fr'fitruns\run_520-570-670v{xx}.npz'
    PC = PredictionChecker(fname_pred) 

#     spectrum = PC.edge_mean[2,:]   # pick one spatial point
    C1, C2, A1, A2, sigma1, sigma2, mu1, mu2 = PC.white_line_calculation_MC(577, 586, fit_window = 5)

    np.savez(fr'fitdata\white_line_results-{xx}.npz', 
            C1=C1, C2=C2, A1=A1, A2=A2,
            sigma1=sigma1, sigma2=sigma2, mu1=mu1, mu2=mu2, spatial_size = PC.spatial_size)

    
    del C1, C2, A1, A2, sigma1, sigma2, mu1, mu2, PC

In [None]:
import matplotlib.pyplot as plt
data = np.load(fr'fitdata\white_line_results-{0}.npz')
C1= data['C1']
C2= data['C2']
sigma1= data['sigma1']
sigma2= data['sigma2']
A1= data['A1']
A2= data['A2']
SS = data['spatial_size']
print(C1.shape)
wl = (A1*sigma1+A2*sigma2)/(C1+C2)
# print(wl.shape)
wl = wl.reshape(5,*SS)
wl.shape

plt.imshow(wl.mean(axis=0))
plt.show()


plt.imshow(wl.std(axis=0))
plt.show()