# Import

In [None]:
import cv2
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.utils.data
import os
from PIL import Image
from scipy.ndimage import label

In [None]:
from autoscript_tem_microscope_client import TemMicroscopeClient
from autoscript_tem_microscope_client.enumerations import *
from autoscript_tem_microscope_client.structures import *

# Acquire image

In [None]:
microscope = TemMicroscopeClient()
microscope.connect()
image = microscope.acquisition.acquire_stem_image(DetectorType.HAADF, ImageSize.PRESET_512, 1e-6) # Acquire the image

# What happens if we don't put disconnect (?)
microscope.disconnect()
''' Code for plot image'''
plt.figure(figsize=(8, 6))
plt.imshow(image, cmap='gray')
plt.title("Image")
plt.axis('off')
plt.show()

' Code for plot image'

# Nps detection

In [None]:
files = [file for file in os.listdir() if file.endswith('pt')]
print(files)
model_list = []
for file in files:
    model_list.append(torch.load(file))

In [None]:
model_path = '''model path/name'''
model = torch.load(model_path)

In [None]:
def isolate_particles(pred_mask, threshold=0.5):
    binary_mask = (pred_mask > threshold).astype(np.uint8)  # Convierte la máscara de probabilidades en una máscara binaria
    labeled_mask, num_particles = label(binary_mask) # Function from scipy.ndimage, busca grupos de píxeles conectados entre sí
    return labeled_mask, num_particles

In [None]:
def get_square_bbox(indices, image_shape, ratio = 1, min_size = 0.2):
    rows, cols = indices
    if len(rows) == 0 or len(cols) == 0: 
        return None # Clean the list if there is no particle
    min_row, max_row = np.min(rows), np.max(rows) # Lookf or the limits
    min_col, max_col = np.min(cols), np.max(cols)
    width = max_col - min_col
    height = max_row - min_row
    side = int(max(width, height) * ratio)
    if side <= min_size:
        return None

    # center of the particle
    row_center = (min_row + max_row) // 2
    col_center = (min_col + max_col) // 2
    half_side = side // 2

    # Calculate the top left square to put the square
    start_row = max(row_center - half_side, 0) 
    start_col = max(col_center - half_side, 0)

    # This assures that the box doesn't go out of the image, from Chat-GPT, check if there is a better option
    if start_row + side > image_shape[0]:
        start_row = image_shape[0] - side
    if start_col + side > image_shape[1]:
        start_col = image_shape[1] - side
        
    return (start_row, start_col, side)

In [None]:
def plot_particle(image, bbox, title="Partícula"):
    r, c, s = bbox
    particle_img = image[r:r+s, c:c+s]
    plt.figure()
    plt.imshow(particle_img, cmap='gray')
    plt.title(title)
    plt.axis('off')
    plt.show()

In [None]:
''' 
Code for detect and plot:
    Global image
    Prediction
    Global image with squares
    Crop NP 1
    Crop NP 2
    ...
'''
def detect_and_plot(microscope, th = 0.5):

    image = microscope.acquisition.acquire_stem_image(DetectorType.HAADF, ImageSize.PRESET_512, 1e-6) # Acquire the image
    pred = model.predict(image)
        
    pred_mask = np.squeeze(pred[0]) # From ChatGPT, maybe there is a better option (?)
    labeled_mask, num_particles = isolate_particles(pred_mask, threshold=th) # From ChatGPT, maybe there is a better option (?)
    image_color = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR) # From ChatGPT, maybe there is a better option (?)
        
    bboxes = []
    for x in range(1, num_particles + 1):
        particle_indices = np.where(labeled_mask == x)
        bbox = get_square_bbox(particle_indices, image.shape, 2, 0.2)
        if bbox is not None:
            bboxes.append(bbox)
            start_row, start_col, side = bbox
            cv2.rectangle(image_color, (start_col, start_row), (start_col + side, start_row + side), (0, 0, 255), 2)
            plot_particle(image, bbox, title=f"Partícula {x}")
            
    fig, axes = plt.subplots(1, 3, figsize=(10, 5))
        
    print(f"Num de partículas detectadas: {num_particles}")
        
    axes[0].imshow(image, cmap='gray')
    axes[0].set_title("Imagen original")
    axes[0].axis('off') 
        
    axes[1].imshow(labeled_mask, cmap='gray')
    axes[1].set_title("Predicción del modelo")
    axes[1].axis('off')
        
    axes[2].imshow(image_color)
    axes[2].set_title("Imagen con partículas en cuadrados")
    axes[2].axis('off')
        
    plt.tight_layout()
    plt.show()

' \nCode for detect and plot:\n    Global image\n    Prediction\n    Global image with squares\n    Crop NP 1\n    Crop NP 2\n    ...\n'

In [None]:
microscope = TemMicroscopeClient()
microscope.connect()
detect_and_plot(microscope)

# What happens if we don't put disconnect (?)
microscope.disconnect()

# NPs acquisition

In [None]:
'''
Global code where:
    Take the image
    Detect the Nps
    Move to each NP and zoom in for acquire NP image 
    Save the images
'''
def detect_and_save(microscope, index = 1, th = 0.5):
    image = microscope.acquisition.acquire_stem_image(DetectorType.HAADF, ImageSize.PRESET_512, 1e-6) # Acquire the image

    pred, zones = model.predict(image) # Prediction and centered zones using our model
                
    pred_mask = np.squeeze(pred[0])
    labeled_mask, num_particles = isolate_particles(pred_mask, threshold=th)
    image_color = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)

    directory = f"Images/{index:04d}_Image"
    os.makedirs(directory, exist_ok=True)

    previousX, previousY = None, None

    image.save(os.path.join(directory, f"image_{index:04d}.png"))
    labeled_mask.save(os.path.join(directory, f"prediction_{index:04d}.png"))

    if(num_particles>0):    
        bboxes = []
        for x in range(1, num_particles + 1):
            particle_indices = np.where(labeled_mask == x)
            bbox = get_square_bbox(particle_indices, image.shape, 2, 0.2)
            if bbox is not None:
                bboxes.append(bbox)
                start_row, start_col, side = bbox
                cv2.rectangle(image_color, (start_col, start_row), (start_col + side, start_row + side), (0, 0, 255), 2)

        image_color.save(os.path.join(directory, f"detection_{index:04d}.png"))       
        
        initial_position = microscope.specimen.stage.position

        for i, zone in enumerate(zones):
            if previousX is not None and previousY is not None:
                # Maybe we should use specimen.stage.move_to_pixel, ask to Ivan
                microscope.specimen.stage.relative_move(StagePosition(x=zone.x - previousX, y=zone.y - previousY))
            else:
                fov = microscope.optics.scan_field_of_view
                extra = fov / 2
                microscope.specimen.stage.relative_move(StagePosition(x=zone.x - extra, y=zone.y - extra))

            previousX, previousY = zone.x, zone.y
            centered_image = microscope.acquisition.acquire_stem_image(DetectorType.HAADF, ImageSize.PRESET_512, 1e-6)
            centered_image.save(os.path.join(directory, f"particle_{i:04d}.png"))
    
    microscope.specimen.stage.absolute_move_safe(initial_position)
    # image = cv2.imread(path, cv2.IMREAD_GRAYSCALE) # This returns a image in grayscale
    # Maybe we should show a message if the image is not found, ask to Ivan
    return

In [None]:

microscope = TemMicroscopeClient()
microscope.connect()
microscope.optics.optical_mode = OpticalMode.STEM
    
detect_and_save(microscope)
    
microscope.disconnect()

# NPs Acquisition with Movement

In [None]:
def build_spiral_coordinates(total_cells = 12):
    coord_initial = []
    directions = [(0, 1), (-1, 0), (0, -1), (1, 0)]
    direction_index = 0
    step_count = 0
    step_limit = 1
    direction_changes = 0

    while len(coord_initial) < total_cells:

        coord_initial.append(directions[direction_index])
        step_count += 1

        if step_count == step_limit:
            step_count = 0
            direction_index = (direction_index + 1) % 4
            direction_changes += 1
            
            if direction_changes % 2 == 0:
                step_limit += 1

    return coord_initial

In [None]:
def movement(grid_x, grid_y, step_size, microscope):
    for i in range (0,2):
            microscope.specimen.stage.relative_move(StagePosition(x=grid_x * step_size, y=grid_y * step_size))
            detect_and_plot(microscope)

In [None]:
def main():
    microscope = TemMicroscopeClient()
    microscope.connect()
    microscope.optics.optical_mode = OpticalMode.STEM

    num_images = 1000
    step_size = 0.0001

    total_steps = num_images/2
    movement_direction = build_spiral_coordinates(total_cells=total_steps)

    for (x, y) in movement_direction:
        movement(x, y, step_size, microscope)
    
    microscope.disconnect()