# This notebook shows the canonical mean to select some spots with a simple filter.
* It is possible to interact with external libraries, have a look on the `init_diagram` notebook.
* To instantiate a `Diagram` allows you to apply efficiently filters based on image processing.
* They are several means to instantiate a `diagram` some of them are betters than others.

In [None]:
%matplotlib notebook

import multiprocessing.pool
import pathlib

import matplotlib.pyplot as plt
import torch
import tqdm

from laueimproc import Diagram
from laueimproc.io.download import get_samples  # gives access to the dataset

In [None]:
"""Define some utils"""

def select_items(criteria: torch.Tensor, threshold: float, nb_min: int) -> torch.Tensor:
    """Select a least `nb_min` items or also `criteria` >= `threshold`."""
    if len(criteria) <= nb_min:
        return torch.arange(len(criteria))
    indexs = torch.argsort(criteria, descending=True)  # the indexs from the best to the worse
    threshold = min(threshold, criteria[indexs[nb_min-1]].item())
    return criteria >= threshold

## Init the set of all possible Laue pattern: `diagrams`
* The different ways of initializing a diagram are described in detail in the `init_diagram` notebook.

**WARNING! to launch only once (to avoid memory increase)**

In [None]:
all_files = sorted(get_samples().glob("*.jp2"))  # the list of all images path
# all_files = list(pathlib.Path("/data/visitor/a322855/bm32/20240221/RAW_DATA/Almardi/Almardi_map2DGemardi3_GOOD_0004/scan0002").iterdir())
print(f"instanciate {len(all_files)} diagram objects...")
diagrams = [Diagram(f) for f in tqdm.tqdm(all_files)]

## peak search on all images (with laue improc algo)
* The different ways to configure the peaks search are described in the `peaks_search` notebook.

In [None]:
# density from 0.15 to 0.85: large density to get higher nb of spots 
for diagram in tqdm.tqdm(diagrams):
    diagram.find_spots(density=0.6)

## look at the results

In [None]:
for diagram in diagrams[:10]:
    print(diagram)  # history, current state
    diagram.plot(plt.figure(layout="tight", figsize=(8, 8)), vmax=diagram.image.max().item()); plt.show()

### Basic filtering

In [None]:
"""Sorted by intensity."""

for diagram in tqdm.tqdm(diagrams, smoothing=0.01):
    intensities = diagram.compute_pxl_intensities()
    sorted_indexs = torch.argsort(intensities, descending=True)
    diagram.filter_spots(sorted_indexs, msg="sorted by intensities", inplace=True)
    # print(diagram)
    # diagram.plot(plt.figure(layout="tight", figsize=(8, 8))); plt.show()

In [None]:
"""Select rotation symetric spots."""

for diagram in tqdm.tqdm(diagrams, smoothing=0.01):
    sym = diagram.compute_rot_sym()
    selection = select_items(sym, 0.8, 10)  # 10 best or round peaks
    diagram.filter_spots(selection, msg="keep circular spots", inplace=True)
    # print(diagram)
    # diagram.plot(plt.figure(layout="tight", figsize=(8, 8))); plt.show()

### Fit spots roi with gaussian

In [None]:
"""Simple fast gaussian fit max likelihood (proba approch)."""

for diagram in tqdm.tqdm(diagrams, smoothing=0.01):
    mean, cov, infodict = diagram.fit_gaussian_em(photon_density=10900.0, eigtheta=True)
    stretch = torch.sqrt(infodict["eigtheta"][:, 0] / infodict["eigtheta"][:, 1])  # gauss std ratio >= 1
    # rot = torch.rad2deg(infodict["eigtheta"][:, 2])  # spot rotation in ]-90, 90]
    selection = select_items(1/stretch, 0.667, 10)  # 10 rounder, or stretch < 1.5
    diagram.filter_spots(selection, msg="keeps spots not too stretched", inplace=True)
    # print(diagram)
    # diagram.plot(plt.figure(layout="tight", figsize=(8, 8))); plt.show()