In [None]:
%load_ext autoreload
%autoreload 2

import glob
from fibsem.segmentation.model import SegmentationModel
import tifffile as tf
import matplotlib.pyplot as plt


from random import shuffle

import numpy as np
from fibsem.detection.utils import Feature, FeatureType, DetectionResult
from autoscript_sdb_microscope_client.structures import AdornedImage

from fibsem.structures import Point
from fibsem.imaging import masks
from fibsem.detection import detection
import skimage

from pathlib import Path


from pprint import pprint


from fibsem.segmentation.model import load_model

## Detection Goals


1. Detect Needle Tip
2. Detect Lamella Centre
3. Detect Lamella Edges (Right / Left, Up / Down)

account for multiple lamellas
mask centre?

In [None]:


# data
from dataclasses import dataclass

from fibsem import conversions


filenames = glob.glob("/home/patrick/github/data/training/images/*.tif")
print(len(filenames))

# model
checkpoint = "/home/patrick/github/fibsem/fibsem/segmentation/models/model.pt"
model = load_model(checkpoint)


@dataclass
class DetectedFeatures:
    features: list[Feature]
    image: np.ndarray
    mask: np.ndarray
    pixelsize: float
    distance: Point


def to_bounding_box(contour):
    # convert a contour to bounding box (xc, yc, w, h)
    # ref : https://muthu.co/draw-bounding-box-around-contours-skimage/

    xmin, xmax = np.min(contour[:, 1]), np.max(contour[:, 1])
    ymin, ymax = np.min(contour[:, 0]), np.max(contour[:, 0])

    w = (xmax - xmin)
    h = (ymax - ymin)
    xc = xmin + w // 2
    yc = ymin + h // 2
    
    return [xc, yc, w, h]
    

def detect_features_v2(img: np.ndarray, mask: np.ndarray, features: tuple[Feature]) -> list[Feature]:

    detection_features = []

    for feature in features:
        
        det_type = feature.detection_type
        initial_point = feature.feature_px
        
        if not isinstance(det_type, FeatureType):
            raise TypeError(f"Detection Type {det_type} is not supported.")

        # get the initial position estimate
        if initial_point is None:
            initial_point = Point(x=img.shape[1]//2, y=img.shape[0]//2)

        if det_type == FeatureType.ImageCentre:
            feature_px = initial_point

        if det_type == FeatureType.NeedleTip:
            feature_px = detection.detect_needle_v4(mask)

        if det_type in [FeatureType.LamellaCentre, FeatureType.LamellaLeftEdge, FeatureType.LamellaRightEdge]:
            feature_px = detection.detect_lamella(mask, det_type)

        if det_type == FeatureType.LandingPost:
            feature_px = detection.detect_landing_post_v3(img, initial_point)

        detection_features.append(
            Feature(detection_type=det_type, feature_px=feature_px)
        )

    return detection_features

def locate_shift_between_features_v2(image: np.ndarray, model: SegmentationModel, features: tuple[Feature], pixelsize: float) -> DetectedFeatures:

    # model inference
    mask = model.inference(image)

    # detect features 
    feature_1, feature_2 = detect_features_v2(image, mask, features)

    # calculate distance between features
    distance_px = conversions.distance_between_points(feature_1.feature_px, feature_2.feature_px)
    distance_m = conversions.convert_point_from_pixel_to_metres(distance_px, pixelsize)

    det = DetectedFeatures(
        features=[feature_1, feature_2],
        image = image,
        mask = mask,
        distance = distance_m,
        pixelsize = pixelsize
    )

    return det

def plot_det_result_v2(det: DetectedFeatures):

    fig, ax = plt.subplots(1, 2, figsize=(12, 7))
    ax[0].imshow(det.image, cmap="gray")
    ax[0].set_title(f"Image")
    ax[1].imshow(det.mask)
    ax[1].set_title("Prediction")
    ax[1].plot(det.features[0].feature_px.x, det.features[0].feature_px.y, "g+", ms=20, label=det.features[0].detection_type.name)
    ax[1].plot(det.features[1].feature_px.x, det.features[1].feature_px.y, "w+", ms=20, label=det.features[1].detection_type.name)
    ax[1].plot([det.features[0].feature_px.x, det.features[1].feature_px.x], [det.features[0].feature_px.y, det.features[1].feature_px.y], "w--")
    ax[1].legend(loc="best")
    plt.show()


# THINGS TO TRY:
# masking centre area for lamella
# using contours to extract individual lamellas -> pick centre


shuffle(filenames)
for i, fname in enumerate(filenames):

    img = tf.imread(fname)

    # inference
    mask = model.inference(img)


    # detect features
    features = [Feature(FeatureType.NeedleTip), 
                    Feature(FeatureType.ImageCentre)]
    det = locate_shift_between_features_v2(img, model, features=features, pixelsize=10e-9)

    # plot
    plot_det_result_v2(det)


    if i == 5:
        break


In [None]:


# mask helper


# centre circle
# left/right half
# top/bottom half


from fibsem.imaging import masks
import numpy as np


arr = np.zeros(shape=(720, 1200 ))

circ_mask = masks.create_circle_mask(arr.shape, radius=128)


bl_mask = masks.get_area_mask(arr, left=True, lower=True)
ul_mask = masks.get_area_mask(arr, left=True, upper=True)
br_mask = masks.get_area_mask(arr, left=True, lower=True)
ur_mask = masks.get_area_mask(arr, right=True, upper=True)
mask = masks.get_area_mask(arr, left=True)


fig, ax = plt.subplots(1, 5, figsize=(15, 7))
ax[0].imshow(bl_mask)
ax[1].imshow(ul_mask)
ax[2].imshow(br_mask)
ax[3].imshow(ur_mask)
ax[4].imshow(mask)
plt.show()



In [None]:
# contour stuff
    # options for getting individual lamella from multiple: contours, masking 
    # bboxes = []
    # contours = skimage.measure.find_contours(lamella_mask[:, :, 0].astype(np.uint8), 0.8)
    # for contour in contours:
    #     bboxes.append(to_bounding_box(contour))
        # for contour in contours:
    #     ax[1].plot(contour[:, 1], contour[:, 0], color="white", linewidth=1)
    