## Example of unsupervised spot classification

In [None]:
%matplotlib inline

import matplotlib.pyplot as plt
import torch

import laueimproc

### Init a dataset of diagrams
Please have a look on the ``api_init_dataset`` notebook for more details.

In [None]:
data_directory = laueimproc.io.get_samples()
dataset = laueimproc.DiagramsDataset(data_directory)

In [None]:
def peaks_search(diagram: laueimproc.Diagram, *args, **kwargs) -> int:
    """Init the diagram with the internal laueimproc peaks search function."""
    diagram.find_spots(*args, **kwargs)
    return len(diagram)

nb_spots = dataset.apply(peaks_search, kwargs={"density": 0.75, "radius_aglo": 4})
print(f"On average, there are {sum(nb_spots.values())/len(dataset)} spots per diagram.")

### Select the interesting spots
Before classifying all the spots, we need to extract a family.

In [None]:
def filter_intensity(diagram: laueimproc.Diagram, threshold: float) -> int:
    """Remove spots that are not intense enough."""
    intensities = diagram.compute_rois_max(cache=False)[:, 2]
    diagram.filter_spots(intensities >= threshold, "keep intensive spots")
    return len(diagram)

def filter_size(diagram: laueimproc.Diagram, size_min: int, size_max) -> int:
    """Remove the too small bboxes."""
    heights = diagram.bboxes[:, 2]
    widths = diagram.bboxes[:, 3]
    cond = torch.logical_and(heights >= size_min, widths >= size_min)
    diagram.filter_spots(cond, f"keep bboxes >= {size_min}")
    heights = diagram.bboxes[:, 2]
    widths = diagram.bboxes[:, 3]
    cond = torch.logical_and(heights <= size_max, widths <= size_max)
    diagram.filter_spots(cond, f"keep bboxes <= {size_max}")
    return len(diagram)
    

In [None]:
filtered_dataset = dataset.clone()  # fork
backup_file = filtered_dataset.autosave(
    "/tmp/dataset_backup.pickle",
    delay="10 min",
    # restore=False,  # to ignore existing backup
)

In [None]:
nb_spots = filtered_dataset.apply(filter_intensity, args=(0.005,))
print(f"On average, there are {sum(nb_spots.values())/len(dataset)} spots per diagram.")
nb_spots = filtered_dataset.apply(filter_size, args=(24, 128))
print(f"On average, there are {sum(nb_spots.values())/len(dataset)} spots per diagram.")

### Train the model

In [None]:
# create and warm the model
model = filtered_dataset.train_spot_classifier(intensity_sensitive=True, scale_sensitive=True, space=5, shape=(20, 20))
print(f"real snapshot shape: {model.shape}")
print(f"latent square dimension: {2*model.space} x {2*model.space}")

In [None]:
# train the model
fig = plt.figure(layout="tight", figsize=(16, 12))  # optional, for data visualization
model = filtered_dataset.train_spot_classifier(model=model, batch=1000, epoch=200, fig=fig)
# model = filtered_dataset.train_spot_classifier(model=model, epoch=100, fig=fig)
plt.show()

In [None]:
model(filtered_dataset[0])