## Multi-grain indexation based on diagram simulation and matching-rate.

In [None]:
import pprint
import math
import random
import timeit

from tqdm.autonotebook import tqdm
import matplotlib.pyplot as plt
import torch

from laueimproc.geometry import *
from laueimproc.geometry.indexation import *

### Reference constants and parameters
* `LATTICE` is the vector of the 6 lattice parameter $[a, b, c, \alpha, \beta, \gamma]$
* `PONI` are the detector parameters $[dist, poni_1, poni_2, rot_1, rot_2, rot_3]$

The convention adopted is the pyfai convention. Have a look on the documentation for more details.

In [None]:
EV = 1.60e-19  # 1 eV = EV J
RAD = math.pi / 180.0
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
LATTICES = torch.tensor([
    [5.1505e-10, 5.2166e-10, 5.3173e-10, torch.pi/2, 99.23*RAD, torch.pi/2],  # ZrO2
    [4.785e-10, 4.785e-10, 4.785e-10, torch.pi/2, torch.pi/2, 2*torch.pi/3],  # Al2O3
])
LATTICE = LATTICES[0]
PONI = torch.tensor([0.07, 73.4e-3, 73.4e-3, 0.0, -torch.pi/2, 0.0])  # mode laue detector on top
DETECTOR = {"shape": (2018, 2016), "pxl": 73.4e-6}  # shape is along d1 then d2
E_MIN = 5e3 * EV  # 5 keV
E_MAX = 25e3 * EV  # 5 keV

### Simulate a diagram

#### Random draw single grain, no strain

In [None]:
# parameters of simulation
eps_a_ref = eps_b_ref = eps_c_ref = 0.0
eps_alpha_ref = eps_beta_ref = eps_gamma_ref = 0.0
lattice_ref = LATTICE.clone()
rot1_ref = 2.0 * math.pi * random.random()
rot2_ref = 2.0 * math.pi * random.random()
rot3_ref = 2.0 * math.pi * random.random()

#### Random draw single grain, with strain

In [None]:
# parameters of simulation
eps_a_ref = random.random() * 2e-2 - 1e-2  # eps_a = (a - a0) / a0
eps_b_ref = random.random() * 2e-2 - 1e-2
eps_c_ref = random.random() * 2e-2 - 1e-2
eps_alpha_ref = random.random() * 2e-3 - 1e-3  # eps_alpha = tan(alpha - alpha_0)
eps_beta_ref = random.random() * 2e-3 - 1e-3
eps_gamma_ref = random.random() * 2e-3 - 1e-3
lattice_ref = LATTICE.clone()
lattice_ref[0] *= 1.0 + eps_a_ref
lattice_ref[1] *= 1.0 + eps_b_ref
lattice_ref[2] *= 1.0 + eps_c_ref
lattice_ref[3] += eps_alpha_ref
lattice_ref[4] += eps_beta_ref
lattice_ref[5] += eps_gamma_ref
rot1_ref = 2.0 * math.pi * random.random()
rot2_ref = 2.0 * math.pi * random.random()
rot3_ref = 2.0 * math.pi * random.random()

#### Random draw multi grain, no strain

In [None]:
NB_GRAIN = 5

# parameters of simulation
eps_a_ref = eps_b_ref = eps_c_ref = 0.0
eps_alpha_ref = eps_beta_ref = eps_gamma_ref = 0.0
lattice_ref = LATTICE.clone()
rot1_ref = 2.0 * torch.pi * torch.rand(NB_GRAIN)
rot2_ref = 2.0 * torch.pi * torch.rand(NB_GRAIN)
rot3_ref = 2.0 * torch.pi * torch.rand(NB_GRAIN)

#### Random draw multi grain, with strain

In [None]:
NB_GRAIN = 5

# parameters of simulation
eps_a_ref = torch.rand(NB_GRAIN) * 2e-2 - 1e-2  # eps_a = (a - a0) / a0
eps_b_ref = torch.rand(NB_GRAIN) * 2e-2 - 1e-2
eps_c_ref = torch.rand(NB_GRAIN) * 2e-2 - 1e-2
eps_alpha_ref = torch.rand(NB_GRAIN) * 2e-3 - 1e-3  # eps_alpha = tan(alpha - alpha_0)
eps_beta_ref = torch.rand(NB_GRAIN) * 2e-3 - 1e-3
eps_gamma_ref = torch.rand(NB_GRAIN) * 2e-3 - 1e-3
lattice_ref = torch.cat([
    (LATTICE[0] * 1.0 + eps_a_ref).unsqueeze(1),
    (LATTICE[1] * 1.0 + eps_b_ref).unsqueeze(1),
    (LATTICE[2] * 1.0 + eps_c_ref).unsqueeze(1),
    (LATTICE[3] + eps_alpha_ref).unsqueeze(1),
    (LATTICE[4] + eps_beta_ref).unsqueeze(1),
    (LATTICE[5] + eps_gamma_ref).unsqueeze(1),
], dim=1)
rot1_ref = 2.0 * torch.pi * torch.rand(NB_GRAIN)
rot2_ref = 2.0 * torch.pi * torch.rand(NB_GRAIN)
rot3_ref = 2.0 * torch.pi * torch.rand(NB_GRAIN)

#### Simulation of the diagram

In [None]:
# cast into torch tensor
eps_a_ref = torch.asarray(eps_a_ref).reshape(-1)
eps_b_ref = torch.asarray(eps_b_ref).reshape(-1)
eps_c_ref = torch.asarray(eps_c_ref).reshape(-1)
eps_alpha_ref = torch.asarray(eps_alpha_ref).reshape(-1)
eps_beta_ref = torch.asarray(eps_beta_ref).reshape(-1)
eps_gamma_ref = torch.asarray(eps_gamma_ref).reshape(-1)

rot1_ref = torch.asarray(rot1_ref).reshape(-1)
rot2_ref = torch.asarray(rot2_ref).reshape(-1)
rot3_ref = torch.asarray(rot3_ref).reshape(-1)

lattice_ref = lattice_ref.reshape(-1, 6)

In [None]:
# simulation
phi = torch.cat([rot1_ref.unsqueeze(-1), rot2_ref.unsqueeze(-1), rot3_ref.unsqueeze(-1)], dim=-1)
bragg = Geometry(lattice=LATTICE, phi=phi, e_min=E_MIN, e_max=E_MAX, poni=PONI)
with torch.no_grad():
    points, *_ = bragg.compute_cam(cam_size=(DETECTOR["pxl"]*DETECTOR["shape"][0], DETECTOR["pxl"]*DETECTOR["shape"][1]))

# add experimental noise and false detection
points_exp = points + 5 * DETECTOR["pxl"] * torch.randn_like(points)
cond = torch.rand_like(points_exp) < 0.05  # proba to faild the detection
points_exp[cond[..., 0], 0] = torch.rand_like(points_exp[cond[..., 0], 0]) * DETECTOR["shape"][0] * DETECTOR["pxl"]
points_exp[cond[..., 1], 1] = torch.rand_like(points_exp[cond[..., 1], 1]) * DETECTOR["shape"][1] * DETECTOR["pxl"]

# print informations
print(f"There are {len(points)} experimental spots.")
print("Lattice parameters:")
for i, lat in enumerate(lattice_ref):
    print(
        f"   {i} - "
        f"a={lat[0]:.3e}, b={lat[1]:.3e}, c={lat[2]:.3e}, alpha={lat[3]:.3f}, beta={lat[4]:.3f}, gamma={lat[5]:.3f} "
    )
print("Rotations:")
for i, (r1, r2, r3) in enumerate(zip(rot1_ref, rot2_ref, rot3_ref)):
    print(
        f"   {i} - "
        f"rot1={r1:.2f}, rot2={r2:.2f}, rot3={r3:.2f}"
    )

# display in pyfai convention
plt.title("Simulated experimental diagram")
plt.xlabel("d2")
plt.ylabel("d1")
plt.scatter(*(points.flip(-1) / DETECTOR["pxl"]).movedim(-1, 0), label="experimental")
plt.scatter(*(points_exp.flip(-1) / DETECTOR["pxl"]).movedim(-1, 0), label="noisy")
plt.legend()
plt.grid()
plt.show()

### NN pseudo Indexation
It it not a complete indexation, it is just a hkl class predictor

#### Initialise the model

In [None]:
# initialise the network
indexator = NNIndexator(LATTICES, hkl_max=4)
print(f"bins: {indexator.bins}")
print("hkl classes:")
pprint.pprint(indexator.families)

#### Training

In [None]:
# training
loss = torch.nn.CrossEntropyLoss()
optim = torch.optim.RAdam(indexator.weights, weight_decay=0.0005)

indexator.train()

epoch_log = tqdm(total=0, position=1, bar_format="{desc}")
for _ in tqdm(range(50), unit="epoch", desc="train"):
    # generate data
    hist, target = indexator.generate_training_batch(
        poni=PONI,
        cam_size=(DETECTOR["shape"][0] * DETECTOR["pxl"], DETECTOR["shape"][1] * DETECTOR["pxl"]),
        e_min=E_MIN,
        e_max=E_MAX,
        batch=256,
    )

    # update weights
    optim.zero_grad()
    pred = indexator(hist)
    output = loss(pred, target)
    output.backward()
    optim.step()

    # display
    epoch_log.set_description_str(f"Loss: mse={output:.6e}")

print("hkl class repartition:")
print(indexator._attrs["hkl_hist"])

#### Prediction

In [None]:
indexator.eval()
material, hkl, confidence = indexator.predict_hkl(points_exp, poni=PONI, indices=[0, 42])
print(f"material: {material}")
print(f"hkl class: {hkl}")
print(f"confidence: {confidence}")

### Brute force indexation
We work in the $u_q$ space, not in the detector space. First we have to transfere the points of the detector to the $u_q$ space.

#### Test all rotations

In [None]:
# hyperparameters
PHI_RES = 2.0 * RAD
PHI_MAX = 1.2 * RAD
E_MAX_INDEXATION = 10e3 * EV  # computing time increases sharply with energy

# brute force indexation
indexator = StupidIndexator(LATTICE, E_MAX_INDEXATION)
all_omega, all_rate = indexator(points_exp, poni=PONI, omega=PHI_RES, phi_max=PHI_MAX)

#### Selection

In [None]:
MIN_RATE = None  # or set int directely

min_rate = MIN_RATE or round(0.8 * float(all_rate[0]))
rate_index = torch.argmin((all_rate >= min_rate).view(torch.uint8))

print(f"{all_rate[rate_index-1]} <= rate <= {all_rate[0]}")

plt.title("Matching")
plt.ylabel("rate")
plt.plot(all_rate)
plt.hlines(all_rate[rate_index-1], 0, len(all_rate))
plt.plot(torch.linspace(0, len(all_rate), rate_index), all_rate[:rate_index])
plt.show()

#### Visualisation

In [None]:
# generate data from indexed
bragg = Geometry(lattice=LATTICE, phi=all_omega[0], e_max=E_MAX, poni=PONI)
bragg.requires_grad_(False)
points_indexed, *_ = bragg.compute_cam(cam_size=(DETECTOR["pxl"]*DETECTOR["shape"][0], DETECTOR["pxl"]*DETECTOR["shape"][1]))

# display in pyfai convention
plt.title("Indexation")
plt.xlabel("d2")
plt.ylabel("d1")
plt.scatter(*(points_exp.flip(-1) / DETECTOR["pxl"]).movedim(-1, 0), label="experimental")
plt.scatter(*(points_indexed.flip(-1) / DETECTOR["pxl"]).movedim(-1, 0), label="indexed")
plt.legend()
plt.grid()
plt.show()

### Refinement
It refines the oriantation $\omega$, and it can refine $a$, $b$, $c$ and $\alpha$, $\beta$, $\gamma$ as well.

In [None]:
# initialization
INDEX = 0
PHI_MAX = 0.5 * RAD

refiner = Refiner(LATTICE, all_omega[INDEX], points_exp, poni=PONI, e_min=E_MIN, e_max=E_MAX)
print(f"initial omega: {refiner.angle}")
print(f"initial lattice: {refiner.lattice}")
print(f"initial matching rate {refiner(PHI_MAX)}")

In [None]:
# refine, you can execute this cell several times
rate = refiner.refine(PHI_MAX, refine_abc=True, refine_shear=True)
print(f"new omega: {refiner.angle}")
print(f"new lattice: {refiner.lattice}")
print(f"new matching rate {rate}")

In [None]:
bragg = Geometry(lattice=refiner.lattice.to(torch.float32), phi=refiner.angle.to(torch.float32), e_max=E_MAX, poni=PONI)
with torch.no_grad():
    points_raffined, *_ = bragg.compute_cam(cam_size=(DETECTOR["pxl"]*DETECTOR["shape"][0], DETECTOR["pxl"]*DETECTOR["shape"][1]))

# display in pyfai convention
plt.title("Indexation")
plt.xlabel("d2")
plt.ylabel("d1")
plt.scatter(*(points_exp.flip(-1) / DETECTOR["pxl"]).movedim(-1, 0), label="experimental")
plt.scatter(*(points_indexed.flip(-1) / DETECTOR["pxl"]).movedim(-1, 0), label="indexed")
plt.scatter(*(points_raffined.flip(-1) / DETECTOR["pxl"]).movedim(-1, 0), label="raffined")
plt.legend()
plt.grid()
plt.show()