In [None]:
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

import autoden as ad

%load_ext autoreload
%autoreload 2

%matplotlib widget

In [None]:
# import h5py

# with h5py.File(
#     "/data/projects/computational_imaging/ghost_imaging/datasets/2022-09-09_ID19_IHMI1497_XRF-GI/gi_W250mask_CuFeUnknown_3wires/gi_W250mask_CuFeUnknown_3wires_0002/gi_W250mask_CuFeUnknown_3wires_0002.h5"
# ) as h5f:
#     data_xrf_h5 = h5f["7.1/instrument/fxid19_det0/data"]
#     if isinstance(data_xrf_h5, h5py.Dataset):
#         data_xrf: NDArray = data_xrf_h5[()]
#     else:
#         raise ValueError("Not a dataset!")

# print(data_xrf.shape)
# en_keV = np.round(np.arange(data_xrf.shape[-1]) * 0.00501 - 0.02, decimals=3)

# df = pd.DataFrame(data_xrf, columns=en_keV)
# df.to_csv("05_xrf_spectra.csv")

In [None]:
xrf_data = pd.read_csv("05_xrf_spectra.csv")

energies_keV = np.array([float(e) for e in xrf_data.columns[1:-1]])
counts = xrf_data.to_numpy()[:, 1:-1]

fig, axs = plt.subplots(1, 1, figsize=(8, 3))
axs.plot(energies_keV, counts[0], label="Spectrum #1")
axs.plot(energies_keV, counts.mean(axis=0), label="Mean of spectra")
axs.set_xlim(energies_keV[0], energies_keV[-1])
axs.set_xlabel("keV")
axs.set_ylabel("Photon counts")
axs.grid()
axs.legend()
fig.tight_layout()

In [None]:
model = ad.NetworkParamsDnCNN(n_dims=1, n_features=4, n_layers=6).get_model()

n2n = ad.N2N(model=model, reg_val=None)
n2n_data = n2n.prepare_data(counts[:4])
_ = n2n.train(*n2n_data, epochs=10_000, lower_limit=0.0, learning_rate=1e-2, restarts=1)
pred_counts = n2n.infer(n2n_data[0])

In [None]:
fig, axs = plt.subplots(1, 1, figsize=(8, 3))
axs.plot(energies_keV, counts[0], ".", label="Spectrum #1")
axs.plot(energies_keV, counts.mean(axis=0), label="Mean of spectra")
axs.plot(energies_keV, pred_counts, label="Predicted spectrum")
axs.set_xlim(energies_keV[0], energies_keV[-1])
axs.set_xlabel("keV")
axs.set_ylabel("Photon counts")
axs.grid()
axs.legend()
fig.tight_layout()

In [None]:
ranges: tuple[tuple[float, float], ...] = ((2.5, 7.5), (7.75, 9.75), (15.0, 18.5), (15.2, 16.25))

fig, axs = plt.subplots(len(ranges), 1, figsize=(8, len(ranges) * 2.5))
for ii, (e_s, e_e) in enumerate(ranges):
    bin_s = np.abs(energies_keV - e_s).argmin()
    bin_e = np.abs(energies_keV - e_e).argmin()
    axs[ii].plot(energies_keV[bin_s:bin_e], counts[0, bin_s:bin_e], ".", label="Spectrum #1")
    axs[ii].plot(energies_keV[bin_s:bin_e], counts[:, bin_s:bin_e].mean(axis=0), label="Mean of spectra")
    axs[ii].plot(energies_keV[bin_s:bin_e], pred_counts[bin_s:bin_e], label="Predicted spectrum")
    axs[ii].set_xlim(e_s, e_e)
    axs[ii].grid()
    axs[ii].legend()
fig.tight_layout()