In [1]:
import itertools
import multiprocessing as mp
import os

# os.environ["JAX_LOG_COMPILES"] = "1"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "0.9"

from functools import partial

import h5py
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

from adaptive_spacing import adaptive_geomspace, smart_grid_smooth
from template_generation_jax import (
    chunked,
    generate_ellipsoid_diffraction_complex,
    generate_parallelepiped_diffraction_complex,
    generate_spaced_rotations,
    q_from_xyz,
    xyyz_from_detector_geometry,
)

# Experiment explanation

<figure>
  <img src="/gpfs/exfel/exp/SPB/202503/p008047/scratch/Konstantin/image_1.png" alt="Diffraction pattern" width="1600">
  <figcaption><b>Figure 1.</b> Experiment </figcaption>
</figure>


<figure>
  <img src="/gpfs/exfel/exp/SPB/202503/p008047/scratch/Konstantin/image.png" alt="Diffraction pattern" width="1600">
  <figcaption><b>Figure 2.</b> Detector </figcaption>
</figure>


## Forward model of the measurement


$ I_{det} =  \mathcal{P}\left\{ \mathcal{S\left\{ M \left\{ |\mathrm{D_{ideal}}|^2 \right\} \right\}} \right\} $,
where $I_{det}$ is measured intensity in photons, $\mathcal{P}$ is a Poisson sampling, $\mathcal{S}$ is scaling by total amount of photons, $\mathcal{M}$ is masking operation and $\mathrm{D_{ideal}}$ is the ideal diffraction distribution calculated from shape parameters.

So in python it will be like:
```
I_masked =  (np.abs(D_simmulated)**2) * mask
I_scaled =  I_masked/I_masked.sum() * total_photons
I_det =  np.random.poisson(I_scaled)
```

Also, we need to simmulate 'ugly' particles such as agglomerates or particles with uneven density distribution. One way to do that is to modulate the idealized particle density with gaussian blured random distribution with variable sigma, where larger sigma corresponds to more ideal particles.


in that case detector intensity can be calculated as 
```
random_modulation  = 1+(np.random.rand(360,360))*np.exp(1j*np.random.rand(360,360)*np.pi)# it's a random complex number added to the unity transmission
random_modulation = gaussian_filter(random_modulation,sigma=1)
# here sigma controls the ugliness of the particle, where ~ 5 is almost perfect, 1-2 is border  region and 0.1-0.5 is very ugly 
D_modulated = fftn(ifftn(D_modulated,axes=(-1,-2))*random_modulation,axes=(-1,-2))
I_masked =  (np.abs(D_simmulated)**2) * mask
I_scaled =  I_masked/I_masked.sum() * total_photons
I_det =  np.random.poisson(I_scaled)

```

## Key experimental parameters


- Detector to sample distance = 1.2 m
- Beam Diameter = 300 nm
- Particle Beam = 20 μm
- Pixel Size = 200 μm
- Photon Energy = 6000 ± 100 eV
- Wavelength = 0.20664e-9 m 

# Template generation

In [2]:
template_file = h5py.File(
    "/gpfs/exfel/exp/SPB/202503/p008047/scratch/Konstantin/templates/simulated_data.h5",
    "a",
)
diffraction_patterns_group = template_file.create_group("diffraction_patterns")

## Geometry

In [3]:
xy, z = xyyz_from_detector_geometry(
    pixel_size_m=200e-6, detector_distance_m=1.2, n_pixels_y=360, n_pixels_x=360
)
q_xyz = q_from_xyz(xy, z, 0.207e-9)

geometry_group = template_file.create_group("experiment_geometry")
geometry_group.create_dataset("xy", data=xy)
geometry_group.create_dataset("z", data=z)
geometry_group.create_dataset("q_xyz", data=q_xyz)
geometry_group.attrs["pixel_size_m"] = 200e-6
geometry_group.attrs["detector_distance_m"] = 1.2
geometry_group.attrs["n_pixels_y"] = 360
geometry_group.attrs["n_pixels_x"] = 360
geometry_group.attrs["wavelength_m"] = 0.207e-9

## Spheres




In [4]:
# define sphere diffraction generator


def generate_sphere_diffraction(q_xyz, r):
    return generate_ellipsoid_diffraction_complex(q_xyz, r, r, r, 0, 0, 0)


generate_sphere_diffraction_v = jax.jit(
    jax.vmap(generate_sphere_diffraction, in_axes=(None, 0))
)

In [5]:
sphere_radius_distribution = jnp.linspace(1, 150, 150)
spheres_patterns = generate_sphere_diffraction_v(
    q_xyz, sphere_radius_distribution * 1e-9
)

spheres_group = diffraction_patterns_group.create_group("spheres")
ds = spheres_group.create_dataset("A", data=sphere_radius_distribution)
ds.attrs["description"] = "A in nm"
spheres_group.create_dataset("patterns", data=spheres_patterns)

<HDF5 dataset "patterns": shape (150, 360, 360), type "<f4">

In [6]:
del spheres_patterns, sphere_radius_distribution

## Ellipsoids

In [7]:
# define in_plane ellipsoid diffraction generator


def generate_ellipsoid_in_plane_diffraction(q_xyz, A, B, phi):
    return generate_ellipsoid_diffraction_complex(q_xyz, A, B, A, 0, 0, phi)


generate_ellipsoid_in_plane_diffraction_v = jax.jit(
    jax.vmap(generate_ellipsoid_in_plane_diffraction, in_axes=(None, 0, 0, 0))
)

In [8]:
semi_axes = adaptive_geomspace(1, 150, 20, 0.3, 2, 30)


size_low = jnp.tile(semi_axes, semi_axes.shape[0])
size_high = jnp.repeat(semi_axes, semi_axes.shape[0])
size_mask = size_low < size_high
size_mask.sum()

size_low = size_low[size_mask]
size_high = size_high[size_mask]

sizes = jnp.vstack((size_high, size_low))

in_plane_rotations = jnp.linspace(0, 180, 36, endpoint=False)


sizes = jnp.repeat(sizes, in_plane_rotations.shape[0], axis=1)
in_plane_rotations = jnp.tile(in_plane_rotations, size_high.shape[0])

all_parameters = jnp.vstack((sizes, in_plane_rotations))


ellipsoids_group = diffraction_patterns_group.create_group("ellipsoids")
meta = ellipsoids_group.create_dataset("A_B_phi", data=all_parameters)
meta.attrs["description"] = "A and B semi-axes in nm, phi in degrees"

In [9]:
dset = ellipsoids_group.create_dataset(
    "patterns",
    shape=(all_parameters.shape[1], 360, 360),  # full size
)

print("starting generation")
chunk_size = 10000

for idx, data in enumerate(chunked(all_parameters.T.tolist(), chunk_size)):
    print(idx, np.array(data).T.shape)
    d_arr = jnp.array(np.array(data).T)
    generated = generate_ellipsoid_in_plane_diffraction_v(
        q_xyz,
        d_arr[0, :] * 1e-9,
        d_arr[1, :] * 1e-9,
        d_arr[2, :],
    )

    dset[idx * chunk_size : idx * chunk_size + d_arr.shape[1], :, :] = np.array(
        generated
    )

starting generation
0 (3, 6840)


In [10]:
del generated, all_parameters

## Squares

In [11]:
# define sphere diffraction generator


def generate_square_diffraction(q_xyz, r, phi):
    return generate_parallelepiped_diffraction_complex(q_xyz, r, r, r, 0, 0, phi)


generate_square_diffraction_v = jax.jit(
    jax.vmap(generate_square_diffraction, in_axes=(None, 0, 0))
)

In [12]:
squares_radius_distribution = jnp.linspace(1, 150, 150)
squares_rotation_distribution = jnp.linspace(0, 90, 15, endpoint=False)


sizes = jnp.repeat(
    squares_radius_distribution,
    squares_rotation_distribution.shape[0],
)
squares_rotation_distribution = jnp.tile(
    squares_rotation_distribution, squares_radius_distribution.shape[0]
)


squares_parameters = jnp.vstack((sizes, squares_rotation_distribution))

squares_patterns = generate_square_diffraction_v(
    q_xyz, squares_parameters[0, :] * 1e-9, squares_parameters[1, :]
)

squares_group = diffraction_patterns_group.create_group("squares")


ds = squares_group.create_dataset("side_nm angle_deg", data=squares_parameters)
ds.attrs["description"] = "A phi semi-axis in nm, phi in degrees"
squares_group.create_dataset("patterns", data=squares_patterns)

<HDF5 dataset "patterns": shape (2250, 360, 360), type "<f4">

In [13]:
del squares_group, squares_parameters

## Rectangles

In [14]:
# define in_plane rectangle diffraction generator


def generate_rect_in_plane_diffraction(q_xyz, A, B, phi):
    return generate_parallelepiped_diffraction_complex(q_xyz, A, B, A, 0, 0, phi)


generate_rectangle_in_plane_diffraction_v = jax.jit(
    jax.vmap(generate_rect_in_plane_diffraction, in_axes=(None, 0, 0, 0))
)

In [15]:
semi_axes = adaptive_geomspace(1, 150, 20, 0.3, 2, 30)


size_low = jnp.tile(semi_axes, semi_axes.shape[0])
size_high = jnp.repeat(semi_axes, semi_axes.shape[0])
size_mask = size_low < size_high
size_mask.sum()

size_low = size_low[size_mask]
size_high = size_high[size_mask]

sizes = jnp.vstack((size_high, size_low))

in_plane_rotations = jnp.linspace(0, 180, 36, endpoint=False)


sizes = jnp.repeat(sizes, in_plane_rotations.shape[0], axis=1)
in_plane_rotations = jnp.tile(in_plane_rotations, size_high.shape[0])

all_parameters = jnp.vstack((sizes, in_plane_rotations))


rectangles_group = diffraction_patterns_group.create_group("rectangles")
meta = rectangles_group.create_dataset("A_B_phi", data=all_parameters)
meta.attrs["description"] = "A and B semi-axes in nm, phi in degrees"

In [16]:
dset = rectangles_group.create_dataset(
    "patterns",
    shape=(all_parameters.shape[1], 360, 360),  # full size
)

print("starting generation")
chunk_size = 10000

for idx, data in enumerate(chunked(all_parameters.T.tolist(), chunk_size)):
    print(idx, np.array(data).T.shape)
    d_arr = jnp.array(np.array(data).T)
    generated = generate_rectangle_in_plane_diffraction_v(
        q_xyz,
        d_arr[0, :] * 1e-9,
        d_arr[1, :] * 1e-9,
        d_arr[2, :],
    )

    dset[idx * chunk_size : idx * chunk_size + d_arr.shape[1], :, :] = np.array(
        generated
    )

starting generation
0 (3, 6840)


In [17]:
del generated, all_parameters

## Paralelipipeds

In [18]:
generate_parallelepiped_diffraction_v = jax.jit(
    jax.vmap(generate_parallelepiped_diffraction_complex, in_axes=(None, 0, 0, 0, 0, 0, 0))
)
# jax.jit(jax.vmap(( generate_parallelepiped_diffraction,(None,0,0,0,0,0,0))))

In [19]:
from template_generation_jax import generate_spaced_rotations

In [20]:
# from template_generation_jax import generate_spaced_rotations

all_rotations_rad: Array = generate_spaced_rotations()
all_rotations_deg = jnp.degrees(all_rotations_rad)


size = adaptive_geomspace(2, 170, 20, 0.3, 1, 30)

size_high = jnp.repeat(size, size.shape[0])
size_med = jnp.tile(size, size.shape[0])

mask = size_med <= size_high
sizes = jnp.vstack((size_high[mask], size_med[mask]))

size_low = jnp.tile(size, sizes.shape[1])
sizes = jnp.repeat(sizes, size.shape[0], axis=1)

mask = size_low <= sizes[1, :]

sizes = jnp.vstack((sizes[:, mask], size_low[mask]))

sh = sizes.shape[1]

sizes = jnp.repeat(sizes, all_rotations_deg.shape[1], axis=1)
all_rotations_deg = jnp.tile(all_rotations_deg, sh)
all_parameters = jnp.vstack((sizes, all_rotations_deg))

paralelipiped_group = diffraction_patterns_group.create_group("paralelipiped")
meta = paralelipiped_group.create_dataset("A_B_C_phi_theta_omega_", data=all_parameters)
meta.attrs["description"] = "A B C semi-axes in nm,  phi theta and omega in degrees"


print(all_parameters.shape)

(6, 558600)


In [21]:
dset = paralelipiped_group.create_dataset(
    "patterns",
    shape=(all_parameters.shape[1], 360, 360),  # full size
)

print("starting generation")
chunk_size = 10000

for idx, data in enumerate(chunked(all_parameters.T.tolist(), chunk_size)):
    print(idx, np.array(data).T.shape)
    d_arr = jnp.array(np.array(data).T)
    generated = generate_parallelepiped_diffraction_v(
        q_xyz,
        d_arr[0, :] * 1e-9,
        d_arr[1, :] * 1e-9,
        d_arr[2, :] * 1e-9,
        d_arr[3, :],
        d_arr[4, :],
        d_arr[5, :],
    )

    dset[idx * chunk_size : idx * chunk_size + d_arr.shape[1], :, :] = np.array(
        generated
    )

starting generation
0 (6, 10000)
1 (6, 10000)
2 (6, 10000)
3 (6, 10000)
4 (6, 10000)
5 (6, 10000)
6 (6, 10000)
7 (6, 10000)
8 (6, 10000)
9 (6, 10000)
10 (6, 10000)
11 (6, 10000)
12 (6, 10000)
13 (6, 10000)
14 (6, 10000)
15 (6, 10000)
16 (6, 10000)
17 (6, 10000)
18 (6, 10000)
19 (6, 10000)
20 (6, 10000)
21 (6, 10000)
22 (6, 10000)
23 (6, 10000)
24 (6, 10000)
25 (6, 10000)
26 (6, 10000)
27 (6, 10000)
28 (6, 10000)
29 (6, 10000)
30 (6, 10000)
31 (6, 10000)
32 (6, 10000)
33 (6, 10000)
34 (6, 10000)
35 (6, 10000)
36 (6, 10000)
37 (6, 10000)
38 (6, 10000)
39 (6, 10000)
40 (6, 10000)
41 (6, 10000)
42 (6, 10000)
43 (6, 10000)
44 (6, 10000)
45 (6, 10000)
46 (6, 10000)
47 (6, 10000)
48 (6, 10000)
49 (6, 10000)
50 (6, 10000)
51 (6, 10000)
52 (6, 10000)
53 (6, 10000)
54 (6, 10000)
55 (6, 8600)


In [22]:
del generated, all_parameters

## Add masks and photons scaling 

In [23]:
# Read some data and masks for illustration
with h5py.File(
    "/gpfs/exfel/exp/SPB/202503/p008047/scratch/Konstantin/angular_correlations_run_76.h5",
    "r",
) as f:
    masks = f["masks_centered"][:]
    intensities = f["hits_sym"][:, ...].sum(axis=(-1, -2))

measured_parameters_group = template_file.create_group("measured_parameters")
measured_parameters_group.create_dataset("masks", data=masks)
measured_parameters_group.create_dataset("scattered_photons", data=intensities)

<HDF5 dataset "scattered_photons": shape (7966,), type "<i8">

In [24]:
template_file.close()

## Add Ugly patterns 

Uggly patterns can be obtained from the simmullated ones by modulation, check ***Experment_explanation.ipynb***