# 01 - XRF HDF5 Data Loading

This notebook demonstrates loading X-ray fluorescence (XRF) data stored in
MAPS HDF5 format using `h5py`. We cover:

1. Opening an HDF5 file and inspecting its group hierarchy
2. Listing available elements
3. Computing basic per-element statistics
4. Visualizing elemental maps

In [None]:
import h5py
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

## 1. Open a MAPS HDF5 File

MAPS stores fitted elemental maps under `/MAPS/XRF_fits`. Each dataset in
that group corresponds to a 2-D elemental concentration map.

In [None]:
DATA_PATH = "../data/sample_xrf.h5"  # adjust to your file

f = h5py.File(DATA_PATH, "r")

def print_tree(group, indent=0):
    """Recursively print HDF5 group hierarchy."""
    for key in group:
        item = group[key]
        prefix = "  " * indent
        if isinstance(item, h5py.Group):
            print(f"{prefix}[Group] {key}")
            print_tree(item, indent + 1)
        else:
            print(f"{prefix}[Dataset] {key}  shape={item.shape}  dtype={item.dtype}")

print_tree(f)

## 2. List Available Elements

The `channel_names` dataset stores element symbols that correspond to
slices along the first axis of `XRF_fits`.

In [None]:
xrf_group = f["/MAPS"]
channel_names = [name.decode() for name in xrf_group["channel_names"][:]]
xrf_fits = xrf_group["XRF_fits"]  # shape: (n_elements, rows, cols)

print(f"Number of elements: {len(channel_names)}")
print(f"Elements: {channel_names}")
print(f"Map shape per element: {xrf_fits.shape[1:]}")

## 3. Basic Per-Element Statistics

In [None]:
data = xrf_fits[:]  # load into memory

print(f"{'Element':<8} {'Mean':>12} {'Std':>12} {'Min':>12} {'Max':>12}")
print("-" * 60)
for i, name in enumerate(channel_names):
    arr = data[i]
    print(f"{name:<8} {arr.mean():12.4f} {arr.std():12.4f} {arr.min():12.4f} {arr.max():12.4f}")

## 4. Visualize Elemental Maps

Display a grid of elemental concentration maps to get an initial sense of
spatial distribution.

In [None]:
n_elements = len(channel_names)
ncols = 4
nrows = int(np.ceil(n_elements / ncols))

fig, axes = plt.subplots(nrows, ncols, figsize=(4 * ncols, 3.5 * nrows))
axes = axes.flatten()

for i, name in enumerate(channel_names):
    im = axes[i].imshow(data[i], cmap="inferno")
    axes[i].set_title(name)
    axes[i].axis("off")
    fig.colorbar(im, ax=axes[i], fraction=0.046, pad=0.04)

for j in range(i + 1, len(axes)):
    axes[j].axis("off")

plt.tight_layout()
plt.show()

In [None]:
f.close()