# X-ray Reflectivity Fitting and Visualization: Simplified Workflow

This notebook provides a streamlined workflow for X-ray reflectivity (XRR) data analysis, including model fitting, parameter inspection, and advanced visualization. It is designed for reproducibility and safe handling of fit results, with frequent saving and protection of the standard `fitting_results_fixed.pkl`.

**Outline:**

1. Import Libraries and Set Up Environment
2. Load and Prepare Data
3. Load Optical Constants
4. Define Utility Functions for Fitting and Visualization
5. Build Layered Slab Model
6. Fit the Model to Data and Save Results
7. Load and Inspect Fitting Results
8. Visualize Reflectivity and Structure
9. Visualize Optical Constants and Orientation Profiles
10. Visualize 2D Reflectivity Maps and Anisotropy


## 1. Import Libraries and Set Up Environment

Import all required libraries for data handling, fitting, and visualization. Set consistent plotting styles.


In [None]:
import numpy as np
import pandas as pd
import polars as pl
import matplotlib.pyplot as plt
import seaborn as sns
import pickle
import os
from typing import Any, Dict, List
from matplotlib.figure import Figure

# Fitting libraries
from refnx.analysis import Transform, GlobalObjective, Parameters
from refnx.reflect.interface import Erf, Step
import pyref.fitting as fit

# Set plotting styles
plt.style.use("seaborn-v0_8-whitegrid")
sns.set_context("notebook")
plt.rcParams["xtick.direction"] = "in"
plt.rcParams["ytick.direction"] = "in"
plt.rcParams["xtick.top"] = True
plt.rcParams["ytick.right"] = True
plt.rcParams["grid.linestyle"] = "--"
plt.rcParams["grid.linewidth"] = 0.5
plt.rcParams["axes.grid.which"] = "both"

## 2. Load and Prepare Data

Read reflectivity data from parquet or CSV files, filter and group as needed, and prepare data structures for fitting.


In [None]:
# Load reflectivity data
df = (
    pl.read_parquet("june_processed.parquet")
    .sort(pl.col("Q"))
    .sort(pl.col("pol"), descending=True)
)
df = df.filter(pl.col("sample").str.starts_with("mono"))

# Display a preview
display(df.head())

# Group data by energy and prepare datasets
data = {}
for en, g in df.group_by("Beamline Energy [eV]", maintain_order=True):
    Q = g["Q"].to_numpy()
    R = g["r"].to_numpy()
    dR = 0.08 * R + 0.2e-6 * Q
    dR = np.minimum(dR, 0.9 * R)
    ds = fit.XrayReflectDataset(data=(Q, R, dR))
    data[str(en[0])] = ds

# Define energy batches for fitting
energy_batches = [
    np.array([250, 283.7]),  # Example: structure energies
    df["Beamline Energy [eV]"].unique().to_numpy(),  # All measured energies
]
display(energy_batches[0])

## 3. Load Optical Constants

Load optical constants from CSV and define interpolation functions for use in model construction.


In [None]:
# Load optical constants
ooc = pd.read_csv("optical_constants.csv")


def ooc_function(energy, ooc=ooc, theta=0.0, density=1.0):
    """Interpolate optical constants for a given energy, orientation, and density."""
    n_xx = np.interp(energy, ooc["energy"], ooc["n_xx"]) * density
    n_zz = np.interp(energy, ooc["energy"], ooc["n_zz"]) * density
    n_ixx = np.interp(energy, ooc["energy"], ooc["n_ixx"]) * density
    n_izz = np.interp(energy, ooc["energy"], ooc["n_izz"]) * density
    # Rotate by theta if needed
    if theta != 0.0:
        n_xx, n_zz = (
            0.5 * (n_xx * (1 + np.cos(theta) ** 2) + n_zz * np.sin(theta) ** 2),
            n_xx * np.sin(theta) ** 2 + n_zz * np.cos(theta) ** 2,
        )
        n_ixx, n_izz = (
            0.5 * (n_ixx * (1 + np.cos(theta) ** 2) + n_izz * np.sin(theta) ** 2),
            n_ixx * np.sin(theta) ** 2 + n_izz * np.cos(theta) ** 2,
        )
    return n_xx, n_zz, n_ixx, n_izz

## 4. Define Utility Functions for Fitting and Visualization

Helper functions for model construction, parameter setting, safe saving/loading of objectives, and plotting.


In [None]:
def save_objective(obj: Any, base_name: str = "fit_obj", folder: str = "fit_results"):
    """Safely save the objective to a versioned pickle file, never overwriting fitting_results_fixed.pkl."""
    os.makedirs(folder, exist_ok=True)
    # Find next available version
    i = 1
    while True:
        fname = os.path.join(folder, f"{base_name}_v{i}.pkl")
        if not os.path.exists(fname):
            break
        i += 1
    with open(fname, "wb") as f:
        pickle.dump(obj, f)
    print(f"Saved objective to {fname}")


def load_latest_objective(base_name: str = "fit_obj", folder: str = "fit_results"):
    """Load the latest versioned objective pickle file."""
    files = [
        f for f in os.listdir(folder) if f.startswith(base_name) and f.endswith(".pkl")
    ]
    if not files:
        raise FileNotFoundError("No saved objectives found.")
    files.sort()
    latest = files[-1]
    with open(os.path.join(folder, latest), "rb") as f:
        obj = pickle.load(f)
    print(f"Loaded objective from {latest}")
    return obj


def plot_reflectivity_and_structure(global_obj, figsize=(10, 6)):
    """Plot reflectivity data and structure profiles for a global objective."""
    objectives = global_obj.objectives
    energy_labels = [o.model.energy for o in objectives]
    stacks = {str(o.model.energy): o.model.structure for o in objectives}
    n_objectives = len(objectives)
    fig, ax = plt.subplots(
        nrows=n_objectives,
        ncols=2,
        figsize=figsize,
        gridspec_kw={"width_ratios": [2.5, 1]},
    )
    if n_objectives == 1:
        ax = ax.reshape(1, -1)
    for i, o in enumerate(objectives):
        # Reflectivity
        o.plot(ax=ax[i][0], show_anisotropy=False)
        ax[i][0].set_yscale("log")
        ax[i][0].set_ylabel("Reflectivity")
        ax[i][0].set_title(f"{energy_labels[i]:.1f} eV")
        # Structure
        stacks[str(o.model.energy)].plot(ax=ax[i][1])
        ax[i][1].set_ylabel("SLD")
    plt.tight_layout()
    plt.show()
    return fig, ax


def print_structure_params(fit_objectives, structure_property):
    """Print the structure parameters from the fit objectives."""
    params = []
    for objective in fit_objectives.objectives:
        for slab in objective.model.structure:
            if structure_property in ["density", "rotation"]:
                if hasattr(slab, "sld") and hasattr(slab.sld, structure_property):
                    param = getattr(slab.sld, structure_property)
                    if param.vary:
                        params.append(param)
            if hasattr(slab, structure_property):
                param = getattr(slab, structure_property)
                if param.vary:
                    params.append(param)
    print(f"Number of {structure_property} parameters: {len(params)}")
    print("-" * 60)
    for param in params:
        bounds = (
            f"({param.bounds.lb:.3f}, {param.bounds.ub:.3f})"
            if param.bounds is not None
            else "(None, None)"
        )
        val = param.value
        err = param.stderr if param.stderr else 0
        print(f"| {param.name:30} | {val:10.3f} ± {err:8.3f} | {str(bounds):18} |")
    print("-" * 60)

## 5. Build Layered Slab Model

Construct the multilayer slab model for each energy. Each layer represents a physical region (e.g., vacuum, surface, ZnPc, contamination, SiO₂, substrate).


In [None]:
ZNPC = "C32H16N8Zn"
MA = np.arcsin(np.sqrt(2 / 3))  # Magic angle in radians


def vacuum(energy):
    slab = fit.MaterialSLD("", 0, name=f"Vacuum_{energy}")(0, 0)
    slab.thick.setp(vary=False)
    slab.rough.setp(vary=False)
    slab.sld.density.setp(vary=False)
    return slab


def substrate(energy, thick=0, rough=1.2, density=2.44):
    slab = fit.MaterialSLD(
        "Si", density=density, energy=energy, name=f"Substrate_{energy}"
    )(thick, rough)
    slab.thick.setp(vary=False)
    slab.rough.setp(vary=False)
    slab.sld.density.setp(vary=False, bounds=(2, 3))
    return slab


def sio2(energy, thick=8.22, rough=6.153, density=2.15):
    slab = fit.MaterialSLD(
        "SiO2", density=density, energy=energy, name=f"Oxide_{energy}"
    )(thick, rough)
    slab.thick.setp(vary=True, bounds=(8, 12))
    slab.rough.setp(vary=True, bounds=(0, 8))
    slab.sld.density.setp(vary=True, bounds=(1, 2.3))
    return slab


def contamination(energy, thick=4.4, rough=2, density=1.0):
    name = f"Contamination_{energy}"
    slab = fit.UniTensorSLD(
        ooc, density=density, rotation=0.81, energy=energy, name=name
    )(thick, rough)
    slab.sld.density.setp(vary=True, bounds=(1, 1.8))
    slab.sld.rotation.setp(vary=True, bounds=(np.pi / 4, 7 * np.pi / 8))
    slab.thick.setp(vary=True, bounds=(0, 12))
    slab.rough.setp(vary=True, bounds=(0, 5))
    return slab


def surface(energy, thick=3.3, rough=1, density=1.0):
    name = f"Surface_{energy}"
    slab = fit.UniTensorSLD(
        ooc, density=density, rotation=0.8, energy=energy, name=name
    )(thick, rough)
    slab.sld.density.setp(vary=True, bounds=(1, 1.8))
    slab.sld.rotation.setp(vary=True, bounds=(0, np.pi / 4))
    slab.thick.setp(vary=True, bounds=(0, 12))
    slab.rough.setp(vary=True, bounds=(0, 5))
    return slab


def znpc(energy, thick=191, rough=8.8, density=1.61):
    name = f"ZnPc_{energy}"
    slab = fit.UniTensorSLD(
        ooc, density=density, rotation=1.35, energy=energy, name=name
    )(thick, rough)
    slab.sld.density.setp(vary=True, bounds=(1.2, 1.8))
    slab.sld.rotation.setp(vary=True, bounds=(MA, np.pi / 2))
    slab.thick.setp(vary=True, bounds=(180, 210))
    slab.rough.setp(vary=True, bounds=(2, 16))
    return slab


def construct_slab(energy, offset=0):
    offset_energy = round(energy + offset, 1)
    slab = (
        vacuum(offset_energy)
        | surface(offset_energy)
        | znpc(offset_energy)
        | contamination(offset_energy)
        | sio2(offset_energy)
        | substrate(offset_energy)
    )
    slab.name = f"Monolayer_{energy + offset}"
    return slab


# Build stacks for structure energies
stacks = {str(round(e, 1)): construct_slab(e) for e in energy_batches[0]}
print("Stack keys:", stacks.keys())

## 6. Fit the Model to Data and Save Results

Fit the model to the data using the defined objective and fitting routines. Save the objective and fit results to a versioned pickle file after each major fitting step.


In [None]:
def fitting(obj, recursion_limit=2, workers=-1, *, mcmc=False, **kwargs):
    """Fit the model to the data using the provided objective."""
    import copy

    objective = copy.deepcopy(obj)
    fitter = fit.CurveFitter(objective, **kwargs)
    fitter.fit(
        "differential_evolution",
        target="nlpost",
        workers=workers,
        polish=False,
        updating="deferred",
    )
    fit.fitters._fix_bounds(objective, by_bounds=True)
    for i in range(recursion_limit):
        fitter.fit(target="nlpost")
        fitter.fit(target="nll")
        fit.fitters._fix_bounds(objective, by_bounds=True)
    return objective, fitter


# Build models for all measured energies
models = {str(e): construct_slab(e) for e in energy_batches[1]}
objectives = [fit.Objective(models[str(e)], data[str(e)]) for e in energy_batches[1]]
for objective in objectives:
    objective.transform = Transform("logY")
global_obj = fit.GlobalObjective(objectives)

# Fit and save results (do NOT overwrite fitting_results_fixed.pkl)
fit_obj, fitter = fitting(global_obj)
save_objective(fit_obj, base_name="fit_obj", folder="fit_results")

## 7. Load and Inspect Fitting Results

Load the latest or a specific saved fit result. Print and inspect varying parameters and structure parameters for each fit.


In [None]:
# Load the latest fit result (never overwrite fitting_results_fixed.pkl)
fit_obj = load_latest_objective(base_name="fit_obj", folder="fit_results")

# Print varying parameters
print("Varying parameters:")
print(fit_obj.varying_parameters())

# Print structure parameters
for prop in ["thick", "rough", "density", "rotation"]:
    print_structure_params(fit_obj, prop)

## 8. Visualize Reflectivity and Structure

Plot the experimental and fitted reflectivity curves for selected energies, and the corresponding structure (SLD) profiles.


In [None]:
# Plot reflectivity and structure for all energies
plot_reflectivity_and_structure(fit_obj, figsize=(10, 2 * len(fit_obj.objectives)))

# Optionally, plot for selected energies only
selected_energies = [250.0, 283.7]
objectives_to_plot = [
    obj for obj in fit_obj.objectives if obj.model.energy in selected_energies
]
if objectives_to_plot:
    plot_reflectivity_and_structure(
        fit.GlobalObjective(objectives_to_plot), figsize=(8, 6)
    )

## 9. Visualize Optical Constants and Orientation Profiles

Plot optical constants as a function of energy and orientation. Visualize orientation (rotation) profiles across the film depth.


In [None]:
def plot_optical_constants_with_energies(
    ooc, energies, en_shift=0.0, theta=0.0, density=1.0
):
    """Plot optical constants with energy markers."""
    fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(8, 8), sharex=True)
    emin = min(energies) + en_shift - 10
    emax = max(energies) + en_shift + 10
    energy_plot = np.linspace(emin, emax, 1000)
    n_xx_plot, n_zz_plot, n_ixx_plot, n_izz_plot = [], [], [], []
    for energy in energy_plot:
        n_xx, n_zz, n_ixx, n_izz = ooc_function(
            energy, ooc=ooc, theta=theta, density=density
        )
        n_xx_plot.append(n_xx)
        n_zz_plot.append(n_zz)
        n_ixx_plot.append(n_ixx)
        n_izz_plot.append(n_izz)
    ax1.plot(energy_plot, n_xx_plot, "b-", linewidth=1.5, label="δ_xx")
    ax1.plot(energy_plot, n_zz_plot, "r-", linewidth=1.5, label="δ_zz")
    ax1.set_ylabel("δ (Real part)")
    ax1.legend()
    ax2.plot(energy_plot, n_ixx_plot, "b-", linewidth=1.5, label="β_xx")
    ax2.plot(energy_plot, n_izz_plot, "r-", linewidth=1.5, label="β_zz")
    ax2.set_ylabel("β (Imaginary part)")
    ax2.set_xlabel("Energy (eV)")
    ax2.legend()
    for e in energies:
        n_xx, n_zz, n_ixx, n_izz = ooc_function(
            e, ooc=ooc, theta=theta, density=density
        )
        ax1.vlines(
            e,
            ymin=min(n_xx, n_zz),
            ymax=max(n_xx, n_zz),
            color="gray",
            linestyle="--",
            alpha=0.5,
        )
        ax2.vlines(
            e,
            ymin=min(n_ixx, n_izz),
            ymax=max(n_ixx, n_izz),
            color="gray",
            linestyle="--",
            alpha=0.5,
        )
    plt.tight_layout()
    plt.show()


# Example: plot for a fitted orientation
orient = {}
densities = {}
for objective in fit_obj.objectives:
    for slab in objective.model.structure:
        slab_name = slab.name.split("_")[0]
        if hasattr(slab, "sld") and hasattr(slab.sld, "rotation"):
            if slab.sld.rotation.vary:
                orient[slab_name] = slab.sld.rotation.value
            if hasattr(slab.sld, "density") and slab.sld.density.vary:
                densities[slab_name] = slab.sld.density.value

energy_offset = (
    fit_obj.objectives[0].model.energy_offset.value
    if hasattr(fit_obj.objectives[0].model, "energy_offset")
    else 0.0
)

for slab_name, orientation in orient.items():
    plot_optical_constants_with_energies(
        ooc,
        energy_batches[1],
        en_shift=energy_offset,
        theta=orientation,
        density=densities.get(slab_name, 1.61),
    )
    plt.title(
        f"Optical Constants for {slab_name} at Orientation {np.degrees(orientation):.1f}°"
    )
    plt.xlim(283, 289.5)
    plt.show()

## 10. Visualize 2D Reflectivity Maps and Anisotropy

Interpolate reflectivity data onto a common theta grid and plot 2D maps of reflectivity and anisotropy ratio for s- and p-polarizations.


In [None]:
# --- Interpolate reflectivity onto a common theta grid ---
hc = 12398.4198  # eV·Å


def q_to_theta(q, energy):
    wavelength = hc / energy
    theta_rad = np.arcsin(q * wavelength / (4 * np.pi))
    return np.degrees(theta_rad)


def theta_to_q(theta, energy):
    wavelength = hc / energy
    theta_rad = np.radians(theta)
    return (4 * np.pi / wavelength) * np.sin(theta_rad)


# Find global theta range
min_theta, max_theta = 0, 60
for objective in fit_obj.objectives:
    energy = objective.model.energy
    q_values = np.concatenate((objective.data.s.x, objective.data.p.x))
    theta_values = q_to_theta(q_values, energy)
    min_theta = min(min_theta, np.min(theta_values))
    max_theta = min(max_theta, np.max(theta_values))
common_theta = np.linspace(min_theta, max_theta, 1000)

# Interpolate each dataset
interpolated_data = {}
for objective in fit_obj.objectives:
    energy = objective.model.energy
    theta_s = q_to_theta(objective.data.s.x, energy)
    sort_idx_s = np.argsort(theta_s)
    interp_s = np.interp(
        common_theta, theta_s[sort_idx_s], objective.data.s.y[sort_idx_s]
    )
    theta_p = q_to_theta(objective.data.p.x, energy)
    sort_idx_p = np.argsort(theta_p)
    interp_p = np.interp(
        common_theta, theta_p[sort_idx_p], objective.data.p.y[sort_idx_p]
    )
    interpolated_data[energy] = {"s": interp_s, "p": interp_p}

# Plot 2D reflectivity maps
energies = sorted(interpolated_data.keys())
s_reflectivity_map = np.array([interpolated_data[e]["s"] for e in energies])
p_reflectivity_map = np.array([interpolated_data[e]["p"] for e in energies])
log_s_map = np.log(s_reflectivity_map)
log_p_map = np.log(p_reflectivity_map)
vmin = min(log_s_map.min(), log_p_map.min())
vmax = max(log_s_map.max(), log_p_map.max())

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6), sharey=True)
sns.heatmap(
    log_s_map.T,
    ax=ax1,
    cmap="Blues",
    cbar_kws={"label": "log(Reflectivity)"},
    vmin=vmin,
    vmax=vmax,
)
ax1.set_title("s-polarized Reflectivity Map")
ax1.set_ylabel("Theta (degrees)")
ax1.set_xlabel("Energy index")
sns.heatmap(
    log_p_map.T,
    ax=ax2,
    cmap="Reds",
    cbar_kws={"label": "log(Reflectivity)"},
    vmin=vmin,
    vmax=vmax,
)
ax2.set_title("p-polarized Reflectivity Map")
ax2.set_xlabel("Energy index")
plt.tight_layout()
plt.show()

# Plot anisotropy ratio map
anisotropy_ratio = (s_reflectivity_map - p_reflectivity_map) / (
    s_reflectivity_map + p_reflectivity_map
)
plt.figure(figsize=(8, 6))
sns.heatmap(
    anisotropy_ratio.T, cmap="coolwarm", center=0, cbar_kws={"label": "(s-p)/(s+p)"}
)
plt.title("Anisotropy Ratio Map")
plt.ylabel("Theta (degrees)")
plt.xlabel("Energy index")
plt.tight_layout()
plt.show()