In [None]:
%matplotlib qt
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
import hyperspy.api as hs
from scipy.ndimage import gaussian_filter
from joblib import Parallel, delayed, cpu_count, parallel_backend
import time
from matplotlib.backend_bases import MouseButton
from tqdm import tqdm


In [None]:
def gkern(l=256, sig=64, x = 0, y = 0):
    ax = np.linspace(-(l - 1) / 2., (l - 1) / 2., l)
    gauss_x = np.exp(-0.5 * np.square(ax - x) / np.square(sig))
    gauss_y = np.exp(-0.5 * np.square(ax - y) / np.square(sig))

    kernel = np.outer(gauss_y, gauss_x)
    return kernel / np.sum(kernel)

In [None]:

def find_extrema(data, threshold=30E-5, l=254):
    extrema_points = [extrema(data[i:i+5]) for i in range(l)]
    points_indices = [0] + [i+2 for i, x in enumerate(extrema_points) if x]
    
    if len(points_indices) < 1:
        return np.array([])
    
    derivatives = [greatest_derivative(data[points_indices[i]:points_indices[i+1]]) for i in range(len(points_indices) - 1)]
    extrema_indices = [points_indices[i + 1] for i, x in enumerate(derivatives) if x > threshold]

    return np.array(extrema_indices)

def extrema(array):
    n = len(array)
    v = array[n // 2]
    s1 = sum(v - x for x in array[:n // 2])
    s2 = sum(v - x for x in array[n // 2 + 1:])
    return s1 > 0.0 and s2 > 0.0


def greatest_derivative(array):
    return len(array) >= 2 and max(array[i + 1] - array[i] for i in range(len(array) - 1))

def distance(point1, point2):
    return np.sum((point1 - point2) ** 2)


In [None]:
def detect_points(l, image, distance_threshold=25, derivative_threshold=15E-5):
    scatters_y = [find_extrema(image[:, j], derivative_threshold, l) for j in range(l)]
    scatters_x = [find_extrema(image[j, :], derivative_threshold, l) for j in range(l)]
    
    scatters = []
    seen_points = set()
    
    for i, row in enumerate(scatters_y):
        for element in row:
            if i in scatters_x[element]:
                current_point = np.array([i, element])
                too_close = any(distance(current_point, prev_point) < distance_threshold for prev_point in seen_points)
                if not too_close:
                    scatters.append(current_point)
                    seen_points.add(tuple(current_point))

    return np.array(scatters)

    

In [None]:
scatters = [[] for _ in range(100)]
parameters = [[0,0,0,0,0] for _ in range(100)]
parameters[0] = 0.1,32,5,6,7

In [None]:
parameters[1] == [0.1,0,0,0,0]

In [None]:
threshold, gauss1, gauss2, gauss3, gauss4 = 0.1,32,5,6,7
x, y = np.meshgrid(np.arange(256), np.arange(256))
mask = np.ones((256,256))
im0, im1, im2 = None, None, None 
idx = 0
fig = plt.figure(figsize=(10, 10))
fig.subplots_adjust(bottom=0.4)
def update(val = None):
    global im0, im1, im2, threshold, gauss1, gauss2, gauss3, gauss4, scatters
    threshold = threshold_slider.val
    gauss1 = gauss1_slider.val
    gauss2 = gauss2_slider.val
    gauss3 = gauss3_slider.val
    gauss4 = gauss4_slider.val
    data, mask,points = blob_detection(image, (threshold_slider.val, gauss1_slider.val, gauss2_slider.val, gauss3_slider.val, gauss4_slider.val), idx)
    scatters[idx] = points
    im1.set_data(data)
    im2.set_data(mask)
    fig.canvas.draw_idle()
    
sliders = []
threshold_slider = None
gauss1_slider = None
gauss2_slider = None
gauss3_slider = None
gauss4_slider = None

def press(event):
    global idx, scatters, parameters, threshold, gauss1, gauss2, gauss3, gauss4
    update = False
    if event.key == "enter" or event.key == "right":
        parameters[idx] = [threshold, gauss1, gauss2, gauss3, gauss4]
        idx += 1
        if parameters[idx] == [0,0,0,0,0]:
            parameters[idx] = parameters[idx - 1]
    if event.key == "left":
        idx -= 1
        
    print(idx)
    threshold, gauss1, gauss2, gauss3, gauss4 = parameters[idx]
    data, mask,points = blob_detection(image, (threshold, gauss1, gauss2, gauss3, gauss4), idx)
    scatters[idx] = points
    im0.set_data(image[idx])
    im1.set_data(data)
    im2.set_data(mask)

    fig.canvas.draw()
    


def plot(original_image, finished, mask):
    global im0, im1, im2, threshold_slider, gauss1_slider, gauss2_slider, gauss3_slider, gauss4_slider

    plt.rcParams['font.size'] = 24
    ax = [fig.add_subplot(131),
        fig.add_subplot(132),fig.add_subplot(133)]
    
    for i,a in enumerate(ax):
        a.tick_params(left = False, right = False , labelleft = False , 
                    labelbottom = False, bottom = False) 

        a.set_title("(" + alphabet[i] + ")")
        
    im0 = ax[0].imshow(original_image[idx], cmap = "Greys_r", norm = "symlog")

    im1 = ax[1].imshow(finished, cmap = "Greys_r", norm = "symlog")

    im2 = ax[2].imshow(mask, cmap = "Greys_r", norm = "symlog")

    axthres = fig.add_axes([0.25, 0.3, 0.65, 0.03])
    threshold_slider = Slider(
        ax=axthres,
        label='Thresholding',
        valmin=0.00,
        valmax=0.3,
        valinit=threshold,
    )
    axgauss1 = fig.add_axes([0.25, 0.25, 0.65, 0.03])
    gauss1_slider = Slider(
        ax=axgauss1,
        label='Edge sensitivity',
        valmin=1,
        valmax=256,
        valinit=gauss1,
    )
    axgauss2 = fig.add_axes([0.25, 0.2, 0.65, 0.03])
    gauss2_slider = Slider(
        ax=axgauss2,
        label='Differnce of Gauss minuend',
        valmin=1,
        valmax=10,
        valinit=gauss2,
    )
    axgauss3 = fig.add_axes([0.25, 0.15, 0.65, 0.03])
    gauss3_slider = Slider(
        ax=axgauss3,
        label='Differnce of Gauss subtrahend',
        valmin=1,
        valmax=10,
        valinit=gauss3,
    )
    axgauss4 = fig.add_axes([0.25, 0.1, 0.65, 0.03])
    gauss4_slider = Slider(
        ax=axgauss4,
        label='Gauss blur',
        valmin=1,
        valmax=10,
        valinit=gauss4,
    )

    sliders.append(threshold_slider)
    sliders.append(gauss1_slider)
    sliders.append(gauss2_slider)
    sliders.append(gauss3_slider)
    sliders.append(gauss4_slider)

    
    threshold_slider.on_changed(update)
    gauss1_slider.on_changed(update)
    gauss2_slider.on_changed(update)
    gauss3_slider.on_changed(update)
    gauss4_slider.on_changed(update)
    fig.canvas.mpl_connect('key_press_event', press)

    plt.show()



In [None]:
def blob_detection(image, parameters, idx = 0,radius = 3, background_value = 0.0):

    if len(image.shape) == 3:
        image = image[idx]


    size = image.shape

    original_image = image.copy()    

    
    threshold, gauss1, gauss2, gauss3, gauss4 = parameters

    kern = gkern(sig = 5)
    image = image * (1 - kern / np.max(kern))

    kern = gkern(sig = gauss1)
    image = image / (kern / np.max(kern))
    

    gauss = gaussian_filter(image, sigma=gauss2)
    gauss  -= gaussian_filter(image, sigma=gauss3)

    image = gaussian_filter(gauss, sigma=gauss4)

    scatters = detect_points(size[0], image,distance_threshold=40, derivative_threshold = threshold)
 
    mask = np.zeros(size)
    for point in scatters:
        distances = (x - point[0])**2 + (y - point[1])**2
        mask[distances <= 25] = 1
    #mask += create_circular_mask(size[0], size[1], center = point, radius = radius) 

    mask = np.where(mask > 0.0, original_image, background_value)
    #return mask
    return image, mask, scatters


In [None]:
def load_signal(filename, lazy = True):
    return hs.load(filename, lazy = lazy)


In [None]:
signal = load_signal(..., lazy = True)

In [None]:
data = np.array([signal.inav[np.random.randint(0,256),np.random.randint(0,256)].data for i in range(100)])

In [None]:
image = data
processed_image, mask,_ = blob_detection(image, (threshold, gauss1, gauss2, gauss3, gauss4), 0)
plot(image, processed_image, mask)

In [None]:
np.save(..., scatters)

In [None]:
import pickle
def save_masks(x_points, signal):
    with open(..., "wb") as f:
        pickle.dump(x_points, f)
  
    np.save(..., signal)

In [None]:
save_masks(scatters, data)