## Find the weighted distribution of a 2d gaussian in a spot image

In [None]:
import math

import matplotlib.pyplot as plt
import torch
import tqdm.autonotebook as tqdm

import laueimproc

In [None]:
def concat_mosaic(rois: torch.Tensor) -> torch.Tensor:
    """Concatenate the images in regular grid."""
    assert isinstance(rois, torch.Tensor), rois.__class__.__name__
    assert rois.ndim == 3, rois.shape
    width = max(1, round(math.sqrt(len(rois))))
    height = max(1, round(len(rois) / width))
    mosaic = torch.empty((height*rois.shape[1], width*rois.shape[2]))
    for i in range(height):
        i_stride = i * width
        i_rel = i * rois.shape[1]
        for j in range(width):
            j_rel = j * rois.shape[2]
            patch = rois[i_stride + j] if i_stride + j < len(rois) else 0.0
            mosaic[i_rel:i_rel+rois.shape[1],j_rel:j_rel+rois.shape[2]] = patch
    return mosaic

### Generation of random gaussian mixture images

In [None]:
# configuration

NB_GAUSSIANS = 3  # the number of gaussians
NB_PATCHES = 12  # batch dimension, number of generated images
ROIS_SHAPE = (40, 50)  # the dimension (height, width) of the generated images
NB_PHOTONS = 100_000  # the number of photons by images
BKG_NOISE = 1e-3  # the std of the aditional noise

In [None]:
# relative gaussians weight
eta = torch.rand((NB_PATCHES, NB_GAUSSIANS)) + 1/NB_GAUSSIANS  # min intensity is 50% lower than higest
eta /= eta.sum(dim=1, keepdim=True)

# gaussian location
mean = torch.rand((NB_PATCHES, NB_GAUSSIANS, 2))
mean *= torch.asarray([[[0.6*ROIS_SHAPE[0], 0.6*ROIS_SHAPE[1]]]])
mean += torch.asarray([[[0.2*ROIS_SHAPE[0], 0.2*ROIS_SHAPE[1]]]])

# covariance matrices
theta = torch.rand(NB_PATCHES, NB_GAUSSIANS, 1) * (2 * torch.pi)
cos_theta, sin_theta = torch.cos(theta), torch.sin(theta)
rot = torch.cat([cos_theta, -sin_theta, torch.sin(theta), sin_theta], dim=2)
rot = rot.reshape(NB_PATCHES, NB_GAUSSIANS, 2, 2)

cov = torch.empty((NB_PATCHES, NB_GAUSSIANS, 2, 2))
cov[:, :, [0, 1], [0, 1]] = (
    torch.rand((NB_PATCHES, NB_GAUSSIANS, 2))
    * min(ROIS_SHAPE) / 6  # biggest gaussian fit in roi (99% = 3 * std)
)
cov[:, :, [0, 1], [1, 0]] = 0.0
cov *= cov  # cov is var = std**2
cov = rot @ cov @ rot.mT  # random rotation

In [None]:
# draw random photons in detector
distribution = torch.distributions.multivariate_normal.MultivariateNormal(loc=mean, covariance_matrix=cov)
samples = distribution.sample((NB_PHOTONS // NB_GAUSSIANS,))  # shape: (n_samples, n_patches, nb_gaussians, 2)

# integration of photons contribution
samples = samples.to(torch.int32)
rois = samples[..., 0]*ROIS_SHAPE[1] + samples[..., 1]
roislist = []
for patch in range(NB_PATCHES):
    roi = rois[:, patch, :]
    roi = roi[torch.logical_and(roi >= 0, roi < ROIS_SHAPE[0]*ROIS_SHAPE[1])]  # removes photons out of the detector
    roi = torch.bincount(roi, minlength=ROIS_SHAPE[0]*ROIS_SHAPE[1]).to(torch.float32)
    roi = roi.reshape(1, *ROIS_SHAPE)
    roislist.append(roi)
rois = torch.cat(roislist)  # shape: (n_patches, height, width)
rois /= rois.amax(dim=(1, 2), keepdim=True)

# add noise
if BKG_NOISE:
    rois += torch.randn((NB_PATCHES, *ROIS_SHAPE)) * BKG_NOISE
    rois = torch.clamp(rois, 0.0, 1.0, out=rois)

# set the rois into a new diagram
diagram = laueimproc.Diagram(concat_mosaic(rois))
step = max(1, round(math.sqrt(len(rois))))
diagram.set_spots([(rois.shape[1]*(i//step), rois.shape[2]*(i%step), roi) for i, roi in enumerate(rois)])

# _ = diagram.plot(plt.figure(figsize=(8, 8))); plt.show()

### Find the gaussians with the algorithme EM
* for mor details on this algorithme, please refer on the `laueimproc` documentation

In [None]:
NBR_TRIES = 3  # number of times we fit each roi
CRITERIA = "bic"  # "aic" or "bic"

In [None]:
# try with different nbr of gaussians
all_results = {}
for nbr_clusters in tqdm.tqdm(range(1, NB_GAUSSIANS+4, 1)):
    mean, cov, eta, infodict = diagram.fit_gaussians_em(
        nbr_clusters=nbr_clusters, nbr_tries=NBR_TRIES, **{CRITERIA: True}, cache=False
    )
    all_results[nbr_clusters] = {"mean": mean, "cov": cov, "eta": eta, "criteria": infodict[CRITERIA]}

# keep the best fit
results = []
for i in range(len(diagram)):
    best_criteria = min(float(r["criteria"][i]) for r in all_results.values())
    nbr_clusters = [n for n in sorted(all_results) if all_results[n]["criteria"][i] == best_criteria][0]
    results.append({k: v[i] for k, v in all_results[nbr_clusters].items()})

In [None]:
# plot the images
mean = torch.cat([r["mean"].reshape(-1, 2) for r in results])
cov = torch.cat([r["cov"].reshape(-1, 2, 2) for r in results])
eta = torch.cat([r["eta"].reshape(-1) for r in results])
axe = diagram.plot(plt.figure(figsize=(8, 8)))
plt.scatter(mean[:, 1], mean[:, 0])
plt.show()

# plot the numbers of gaussians
plt.figure().set_figheight(8); plt.figure().set_figheight(8)
plt.title(f"number of gaussians criteria {CRITERIA}")
# plt.plot(sorted(all_results), [all_results[c]["criteria"] for c in sorted(all_results)], "-o")
plt.plot(sorted(all_results), [all_results[c]["criteria"].mean() for c in sorted(all_results)], "-o")
plt.show()

### On real data

In [None]:
diagram = laueimproc.Diagram(laueimproc.io.get_sample())
diagram.find_spots(radius_aglo=20)