## Multi-grain diagram simulation.

In [None]:
import math
import random
import timeit

from tqdm.autonotebook import tqdm
import matplotlib.pyplot as plt
import torch

from laueimproc.geometry import *
from laueimproc.geometry.model import BraggModel
from laueimproc.geometry.indexation import StupidIndexator

### Reference constants and parameters
* `LATTICE` is the 3x2 matrix of the lattice parameter $[[a, b, c], [\alpha, \beta, \gamma]]$
* `PONI` are the detector parameters $[dist, poni_1, poni_2, rot_1, rot_2, rot_3]$

The convention adopted is the pyfai convention. Have a look on the documentation for more details.

In [None]:
LATTICE = torch.tensor([3.6e-10, 3.6e-10, 3.6e-10, torch.pi/2, torch.pi/2, torch.pi/2])  # copper
PONI = torch.tensor([0.07, 73.4e-3, 73.4e-3, 0.0, -torch.pi/2, 0.0])  # mode laue detector on top
DETECTOR = {"shape": (2018, 2016), "pxl": 73.4e-6}  # shape is along d1 then d2

In [None]:
EV = 1.60e-19  # 1 eV = EV J
RAD = math.pi / 180.0
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

### Associate the functions
* `laueimproc` provides atomic function to juggle with Bragg diffraction. It is your own stuff to compose them.
* Have a look on the `diffraction` tab of the api documentation.

In [None]:
def lattice_to_reciprocal(lattice):
    primitive = lattice_to_primitive(lattice)
    reciprocal = primitive_to_reciprocal(primitive)
    return reciprocal

def hkl_reciprocal_rot_to_uq(hkl, reciprocal, rot):
    reciprocal_rotated = rotate_crystal(reciprocal, rot)
    u_q = hkl_reciprocal_to_uq(hkl, reciprocal_rotated)
    return u_q
    
def uq_poni_to_detector(u_q, poni):
    u_f = uq_to_uf(u_q)
    point, dist = ray_to_detector(u_f, poni)
    point = point[dist > 0, :]  # ray wrong direction => virtual intersection
    point = point[point[..., 0] > 0, :]  # out of detector top
    point = point[point[..., 0] < DETECTOR["shape"][0] * DETECTOR["pxl"], :]  # out of detector bottom
    point = point[point[..., 1] > 0, :]  # out of detector left
    point = point[point[..., 1] < DETECTOR["shape"][1] * DETECTOR["pxl"], :]  # out of detector right
    return point

def detector_poni_to_uq(point, poni):
    u_f = detector_to_ray(point, poni)
    u_q = uf_to_uq(u_f)
    return u_q

In [None]:
def full_simulation(hkl, lattice, rot, poni):
    reciprocal = lattice_to_reciprocal(lattice)
    u_q = hkl_reciprocal_rot_to_uq(hkl, reciprocal, rot)
    point = uq_poni_to_detector(u_q, poni)
    return point

### Timing simulation

In [None]:
# simple timing

BATCH = 100  # number of simulated diagrams
ROT = torch.eye(3)

hkl = select_hkl(lattice_to_reciprocal(LATTICE), e_max=25e3*EV, keep_harmonics=False)

# case float64
lattice, rot, poni = LATTICE.clone().to(torch.float64), ROT.clone().to(torch.float64), PONI.clone().to(torch.float64)
speed = min(timeit.repeat(lambda: full_simulation(hkl, lattice, rot, poni), repeat=10, number=BATCH)) / BATCH
print(f"float64: it takes {speed*1e6:.2f}us by simulation <=> {1.0/speed:.2f}Hz")

# case float32
lattice, rot, poni = LATTICE.clone().to(torch.float32), ROT.clone().to(torch.float32), PONI.clone().to(torch.float32)
speed = min(timeit.repeat(lambda: full_simulation(hkl, lattice, rot, poni), repeat=10, number=BATCH)) / BATCH
print(f"float32: it takes {speed*1e6:.2f}us by simulation <=> {1.0/speed:.2f}Hz")

# case float64 batched
lattice, rot, poni = LATTICE.clone().to(torch.float64), ROT.clone().to(torch.float64), PONI.clone().to(torch.float64)
lattice = lattice[None, :].expand(BATCH, -1)
speed = min(timeit.repeat(lambda: full_simulation(hkl, lattice, rot, poni), repeat=10, number=1)) / BATCH
print(f"float64 batched: it takes {speed*1e6:.2f}us by simulation <=> {1.0/speed:.2f}Hz")

# case float32 batched
lattice, rot, poni = LATTICE.clone().to(torch.float32), ROT.clone().to(torch.float32), PONI.clone().to(torch.float32)
lattice = lattice[None, :].expand(BATCH, -1)
speed = min(timeit.repeat(lambda: full_simulation(hkl, lattice, rot, poni), repeat=10, number=1)) / BATCH
print(f"float32 batched: it takes {speed*1e6:.2f}us by simulation <=> {1.0/speed:.2f}Hz")

# case float64 batched compiled
lattice, rot, poni = LATTICE.clone().to(torch.float64), ROT.clone().to(torch.float64), PONI.clone().to(torch.float64)
lattice = lattice[None, :].expand(BATCH, -1)
full_simulation_comp = torch.compile(full_simulation, dynamic=False)
speed = min(timeit.repeat(lambda: full_simulation_comp(hkl, lattice, rot, poni), repeat=10, number=1)) / BATCH
print(f"float64 batched compiled: it takes {speed*1e6:.2f}us by simulation <=> {1.0/speed:.2f}Hz")

# case float32 batched compiled
lattice, rot, poni = LATTICE.clone().to(torch.float32), ROT.clone().to(torch.float32), PONI.clone().to(torch.float32)
lattice = lattice[None, :].expand(BATCH, -1)
full_simulation_comp = torch.compile(full_simulation, dynamic=False)
speed = min(timeit.repeat(lambda: full_simulation_comp(hkl, lattice, rot, poni), repeat=10, number=1)) / BATCH
print(f"float32 batched compiled: it takes {speed*1e6:.2f}us by simulation <=> {1.0/speed:.2f}Hz")

if DEVICE.type == "cuda":
    # case float64 batched compiled gpu
    lattice, rot, poni = (
        LATTICE.clone().to(dtype=torch.float64, device=DEVICE),
        ROT.clone().to(dtype=torch.float64, device=DEVICE),
        PONI.clone().to(dtype=torch.float64, device=DEVICE),
    )
    lattice = lattice[None, :].expand(BATCH, -1)
    full_simulation_comp = torch.compile(full_simulation, dynamic=False)
    speed = min(timeit.repeat(lambda: full_simulation_comp(hkl, lattice, rot, poni), repeat=10, number=1)) / BATCH
    print(f"float64 batched compiled gpu: it takes {speed*1e6:.2f}us by simulation <=> {1.0/speed:.2f}Hz")
    
    # case float32 batched compiled gpu
    lattice, rot, poni = (
        LATTICE.clone().to(dtype=torch.float32, device=DEVICE),
        ROT.clone().to(dtype=torch.float32, device=DEVICE),
        PONI.clone().to(dtype=torch.float32, device=DEVICE),
    )
    lattice = lattice[None, :].expand(BATCH, -1)
    full_simulation_comp = torch.compile(full_simulation, dynamic=False)
    speed = min(timeit.repeat(lambda: full_simulation_comp(hkl, lattice, rot, poni), repeat=10, number=1)) / BATCH
    print(f"float32 batched compiled gpu: it takes {speed*1e6:.2f}us by simulation <=> {1.0/speed:.2f}Hz")