In [None]:
%matplotlib qt
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.widgets import Slider
import hyperspy.api as hs
from scipy.ndimage import gaussian_filter
from joblib import Parallel, delayed, cpu_count, parallel_backend
import time
from shapely.geometry import Point, Polygon
from matplotlib.backend_bases import MouseButton
from tqdm import tqdm
import pickle 
alphabet = ["a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l"]


In [None]:
def gkern(l=256, sig=64, x = 0, y = 0):
    ax = np.linspace(-(l - 1) / 2., (l - 1) / 2., l)
    gauss_x = np.exp(-0.5 * np.square(ax - x) / np.square(sig))
    gauss_y = np.exp(-0.5 * np.square(ax - y) / np.square(sig))

    kernel = np.outer(gauss_y, gauss_x)
    return kernel / np.sum(kernel)

In [None]:

def find_extrema(data, threshold=30E-5, l=254):
    extrema_points = [extrema(data[i:i+5]) for i in range(l)]
    points_indices = [0] + [i+2 for i, x in enumerate(extrema_points) if x]
    
    if len(points_indices) < 1:
        return np.array([])
    
    derivatives = [greatest_derivative(data[points_indices[i]:points_indices[i+1]]) for i in range(len(points_indices) - 1)]
    extrema_indices = [points_indices[i + 1] for i, x in enumerate(derivatives) if x > threshold]

    return np.array(extrema_indices)

def extrema(array):
    n = len(array)
    v = array[n // 2]
    s1 = sum(v - x for x in array[:n // 2])
    s2 = sum(v - x for x in array[n // 2 + 1:])
    return s1 > 0.0 and s2 > 0.0


def greatest_derivative(array):
    return len(array) >= 2 and max(array[i + 1] - array[i] for i in range(len(array) - 1))

def distance(point1, point2):
    return np.sum((point1 - point2) ** 2)


In [None]:
def detect_points(l, image, distance_threshold=25, derivative_threshold=15E-5):
    scatters_y = [find_extrema(image[:, j], derivative_threshold, l) for j in range(l)]
    scatters_x = [find_extrema(image[j, :], derivative_threshold, l) for j in range(l)]
    
    scatters = []
    seen_points = set()
    
    for i, row in enumerate(scatters_y):
        for element in row:
            if i in scatters_x[element]:
                current_point = np.array([i, element])
                too_close = any(distance(current_point, prev_point) < distance_threshold for prev_point in seen_points)
                if not too_close:
                    scatters.append(current_point)
                    seen_points.add(tuple(current_point))

    return np.array(scatters)

    

In [None]:
def create_circular_mask(h, w, center=None, radius=None):

    if center is None: # use the middle of the image
        center = (int(w/2), int(h/2))
    if radius is None: # use the smallest distance between the center and image walls
        radius = min(center[0], center[1], w-center[0], h-center[1])

    Y, X = np.ogrid[:h, :w]

    dist_from_center = ((X - center[0])**2 + (Y-center[1])**2)**0.5

    mask = dist_from_center <= radius
    return mask

In [None]:
threshold, gauss1, gauss2, gauss3, gauss4 = 0.1,32,5,6,7

im1, im2, im3 = None, None, None 
fig = None
def update(val):
    global im1, im2, threshold, gauss1, gauss2, gauss3, gauss4, fig
    threshold = threshold_slider.val
    gauss1 = gauss1_slider.val
    gauss2 = gauss2_slider.val
    gauss3 = gauss3_slider.val
    gauss4 = gauss4_slider.val
    data, mask, scatters = blob_detection(image, (threshold_slider.val, gauss1_slider.val, gauss2_slider.val, gauss3_slider.val, gauss4_slider.val))
    im1.set_data(data)

    print("scatters:",scatters)
    if len(scatters) != 0:
        im2.set_data(scatters[:,0], scatters[:,1])
    else:
        im2.set_data([128], [128])

    
    im3.set_data(mask)#mask[:,0], mask[:,1])
    

    fig.canvas.draw_idle()
    
sliders = []
threshold_slider = None
gauss1_slider = None
gauss2_slider = None
gauss3_slider = None
gauss4_slider = None
def plot(original_image, finished, mask, scatters):
    global im1, im2, im3, fig, threshold_slider, gauss1_slider, gauss2_slider, gauss3_slider, gauss4_slider
    fig = plt.figure(figsize=(10, 10))
    fig.subplots_adjust(bottom=0.4)
    plt.rcParams['font.size'] = 24
    ax = [fig.add_subplot(141),
        fig.add_subplot(142),fig.add_subplot(143), fig.add_subplot(144)]
    
    for i,a in enumerate(ax):
        a.tick_params(left = False, right = False , labelleft = False , 
                    labelbottom = False, bottom = False) 

        a.set_title("(" + alphabet[i] + ")")
        
    ax[0].imshow(original_image, cmap = "Greys_r", norm = "symlog")

    im1 = ax[1].imshow(finished, cmap = "Greys_r", norm = "symlog")


    ax[2].imshow(original_image, cmap = "Greys_r", norm = "symlog")
    print("scatters:",scatters)
    if len(scatters) != 0:
        im2, = ax[2].plot(scatters[:,0],scatters[:,1], marker = "*", c = "#ff7f0e", ls='', markersize = 10)
    else:
        im2, = ax[2].plot([128],[128], marker = "*", c = "#ff7f0e", ls='', markersize = 10)

    im3 = ax[3].imshow(mask, cmap = "Greys_r", norm = "symlog")

    axthres = fig.add_axes([0.25, 0.3, 0.65, 0.03])
    threshold_slider = Slider(
        ax=axthres,
        label='Thresholding',
        valmin=0.00,
        valmax=0.3,
        valinit=threshold,
    )
    axgauss1 = fig.add_axes([0.25, 0.25, 0.65, 0.03])
    gauss1_slider = Slider(
        ax=axgauss1,
        label='Edge sensitivity',
        valmin=1,
        valmax=256,
        valinit=gauss1,
    )
    axgauss2 = fig.add_axes([0.25, 0.2, 0.65, 0.03])
    gauss2_slider = Slider(
        ax=axgauss2,
        label='Differnce of Gauss minuend',
        valmin=1,
        valmax=10,
        valinit=gauss2,
    )
    axgauss3 = fig.add_axes([0.25, 0.15, 0.65, 0.03])
    gauss3_slider = Slider(
        ax=axgauss3,
        label='Differnce of Gauss subtrahend',
        valmin=1,
        valmax=10,
        valinit=gauss3,
    )
    axgauss4 = fig.add_axes([0.25, 0.1, 0.65, 0.03])
    gauss4_slider = Slider(
        ax=axgauss4,
        label='Gauss blur',
        valmin=1,
        valmax=10,
        valinit=gauss4,
    )

    sliders.append(threshold_slider)
    sliders.append(gauss1_slider)
    sliders.append(gauss2_slider)
    sliders.append(gauss3_slider)
    sliders.append(gauss4_slider)

    
    threshold_slider.on_changed(update)
    gauss1_slider.on_changed(update)
    gauss2_slider.on_changed(update)
    gauss3_slider.on_changed(update)
    gauss4_slider.on_changed(update)

    plt.show()

In [None]:
def blob_detection(image, parameters, radius = 3, background_value = 0.0):
    size = image.shape

    original_image = image.copy()    

    
    threshold, gauss1, gauss2, gauss3, gauss4 = parameters

    kern = gkern(sig = gauss1)
    image = image / (kern / np.max(kern))
    

    gauss = gaussian_filter(image, sigma=gauss2)
    gauss  -= gaussian_filter(image, sigma=gauss3)

    image = gaussian_filter(gauss, sigma=gauss4)

    scatters = detect_points(size[0], image,distance_threshold=40, derivative_threshold = threshold)
 
    mask = np.zeros(size)
    for point in scatters:
        mask += create_circular_mask(size[0], size[1], center = point, radius = radius) 

    mask = np.where(mask > 0.0, original_image, background_value)
    return mask

In [None]:
signal = hs.load(..., lazy=False)


In [None]:
x, y = np.meshgrid(np.arange(256), np.arange(256))
mask = np.zeros((256,256))
mask[(x - 128)**2 + (y - 128)**2 < 65] = 1
data = np.sum(np.where(mask == 1, signal.data, 0.0),axis = (2,3))


In [None]:
plt.figure()
plt.imshow(data, cmap = "Greys_r")
plt.axis("off")

In [None]:
signal.plot(norm = "symlog")
fig, ax = plt.subplots()



ax.imshow(data, cmap = "Greys_r")
plt.axis("off")
def on_move(event):
    if event.inaxes:
        pass

x_points = [[]]
y_points = [[]]
curves = 1
line = None
def on_click(event):
    global line, curves
    update = False
    if event.button is MouseButton.LEFT:
        if event.xdata is None or event.ydata is None:
            return 
        x_points[-1].append(event.xdata)
        y_points[-1].append(event.ydata)
        update = True

    if event.button is MouseButton.RIGHT:
        x_points[-1].append(x_points[-1][0])
        y_points[-1].append(y_points[-1][0])
        
        x_points.append([])
        y_points.append([])
        
        curves += 1
        update = True
    if update:
        ax.cla()
        ax.imshow(data, cmap = "Greys_r")
        plt.axis("off")
        for c in range(curves):
            if c == curves - 1:
                ax.plot(x_points[c], y_points[c])
            else:
                ax.fill(x_points[c], y_points[c], alpha = 0.3)

        fig.canvas.draw()
        
        plt.disconnect(binding_id)

binding_id = plt.connect('motion_notify_event', on_move)
plt.connect('button_press_event', on_click)

plt.show()

In [None]:
poly = [Polygon([(x_points[i][c], y_points[i][c]) for c in range(len(x_points[i]))]) for i in range(curves - 1)]
map = np.zeros((256,128))
for i in tqdm(range(256)):
    for j in range(256):
        #c = colors[0]
        for k in range(curves - 1):
            if Point(j,i).within(poly[k]):
                map[i,j] = k + 1
                break
np.save(..., map)

In [None]:
parameters = [[0.0,0.0,0.0,0.0,0.0]]*(curves - 1)

In [None]:
i = 0

s = signal.inav[int(poly[i].centroid.x),int(poly[i].centroid.y)]
s.compute()
image = s.data
processed_image, mask, scatters = blob_detection(image, (threshold, gauss1, gauss2, gauss3, gauss4))
plot(image, processed_image, mask, scatters)
parameters[i] = [threshold, gauss1, gauss2, gauss3, gauss4]

In [None]:
parameters[i] = [threshold, gauss1, gauss2, gauss3, gauss4]

In [None]:
np.save(..., parameters)

In [None]:
def convourt_entire_signal(signal_data, parameters, points):
    return [blob_detection(signal_data[x,y], parameters) for x in points[0] for y in points[1]] 
    
def convourt_entire_signal_parallel(signal_data, parameters, points):
    all_blobs = Parallel(n_jobs=-1)(delayed(blob_detection)(signal_data[points[0][i],points[1][i]], parameters) for i in range(len(points[0])))
    return all_blobs

In [None]:
results = [None]*(len(parameters))
points = [np.where(map == c) for c in range(1,len(parameters)+1)]

In [None]:
for c in range(len(parameters)-1):
    start_time = time.time()
    results[c] = convourt_entire_signal_parallel(signal.data, parameters[c], points[c])
    end_time = time.time()
    print("Time taken:", end_time - start_time)


In [None]:
with open(..., 'wb') as handle:
    pickle.dump(results, handle)