# This notebook shows the canonical mean to select some spots with a simple filter
* It is possible to interact with external libraries, have a look on the `init_diagram` notebook.
* To instantiate a `Diagram` allows you to apply efficiently filters based on image processing.
* They are several means to instantiate a `diagram` some of them are betters than others.

In [None]:
%matplotlib notebook

import multiprocessing.pool
import pathlib

import matplotlib.pyplot as plt
import torch
import tqdm

import os

from laueimproc import Diagram
from laueimproc.io.download import get_samples  # gives access to the dataset

# Init the set of all possible Laue pattern: `diagrams`
* The different ways of initializing a diagram are described in detail in the `init_diagram` notebook.

**WARNING! to launch only once (to avoid memory increase)**

In [None]:
#all_files = sorted(get_samples().glob("*.jp2"))  # the list of all images path
#all_files = list(pathlib.Path("/data/visitor/a322855/bm32/20240221/RAW_DATA/Almardi/Almardi_map2DGemardi3_GOOD_0004/scan0002").iterdir())

# sort by name the list of only image file .tif
all_files = sorted(pathlib.Path("/data/visitor/a322855/bm32/20240221/RAW_DATA/Almardi/Almardi_2D_Almardi/scan0001/").glob('*.tif'))
papath = "/data/visitor/a322855/bm32/20240221/RAW_DATA/Almardi/Almardi_2D_Almardi/scan0001/"

myfuncsort = lambda elem: int(os.path.split(elem)[-1][:-4].split('_')[-1])
all_files = sorted(pathlib.Path(papath).glob('*.tif'), key= myfuncsort)

CCDLabel = 'sCMOS'

print('preparing diagram %s objects -----'%len(all_files))
diagrams = [Diagram(f) for f in tqdm.tqdm(all_files)]
nbimages = len(diagrams)
if CCDLabel in ('sCMOS','MARCCD165'):
    imagedynamics = 65535 # 2**16 -1


# peak search on all images (with opencv)

In [None]:
# density  from 0 to 1 : large density to get higher nb of spots 
for diagram in tqdm.tqdm(diagrams):
    diagram.find_spots(density=0.7)

# look at the results

In [None]:
# you can launch the peak search on a single image
_idx=22000
mydiagram = diagrams[_idx]
#mydiagram.find_spots(density=0.65)
print(mydiagram)

In [None]:
# be careful that index in the list of files at init of diagrams could not correspond to the index of the image...
print('index in sorted list of files in a folder:', _idx)
print('full path to image file',mydiagram.file)

In [None]:
mydiagram.file.parent

In [None]:
mydiagram.folder = mydiagram.file.parent
mydiagram.folder

In [None]:
print('maximum intensity')
int(diagrams[_idx].image.max()*imagedynamics)

In [None]:
idx = torch.argmax(mydiagram.image)
idx

In [None]:
mydiagram.image.shape

In [None]:
def getpixelsposfrom1dindex(idx, shape=(2018,2016)):
    """return 2D index corresponding idx 1D index
    
    j, i = pixelX, pixelY coordinates"""
    
    return idx//shape[1], idx%shape[1]

getpixelsposfrom1dindex(idx)

In [None]:
_idx = 0
diagrams[_idx].plot(plt.figure(layout='tight'), vmin=1000./65535, vmax=1200./65535)
plt.show()
print(diagrams[_idx].bboxes[0])
print(diagrams[_idx].image.max()*65535)

# GUI image and results browser

In [None]:
from ipywidgets import widgets, interactive, interact_manual
import numpy as np
myfig, ax = plt.subplots()
ppp = ax.imshow(diagrams[0].image.T, vmin=1000./65535, vmax=1200./65535)
print(diagrams[0].image.max()*65535)
bboxes = diagrams[0].bboxes.numpy(force=True)
roiplot =ax.plot(
        np.vstack((
            bboxes[:, 0],
            bboxes[:, 0]+bboxes[:, 2],
            bboxes[:, 0]+bboxes[:, 2],
            bboxes[:, 0],
            bboxes[:, 0],
        )),
        np.vstack((
            bboxes[:, 1],
            bboxes[:, 1],
            bboxes[:, 1]+bboxes[:, 3],
            bboxes[:, 1]+bboxes[:, 3],
            bboxes[:, 1],
        )),
        color="blue",
        scalex=False,
        scaley=False,
    )

def plotimage(idx=0):
    ax.clear()
    print(ax.lines)
    for ar in ax.lines:
        ar.remove()
    #ppp.set_data(diagrams[idx].image)
    print('idx',idx)
    di = diagrams[idx]
    ax.imshow(di.image.T, vmin=1000./65535, vmax=1200./65535)
    bboxes = di.bboxes.numpy(force=True)
    if 0:
        print('idx',idx)
        print(di.file)
        print(di.image.max()*65535)
        print(bboxes[0])
    
    roiplot =ax.plot(
        np.vstack((
            bboxes[:, 0],
            bboxes[:, 0]+bboxes[:, 2],
            bboxes[:, 0]+bboxes[:, 2],
            bboxes[:, 0],
            bboxes[:, 0],
        )),
        np.vstack((
            bboxes[:, 1],
            bboxes[:, 1],
            bboxes[:, 1]+bboxes[:, 3],
            bboxes[:, 1]+bboxes[:, 3],
            bboxes[:, 1],
        )),
        color="blue",
        scalex=False,
        scaley=False,
    )
    folder, filename= os.path.split(diagrams[idx].file)
    ax.set_title('%s'%filename)

# w1 = interactive(plotimage, idx=(0,len(diagrams)-1))
# display(w1)



btn = widgets.Button(description='show')
display(btn)

def show(b):
    imageindex = int(windex.value)
    plotimage(imageindex)
    
windex = widgets.Text(value='6',
    description='image index:',
   )
display(windex)
    
btn.on_click(show)

#from ipywidgets import Button, HBox, VBox




In [None]:
allidx = torch.argmax(mydiagram.image,dim=0)
roishape = mydiagram.rois[0].shape

In [None]:
# pixel Coordinates  X, Y (or j, i) of ROIs center
mydiagram.centers

In [None]:
mydiagram.bboxes, mydiagram.bboxes[703]

In [None]:
# roi properties
roi_ix= 36

roidata = mydiagram.rois[roi_ix]

#array of intensity  with zero padding so that all rois have the same size
print('array of pixel intensity', roidata*65535)
#the roi is centered at X , Y pixel value
print('center at pixel:', mydiagram.centers[roi_ix])
# the corresponding bbox [X,Y, iboxsize, jboxsize] is anchored at X, Y 
mydiagram.bboxes[roi_ix]

In [None]:
#list of all roi's boxsize
mydiagram.bboxes[:,2:]

In [None]:
# max and argmax in all rois !
mydiagram.rois.max(0) # (0) output shape without the dim 0 

# find maximum and position for all rois

In [None]:
lp = mydiagram

#-------------------------------------------------------
nbrois = len(lp.rois)
# local indices of Imax in each roi
imax2D = lp.rois.max(1)
jmax = imax2D.values.argmax(1)
imax = imax2D.indices[range(nbrois),jmax]
allmax = lp.rois[range(nbrois),imax, jmax]*65535

print("local position of roi's max intensity:\n",imax, jmax)
print("max intensity of each roi",allmax)
pixelX = jmax+lp.bboxes[:, 0]
pixelY = imax+lp.bboxes[:, 1]
print("global pixel position of roi's maxima:\n",pixelX,pixelY)

In [None]:
lp.spots[703]

In [None]:
# check
roi_ix = 703
print(allmax[roi_ix],lp.rois[roi_ix][:8,:8]*65535)
print('bbox  Xtopleft, Ytopleft, boxsizeX, boxsizeY',lp.bboxes[roi_ix])
print("centers  pixel X, Y",lp.centers[roi_ix])
print("local imax, jmax (deltaY, deltaX)",imax[roi_ix],jmax[roi_ix])
print(pixelX[roi_ix], pixelY[roi_ix], 0.01929*65535)

In [None]:
# probleme inversion axe boxsize! ??  car : pîxelX (jcolumn), pixelY (iline), boxsizeline, boxsizecolumn 
# la valeur afficher dans le plot 0.01929 (*65535=1264) ne correspond pas à 194 

In [None]:
#  sort by bbox size
idxsort = torch.argsort(mydiagram.bboxes[:,2:].max(1).values, descending=True)
didi = mydiagram.filter_spots(idxsort, msg='sort by bbox')

In [None]:
len(didi.rois)

In [None]:
#  select bbox size < 10
cond= didi.bboxes[:,2:].max(1).values < 10
didi2 = didi.filter_spots(cond, msg='select only small bbox < 10')

In [None]:
print(didi2)

In [None]:
didi2.centers

In [None]:
didi2.plot(plt.figure(layout="tight", figsize=(8, 8)), vmin=0.0129, vmax=0.02)

In [None]:
#  select bbox aspect ratio < 1.2
bboxratios = didi2.bboxes[:,2:].max(1).values/didi2.bboxes[:,2:].min(1).values
condratio= bboxratios < 1.2
didi3 = didi2.filter_spots(condratio, msg='select only  bbox with ratio < 1.3')

In [None]:
print(didi3)

In [None]:
didi3.plot(plt.figure(layout="tight", figsize=(8, 8)), vmin=0.0129, vmax=0.02)

In [None]:
lp = didi3  # diagram object

#----------------------------------
nbrois = len(lp.rois)
# local indices of Imax in each roi
imax2D = lp.rois.max(1)
jmax = imax2D.values.argmax(1)
imax = imax2D.indices[range(nbrois),jmax]
allmax = lp.rois[range(nbrois),imax, jmax]*65535

print("local position of roi's max intensity:\n",imax, jmax)
print("max intensity of each roi",allmax)
pixelY = imax+lp.bboxes[:, 0]
pixelX = jmax+lp.bboxes[:, 1]
print("global pixel position of roi's maxima:\n",pixelX,pixelY)
#-------------------------------------

In [None]:
# check  DOES NOT WORK   wrong ROI or center ...
roi_ix = 0
allmax[roi_ix],lp.rois[roi_ix][:8,:8]*65535, lp.centers[roi_ix], 0.02261*65535

In [None]:
len(lp.centers), len(lp.rois)

In [None]:
lp.centers, allmax

In [None]:
import copy
didi3bis = copy.copy(didi3)

In [None]:
sorted_indexs = torch.argsort(allmax, descending=True)
didi3.filter_spots(sorted_indexs, msg="sorted by intensities", inplace=True) #
didi4 = didi3bis.filter_spots(sorted_indexs, msg="sorted by intensities", inplace=False)

In [None]:
roi_ix = 300
allmax[sorted_indexs][roi_ix], didi3.rois[roi_ix]*65535

In [None]:
roi_ix = 300
allmax[sorted_indexs][roi_ix], didi4.rois[roi_ix]*65535

In [None]:
didi4.centers, allmax[sorted_indexs]

In [None]:
0.04169*65535

In [None]:
pos =torch.where(torch.abs(diagrams[_idx].fit_gaussian_em()[0][:,0]-895)<4)
pos[0]

In [None]:
didi3.fit_gaussian_em()[0]

In [None]:
didi3.fit_gaussian_em()[1][500]

In [None]:
didi3.history

In [None]:
# for diagram in diagrams[:2]:
    print(diagram)
    diagram.plot(plt.figure(layout="tight", figsize=(8, 8)), vmin=1000/65000., vmax=.1*diagram.image.max().item()); plt.show()

### Basic filtering

In [None]:
from laueimproc.opti.manager import DiagramManager
DiagramManager().verbose=True

In [None]:
"""Sorted by intensity."""

for diagram in tqdm.tqdm(diagrams, smoothing=0.01):
    intensities = diagram.compute_pxl_intensities()
    intensities = allmax
    sorted_indexs = torch.argsort(intensities, descending=True)
    diagram.filter_spots(sorted_indexs, msg="sorted by intensities", inplace=True)
    # print(diagram)
    # diagram.plot(plt.figure(layout="tight", figsize=(8, 8))); plt.show()

In [None]:
"""Define some utils"""

def select_items(criteria: torch.Tensor, threshold: float, nb_min: int) -> torch.Tensor:
    """Select a least `nb_min` items or also `criteria` >= `threshold`."""
    if len(criteria) <= nb_min:
        return torch.arange(len(criteria))
    indexs = torch.argsort(criteria, descending=True)  # the indexs from the best to the worse
    threshold = min(threshold, criteria[indexs[nb_min-1]].item())
    return criteria >= threshold

In [None]:
"""Select rotation symetric spots."""

for diagram in tqdm.tqdm(diagrams, smoothing=0.01):
    sym = diagram.compute_rot_sym()
    selection = select_items(sym, 0.8, 10)  # 10 best or round peaks
    diagram.filter_spots(selection, msg="keep circular spots", inplace=True)
    # print(diagram)
    # diagram.plot(plt.figure(layout="tight", figsize=(8, 8))); plt.show()

### Fit spots roi with gaussian

In [None]:
"""Gaussian fit max likelihood."""

for diagram in tqdm.tqdm(diagrams[99:101], smoothing=0.01):
    mean, cov, infodict = diagram.fit_gaussian_em(photon_density=10900.0, tol=True, eigtheta=True)
    print(infodict)
    #confidence = 3.0 * infodict["tol"]  # 99% confidence interval on pixel position
    #selection = select_items(-confidence, -0.5, 10)  # 10 best or position +- 0.5pxl
    #diagram.filter_spots(selection, msg="remove spots with uncertain position", inplace=True)
    # print(diagram)
    # diagram.plot(plt.figure(layout="tight", figsize=(8, 8))); plt.show()

In [None]:
# full (*20 width (3 sigma) of std pca1 pca2 (third column= inclination)
2*3*torch.sqrt(infodict['eigtheta'])