# 01 - XRF HDF5 Data Loading

This notebook demonstrates how to load X-ray fluorescence (XRF) data stored in
MAPS HDF5 format using `h5py`. We cover:

1. Opening an HDF5 file and inspecting its group hierarchy
2. Listing available element channels
3. Extracting elemental maps as NumPy arrays
4. Computing basic statistics per element
5. Visualizing elemental maps

In [None]:
import h5py
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

## 1. Open the HDF5 file

MAPS-format HDF5 files typically store fitted elemental maps under
`/MAPS/XRF_fits` and raw spectra under `/MAPS/mca_arr`.

In [None]:
# Update this path to point to your MAPS HDF5 file
DATA_PATH = "../data/sample_xrf.h5"

f = h5py.File(DATA_PATH, "r")
print("Top-level groups:", list(f.keys()))

In [None]:
def print_hdf5_tree(group, indent=0):
    """Recursively print HDF5 group hierarchy."""
    for key in group:
        item = group[key]
        prefix = "  " * indent
        if isinstance(item, h5py.Group):
            print(f"{prefix}[Group] {key}")
            print_hdf5_tree(item, indent + 1)
        else:
            print(f"{prefix}[Dataset] {key}  shape={item.shape}  dtype={item.dtype}")

print_hdf5_tree(f)

## 2. List available elements

The `channel_names` dataset contains the element labels that correspond
to slices along axis 0 of `XRF_fits`.

In [None]:
maps_group = f["/MAPS"]
channel_names = [name.decode() for name in maps_group["channel_names"][:]]
print(f"Number of element channels: {len(channel_names)}")
print("Elements:", channel_names)

## 3. Extract elemental maps

In [None]:
xrf_fits = maps_group["XRF_fits"][:]  # shape: (n_elements, rows, cols)
print(f"XRF fits array shape: {xrf_fits.shape}")

# Build a dictionary mapping element name -> 2D map
element_maps = {name: xrf_fits[i] for i, name in enumerate(channel_names)}
print(f"Loaded maps for: {list(element_maps.keys())[:10]} ...")

## 4. Basic statistics per element

In [None]:
print(f"{'Element':<8} {'Mean':>12} {'Std':>12} {'Min':>12} {'Max':>12}")
print("-" * 60)
for name in channel_names:
    m = element_maps[name]
    print(f"{name:<8} {m.mean():>12.4f} {m.std():>12.4f} {m.min():>12.4f} {m.max():>12.4f}")

## 5. Visualization

In [None]:
# Select a subset of elements to display
elements_to_show = channel_names[:6]
ncols = 3
nrows = int(np.ceil(len(elements_to_show) / ncols))

fig, axes = plt.subplots(nrows, ncols, figsize=(4 * ncols, 4 * nrows))
for ax, name in zip(axes.ravel(), elements_to_show):
    im = ax.imshow(element_maps[name], cmap="viridis")
    ax.set_title(name)
    ax.axis("off")
    fig.colorbar(im, ax=ax, fraction=0.046)

# Hide unused axes
for ax in axes.ravel()[len(elements_to_show):]:
    ax.set_visible(False)

plt.tight_layout()
plt.show()

In [None]:
f.close()
print("HDF5 file closed.")