## Multi-grain indexation based on diagram simulation and matching-rate.

In [None]:
import math
import random
import timeit

from tqdm.autonotebook import tqdm
import matplotlib.pyplot as plt
import torch

from laueimproc.geometry import *

### Reference constants and parameters
* `LATTICE` is the 3x2 matrix of the lattice parameter $[[a, b, c], [\alpha, \beta, \gamma]]$
* `PONI` are the detector parameters $[dist, poni_1, poni_2, rot_1, rot_2, rot_3]$

The convention adopted is the pyfai convention. Have a look on the documentation for more details.

In [None]:
LATTICE = torch.tensor([3.6e-10, 3.6e-10, 3.6e-10, torch.pi/2, torch.pi/2, torch.pi/2])  # copper
PONI = torch.tensor([0.07, 73.4e-3, 73.4e-3, 0.0, -torch.pi/2, 0.0])  # mode laue detector on top
DETECTOR = {"shape": (2018, 2016), "pxl": 73.4e-6}  # shape is along d1 then d2

In [None]:
EV = 1.60e-19  # 1 eV = EV J
RAD = math.pi / 180.0
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Associate the functions
* `laueimproc` provides atomic function to juggle with Bragg diffraction. It is your own stuff to compose them.
* Have a look on the `diffraction` tab of the api documentation.

In [None]:
def lattice_to_reciprocal(lattice):
    primitive = lattice_to_primitive(lattice)
    reciprocal = primitive_to_reciprocal(primitive)
    return reciprocal

def hkl_reciprocal_rot_to_uq(hkl, reciprocal, rot):
    reciprocal_rotated = rotate_crystal(reciprocal, rot)
    u_q = hkl_reciprocal_to_uq(hkl, reciprocal_rotated)
    return u_q
    
def uq_poni_to_detector(u_q, poni):
    u_f = uq_to_uf(u_q)
    point, dist = ray_to_detector(u_f, poni)
    point = point[dist > 0, :]  # ray wrong direction => virtual intersection
    point = point[point[..., 0] > 0, :]  # out of detector top
    point = point[point[..., 0] < DETECTOR["shape"][0] * DETECTOR["pxl"], :]  # out of detector bottom
    point = point[point[..., 1] > 0, :]  # out of detector left
    point = point[point[..., 1] < DETECTOR["shape"][1] * DETECTOR["pxl"], :]  # out of detector right
    return point

def detector_poni_to_uq(point, poni):
    u_f = detector_to_ray(point, poni)
    u_q = uf_to_uq(u_f)
    return u_q

In [None]:
def full_simulation(hkl, lattice, rot, poni):
    reciprocal = lattice_to_reciprocal(lattice)
    u_q = hkl_reciprocal_rot_to_uq(hkl, reciprocal, rot)
    point = uq_poni_to_detector(u_q, poni)
    return point

#### Timing comparison

In [None]:
# simple timing

BATCH = 100  # number of simulated diagrams
ROT = torch.eye(3)

hkl = select_hkl(LATTICE, e_max=25e3*EV, keep_harmonics=False)

# case float64
lattice, rot, poni = LATTICE.clone().to(torch.float64), ROT.clone().to(torch.float64), PONI.clone().to(torch.float64)
speed = min(timeit.repeat(lambda: full_simulation(hkl, lattice, rot, poni), repeat=10, number=BATCH)) / BATCH
print(f"float64: it takes {speed*1e6:.2f}us by simulation <=> {1.0/speed:.2f}Hz")

# case float32
lattice, rot, poni = LATTICE.clone().to(torch.float32), ROT.clone().to(torch.float32), PONI.clone().to(torch.float32)
speed = min(timeit.repeat(lambda: full_simulation(hkl, lattice, rot, poni), repeat=10, number=BATCH)) / BATCH
print(f"float32: it takes {speed*1e6:.2f}us by simulation <=> {1.0/speed:.2f}Hz")

# case float64 batched
lattice, rot, poni = LATTICE.clone().to(torch.float64), ROT.clone().to(torch.float64), PONI.clone().to(torch.float64)
lattice = lattice[None, :].expand(BATCH, -1)
speed = min(timeit.repeat(lambda: full_simulation(hkl, lattice, rot, poni), repeat=10, number=1)) / BATCH
print(f"float64 batched: it takes {speed*1e6:.2f}us by simulation <=> {1.0/speed:.2f}Hz")

# case float32 batched
lattice, rot, poni = LATTICE.clone().to(torch.float32), ROT.clone().to(torch.float32), PONI.clone().to(torch.float32)
lattice = lattice[None, :].expand(BATCH, -1)
speed = min(timeit.repeat(lambda: full_simulation(hkl, lattice, rot, poni), repeat=10, number=1)) / BATCH
print(f"float32 batched: it takes {speed*1e6:.2f}us by simulation <=> {1.0/speed:.2f}Hz")

# case float64 batched compiled
lattice, rot, poni = LATTICE.clone().to(torch.float64), ROT.clone().to(torch.float64), PONI.clone().to(torch.float64)
lattice = lattice[None, :].expand(BATCH, -1)
full_simulation_comp = torch.compile(full_simulation, dynamic=False)
speed = min(timeit.repeat(lambda: full_simulation_comp(hkl, lattice, rot, poni), repeat=10, number=1)) / BATCH
print(f"float64 batched compiled: it takes {speed*1e6:.2f}us by simulation <=> {1.0/speed:.2f}Hz")

# case float32 batched compiled
lattice, rot, poni = LATTICE.clone().to(torch.float32), ROT.clone().to(torch.float32), PONI.clone().to(torch.float32)
lattice = lattice[None, :].expand(BATCH, -1)
full_simulation_comp = torch.compile(full_simulation, dynamic=False)
speed = min(timeit.repeat(lambda: full_simulation_comp(hkl, lattice, rot, poni), repeat=10, number=1)) / BATCH
print(f"float32 batched compiled: it takes {speed*1e6:.2f}us by simulation <=> {1.0/speed:.2f}Hz")

if DEVICE.type == "cuda":
    # case float64 batched compiled gpu
    lattice, rot, poni = (
        LATTICE.clone().to(dtype=torch.float64, device=DEVICE),
        ROT.clone().to(dtype=torch.float64, device=DEVICE),
        PONI.clone().to(dtype=torch.float64, device=DEVICE),
    )
    lattice = lattice[None, :].expand(BATCH, -1)
    full_simulation_comp = torch.compile(full_simulation, dynamic=False)
    speed = min(timeit.repeat(lambda: full_simulation_comp(hkl, lattice, rot, poni), repeat=10, number=1)) / BATCH
    print(f"float64 batched compiled gpu: it takes {speed*1e6:.2f}us by simulation <=> {1.0/speed:.2f}Hz")
    
    # case float32 batched compiled gpu
    lattice, rot, poni = (
        LATTICE.clone().to(dtype=torch.float32, device=DEVICE),
        ROT.clone().to(dtype=torch.float32, device=DEVICE),
        PONI.clone().to(dtype=torch.float32, device=DEVICE),
    )
    lattice = lattice[None, :].expand(BATCH, -1)
    full_simulation_comp = torch.compile(full_simulation, dynamic=False)
    speed = min(timeit.repeat(lambda: full_simulation_comp(hkl, lattice, rot, poni), repeat=10, number=1)) / BATCH
    print(f"float32 batched compiled gpu: it takes {speed*1e6:.2f}us by simulation <=> {1.0/speed:.2f}Hz")

### Simulate a diagram

#### Random draw single grain, no strain

In [None]:
# parameters of simulation
eps_a_ref = eps_b_ref = eps_c_ref = 0.0
eps_alpha_ref = eps_beta_ref = eps_gamma_ref = 0.0
lattice_ref = LATTICE.clone()
rot1_ref = 2.0 * math.pi * random.random()
rot2_ref = 2.0 * math.pi * random.random()
rot3_ref = 2.0 * math.pi * random.random()

#### Random draw single grain, with strain

In [None]:
# parameters of simulation
eps_a_ref = random.random() * 2e-2 - 1e-2  # eps_a = (a - a0) / a0
eps_b_ref = random.random() * 2e-2 - 1e-2
eps_c_ref = random.random() * 2e-2 - 1e-2
eps_alpha_ref = random.random() * 2e-3 - 1e-3  # eps_alpha = tan(alpha - alpha_0)
eps_beta_ref = random.random() * 2e-3 - 1e-3
eps_gamma_ref = random.random() * 2e-3 - 1e-3
lattice_ref = LATTICE.clone()
lattice_ref[0] *= 1.0 + eps_a_ref
lattice_ref[1] *= 1.0 + eps_b_ref
lattice_ref[2] *= 1.0 + eps_c_ref
lattice_ref[3] += eps_alpha_ref
lattice_ref[4] += eps_beta_ref
lattice_ref[5] += eps_gamma_ref
rot1_ref = 2.0 * math.pi * random.random()
rot2_ref = 2.0 * math.pi * random.random()
rot3_ref = 2.0 * math.pi * random.random()

#### Random draw multi grain, no strain

In [None]:
NB_GRAIN = 5

# parameters of simulation
eps_a_ref = eps_b_ref = eps_c_ref = 0.0
eps_alpha_ref = eps_beta_ref = eps_gamma_ref = 0.0
lattice_ref = LATTICE.clone()
rot1_ref = 2.0 * torch.pi * torch.rand(NB_GRAIN)
rot2_ref = 2.0 * torch.pi * torch.rand(NB_GRAIN)
rot3_ref = 2.0 * torch.pi * torch.rand(NB_GRAIN)

#### Random draw multi grain, with strain

In [None]:
NB_GRAIN = 5

# parameters of simulation
eps_a_ref = torch.rand(NB_GRAIN) * 2e-2 - 1e-2  # eps_a = (a - a0) / a0
eps_b_ref = torch.rand(NB_GRAIN) * 2e-2 - 1e-2
eps_c_ref = torch.rand(NB_GRAIN) * 2e-2 - 1e-2
eps_alpha_ref = torch.rand(NB_GRAIN) * 2e-3 - 1e-3  # eps_alpha = tan(alpha - alpha_0)
eps_beta_ref = torch.rand(NB_GRAIN) * 2e-3 - 1e-3
eps_gamma_ref = torch.rand(NB_GRAIN) * 2e-3 - 1e-3
lattice_ref = torch.cat([
    (LATTICE[0] * 1.0 + eps_a_ref).unsqueeze(1),
    (LATTICE[1] * 1.0 + eps_b_ref).unsqueeze(1),
    (LATTICE[2] * 1.0 + eps_c_ref).unsqueeze(1),
    (LATTICE[3] + eps_alpha_ref).unsqueeze(1),
    (LATTICE[4] + eps_beta_ref).unsqueeze(1),
    (LATTICE[5] + eps_gamma_ref).unsqueeze(1),
], dim=1)
rot1_ref = 2.0 * torch.pi * torch.rand(NB_GRAIN)
rot2_ref = 2.0 * torch.pi * torch.rand(NB_GRAIN)
rot3_ref = 2.0 * torch.pi * torch.rand(NB_GRAIN)

#### Simulation of the diagram

In [None]:
# cast into torch tensor
eps_a_ref = torch.asarray(eps_a_ref).reshape(-1)
eps_b_ref = torch.asarray(eps_b_ref).reshape(-1)
eps_c_ref = torch.asarray(eps_c_ref).reshape(-1)
eps_alpha_ref = torch.asarray(eps_alpha_ref).reshape(-1)
eps_beta_ref = torch.asarray(eps_beta_ref).reshape(-1)
eps_gamma_ref = torch.asarray(eps_gamma_ref).reshape(-1)

rot1_ref = torch.asarray(rot1_ref).reshape(-1)
rot2_ref = torch.asarray(rot2_ref).reshape(-1)
rot3_ref = torch.asarray(rot3_ref).reshape(-1)

lattice_ref = lattice_ref.reshape(-1, 6)

In [None]:
# simulation
rot_ref = angle_to_rot(rot1_ref, rot2_ref, rot3_ref, cartesian_product=False)
hkl = select_hkl(lattice_ref, e_max=25e3*EV, keep_harmonics=False)
if len(lattice_ref) != 1 and len(rot_ref) != 1:  # default behavour is cartesian product
    points = torch.cat([
        full_simulation(hkl, lattice_ref_, rot_ref_, PONI)
        for lattice_ref_, rot_ref_ in zip(lattice_ref, rot_ref)
    ])
else:
    points = full_simulation(hkl, lattice_ref, rot_ref, PONI)

# add experimental noise and false detection
points_exp = points + 5 * DETECTOR["pxl"] * torch.randn_like(points)
cond = torch.rand_like(points_exp) < 0.05  # proba to faild the detection
points_exp[cond[..., 0], 0] = torch.rand_like(points_exp[cond[..., 0], 0]) * DETECTOR["shape"][0] * DETECTOR["pxl"]
points_exp[cond[..., 1], 1] = torch.rand_like(points_exp[cond[..., 1], 1]) * DETECTOR["shape"][1] * DETECTOR["pxl"]

# print informations
print(f"There are {len(points)} experimental spots.")
print("Lattice parameters:")
for i, lat in enumerate(lattice_ref):
    print(
        f"   {i} - "
        f"a={lat[0]:.3e}, b={lat[1]:.3e}, c={lat[2]:.3e}, alpha={lat[3]:.3f}, beta={lat[4]:.3f}, gamma={lat[5]:.3f} "
    )
print("Rotations:")
for i, (r1, r2, r3) in enumerate(zip(rot1_ref, rot2_ref, rot3_ref)):
    print(
        f"   {i} - "
        f"rot1={r1:.2f}, rot2={r2:.2f}, rot3={r3:.2f}"
    )

# display in pyfai convention
plt.title("Simulated experimental diagram")
plt.xlabel("d2")
plt.ylabel("d1")
plt.scatter(*(points.flip(-1) / DETECTOR["pxl"]).movedim(-1, 0), label="experimental")
plt.scatter(*(points_exp.flip(-1) / DETECTOR["pxl"]).movedim(-1, 0), label="noisy")
plt.legend()
plt.grid()
plt.show()

### Indexation

#### Preparation
We work in the $u_q$ space, not in the detector space. First we have to transfere the points of the detector to the $u_q$ space.

In [None]:
# hyperparameters
ANGLE_RESOL = 0.6 * RAD
ANGLE_MAX_MATCHING = 0.4 * RAD
E_MAX = 15e3 * EV

# compute invariants parameters
all_rot = angle_to_rot(
    torch.arange(-torch.pi / 4, torch.pi / 4, ANGLE_RESOL),
    torch.arange(-torch.pi / 4, torch.pi / 4, ANGLE_RESOL),
    torch.arange(-torch.pi / 4, torch.pi / 4, ANGLE_RESOL),
).reshape(-1, 3, 3)
hkl = select_hkl(LATTICE, e_max=E_MAX, keep_harmonics=False)  # (d, 3)
uq_exp = detector_poni_to_uq(points_exp, PONI).movedim(0, -2).contiguous()  # (3, n)
reciprocal = lattice_to_reciprocal(LATTICE)  # (3, 3)

# compile functions
hkl_reciprocal_rot_to_uq_comp = torch.compile(hkl_reciprocal_rot_to_uq, fullgraph=True, dynamic=False)

#### Brute force indexation

In [None]:
# simulate
BATCH_SIZE = 1000

all_rate = torch.empty(*torch.broadcast_shapes(uq_exp.shape[:-2], all_rot.shape[:-2]))
for i in tqdm(range(0, len(all_rot), BATCH_SIZE), unit_scale=BATCH_SIZE):
    rot_batch = all_rot[i:i+BATCH_SIZE]
    uq_simul = hkl_reciprocal_rot_to_uq_comp(hkl, reciprocal, rot_batch)
    uq_simul = uq_simul.movedim(0, -2).contiguous()
    all_rate[i:i+BATCH_SIZE] = compute_matching_rate(uq_exp, uq_simul, phi_max=ANGLE_MAX_MATCHING)

#### Selection

In [None]:
MIN_RATE = None  # or set int directely

sorted_rate_indices = torch.argsort(all_rate, descending=True)
min_rate = MIN_RATE or round(0.8 * float(all_rate[sorted_rate_indices[0]]))
rate_indices = sorted_rate_indices[:torch.argmin((all_rate[sorted_rate_indices] >= min_rate).view(torch.uint8))]

print(f"{all_rate[rate_indices[-1]]} <= rate <= {all_rate[rate_indices[0]]}")

plt.title("Matching")
plt.ylabel("rate")
plt.plot(all_rate[sorted_rate_indices])
plt.hlines(all_rate[rate_indices[-1]], 0, len(all_rate))
plt.plot(torch.linspace(0, len(all_rate), len(rate_indices)), all_rate[rate_indices])
plt.show()

#### Visualisation

In [None]:
points_indexed = full_simulation(hkl, LATTICE, all_rot[rate_indices[0]], PONI)

# display in pyfai convention
plt.title("Indexation")
plt.xlabel("d2")
plt.ylabel("d1")
plt.scatter(*(points_exp.flip(-1) / DETECTOR["pxl"]).movedim(-1, 0), label="experimental")
plt.scatter(*(points_indexed.flip(-1) / DETECTOR["pxl"]).movedim(-1, 0), label="indexed")
plt.legend()
plt.grid()
plt.show()

### Raffinement

In [None]:
INDEX = 0

# pick out initial parameters
rot1_init, rot2_init, rot3_init = rot_to_angle(all_rot[rate_indices][INDEX])
lattice_init = LATTICE.unsqueeze(0).expand(len(rate_indices), -1)[INDEX]  # (n, 6)
rot1 = rot1_init.clone()
rot1.requires_grad = True
rot2 = rot2_init.clone()
rot2.requires_grad = True
rot3 = rot3_init.clone()
rot3.requires_grad = True
lattice = lattice_init.clone()
lattice.requires_grad = True

# maximisation of matching rate by gradient climbing
optim = torch.optim.SGD([rot1, rot2, rot3], lr=math.radians(1e-5))
for _ in range(100):
    # compute matching rate
    reciprocal = lattice_to_reciprocal(lattice)
    rot = angle_to_rot(rot1, rot2, rot3, cartesian_product=False)
    uq_simul = hkl_reciprocal_rot_to_uq_comp(hkl, reciprocal, rot)
    uq_simul = uq_simul.movedim(0, -2).contiguous()
    rate = compute_matching_rate_continuous(uq_exp, uq_simul, phi_max=ANGLE_MAX_MATCHING)
    print(f"rate={rate}")
    
    # compute gradient
    optim.zero_grad()
    (-rate).backward()
    # print(f"grad=({rot1.grad}, {rot2.grad}, {rot3.grad})")
    rot1.grad = torch.clamp(torch.nan_to_num(rot1.grad), -1e2, 1e2)
    rot2.grad = torch.clamp(torch.nan_to_num(rot2.grad), -1e2, 1e2)
    rot3.grad = torch.clamp(torch.nan_to_num(rot3.grad), -1e2, 1e2)
    # rot1.grad = torch.clamp(rot1.grad, -1.0, 1.0)
    # rot2.grad = torch.clamp(rot2.grad, -1.0, 1.0)
    # rot3.grad = torch.clamp(rot3.grad, -1.0, 1.0)
    
    # update values
    optim.step()

#### Visualisation

In [None]:
points_indexed = full_simulation(hkl, LATTICE, all_rot[rate_indices[0]], PONI)
rot_raffined = angle_to_rot(rot1, rot2, rot3, cartesian_product=False).detach()
points_raffined = full_simulation(hkl, LATTICE, rot_raffined, PONI)

# display in pyfai convention
plt.title("Indexation")
plt.xlabel("d2")
plt.ylabel("d1")
plt.scatter(*(points_exp.flip(-1) / DETECTOR["pxl"]).movedim(-1, 0), label="experimental")
plt.scatter(*(points_indexed.flip(-1) / DETECTOR["pxl"]).movedim(-1, 0), label="indexed")
plt.scatter(*(points_raffined.flip(-1) / DETECTOR["pxl"]).movedim(-1, 0), label="raffined")
plt.legend()
plt.grid()
plt.show()