# Creating a napari-convpaint segmentation model

In [None]:
import numpy as np
import skimage
from napari_convpaint.conv_paint_model import ConvpaintModel

image = skimage.io.imread("Data/microscope/aegypti/ag_06.tif")
annotations = skimage.io.imread("Data/annotations_ag_06.tif")

model = ConvpaintModel(
    fe_name="efficient_netb0",
    fe_use_cuda=True,
    fe_layers=[1],
    fe_scalings=[1]
)

model.set_params(
    image_downsample=1,
    tile_annotations=True,
    tile_image=True,
    clf_iterations=100,
    seg_smoothening=3
)

annotations = annotations.astype(np.uint8)

model.train(image.transpose(2, 0, 1), annotations)

0:	learn: 0.4344805	total: 59.3ms	remaining: 5.87s
1:	learn: 0.2817036	total: 66.9ms	remaining: 3.28s
2:	learn: 0.1774183	total: 75ms	remaining: 2.42s
3:	learn: 0.1273576	total: 82.6ms	remaining: 1.98s
4:	learn: 0.0807750	total: 89.9ms	remaining: 1.71s
5:	learn: 0.0560956	total: 98.3ms	remaining: 1.54s
6:	learn: 0.0452980	total: 107ms	remaining: 1.42s
7:	learn: 0.0375513	total: 114ms	remaining: 1.31s
8:	learn: 0.0315362	total: 122ms	remaining: 1.23s
9:	learn: 0.0251572	total: 130ms	remaining: 1.17s
10:	learn: 0.0222977	total: 138ms	remaining: 1.12s
11:	learn: 0.0203142	total: 146ms	remaining: 1.07s
12:	learn: 0.0182062	total: 154ms	remaining: 1.03s
13:	learn: 0.0167104	total: 162ms	remaining: 996ms
14:	learn: 0.0158393	total: 171ms	remaining: 969ms
15:	learn: 0.0149588	total: 179ms	remaining: 940ms
16:	learn: 0.0142842	total: 187ms	remaining: 911ms
17:	learn: 0.0137138	total: 194ms	remaining: 884ms
18:	learn: 0.0132466	total: 202ms	remaining: 861ms
19:	learn: 0.0127415	total: 210ms	rem

<catboost.core.CatBoostClassifier at 0x7fe498949340>

In [13]:
model_name = "Enet_vx"
model.save("Models/" + model_name + ".pkl", create_yml=False)

# Testing a segmentation model

In [None]:
from napari_convpaint.conv_paint_model import ConvpaintModel
import matplotlib.pyplot as plt

model_name = "Enet_v7"
image = skimage.io.imread('Data/microscope/aegypti/ag_07.tif')

model = ConvpaintModel("Models/" + model_name + ".pkl")
image = np.moveaxis(image, -1, 0)
segment = model.segment(image)
plt.imshow(segment, cmap='gray')

# Segmenting a picture

In [None]:
import numpy as np
from skimage.measure import label, regionprops
from skimage.morphology import remove_small_objects, remove_small_holes, dilation, footprint_rectangle, erosion
from skimage.util import img_as_ubyte

rec = 25

mask_cleaned = remove_small_holes(segment == 2, area_threshold=5000)
labeled_overlay = label(mask_cleaned)
labeled_overlay = erosion(labeled_overlay, footprint=footprint_rectangle((rec, rec)))
labeled_overlay = remove_small_objects(labeled_overlay, min_size=20000)
labeled_overlay = dilation(labeled_overlay, footprint_rectangle((rec, rec)), mode='ignore')

regions = regionprops(labeled_overlay)
plt.imshow(labeled_overlay, cmap='gray')

segmented_image = []

for i, region in enumerate(regions):
    # Bounding Box extrahieren
    minr, minc, maxr, maxc = region.bbox

    # Ausschnitt des Bildes und der Maske
    cropped_image = image[minr:maxr, minc:maxc]
    mask = labeled_overlay[minr:maxr, minc:maxc] == region.label

    # Maske anwenden (falls nötig, z. B. für transparente Bereiche)
    masked_image = np.zeros_like(cropped_image)
    for c in range(cropped_image.shape[2]):
        masked_image[..., c] = cropped_image[..., c] * mask

    segmented_image.append(masked_image)

## Segmenting a lot of pictures


In [None]:
import skimage
import os
import glob
import pandas as pd
from skimage.measure import regionprops
from skimage.morphology import remove_small_holes, remove_small_objects, dilation, footprint_rectangle
from skimage.util import img_as_ubyte

model = ConvpaintModel(model_path="Models/Enet_v7.pkl")
image_paths = glob.glob("Data/microscope/**/*.tif", recursive=True)

images = []
labels = []
segments = []
segmented_image = []
data = []
rec = 25

for path in image_paths:
    img = skimage.io.imread(path)
    images.append(img)
    labell = os.path.basename(os.path.dirname(path))
    labels.append(labell)

for i, image in enumerate(images):
    segment = segmentation(image, model)
    segments.append(segment)
    mask_cleaned = remove_small_holes(segment == 2, area_threshold=5000)
    labeled_overlay = label(mask_cleaned)
    labeled_overlay = erosion(labeled_overlay, footprint=footprint_rectangle((rec, rec)))
    labeled_overlay = remove_small_objects(labeled_overlay, min_size=20000)
    labeled_overlay = dilation(labeled_overlay, footprint_rectangle((rec, rec)), mode='ignore')
    regions = regionprops(labeled_overlay)
    
    for j, region in enumerate(regions):
        minr, minc, maxr, maxc = region.bbox
        cropped_image = image[minr:maxr, minc:maxc]
        mask = labeled_overlay[minr:maxr, minc:maxc] == region.label
        masked_image = np.zeros_like(cropped_image)
        for c in range(cropped_image.shape[2]):
            masked_image[..., c] = cropped_image[..., c] * mask
        segmented_image.append(masked_image)
        image_gray = skimage.color.rgb2gray(masked_image)

        skimage.io.imsave(f"Data/segmented_image/image{i}_segment_{j}.png", img_as_ubyte(masked_image))
        skimage.io.imsave(f"Data/segmented_mask/mask{i}_segment_{j}.png", img_as_ubyte(mask))

        angle = region.orientation
        area = region.area
        perimeter = region.perimeter
        roundness = 4 * np.pi * area / (perimeter ** 2) if perimeter != 0 else 0
        length = region.axis_major_length
        centroid = region.centroid
        centroid_local = (centroid[0] - minr, centroid[1] - minc)
        width = region.axis_minor_length
        ratio = length / width
        laplacian = measure_sharpness(image_gray)
        edge = gradient_sharpness(image_gray)

        data.append({
            'image': i,
            'segment': j,
            'area': area,
            'perimeter': perimeter,
            'roundness': roundness,
            'length': length,
            'width': width,
            'len_wid_ratio': ratio,
            'laplacian' : laplacian,
            'edge': edge
        })

df = pd.DataFrame(data)

# Creating a model for identifying and classifying segments

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report, multilabel_confusion_matrix
from sklearn.model_selection import GridSearchCV

model = RandomForestClassifier()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

param_grid = {
    "n_estimators": [10, 50, 100, 200, 500],
    "criterion": ["gini", "entropy", "log_loss"],
    "max_features": ["sqrt", "log2"],
    "max_depth": [None, 5, 10],
    "min_samples_split": [2, 5, 10],
    "min_samples_leaf": [1, 2, 4],
    "bootstrap": [True, False]
}

grid_search_forest = GridSearchCV(model, param_grid, cv=5, scoring='precision', n_jobs=-1, verbose=2)
grid_search_forest.fit(X, y)
best_forest_model = grid_search_forest.best_estimator_
y_pred = best_forest_model.predict(X_test)

In [None]:
report = classification_report(y_test, y_pred, target_names=y.columns, output_dict=True)
report_df = pd.DataFrame(report).transpose()

In [None]:
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd

combined_y_test = y_test.apply(lambda row: ''.join(row.astype(str)), axis=1)
combined_y_pred = pd.DataFrame(y_pred, columns=y.columns).apply(lambda row: ''.join(row.astype(str)), axis=1)

conf_matrix_combined = confusion_matrix(combined_y_test, combined_y_pred)

labels_sorted = sorted(set(combined_y_test) | set(combined_y_pred))

plt.figure(figsize=(10, 8))
sns.heatmap(conf_matrix_combined, annot=True, fmt='d', cmap='Blues',
            xticklabels=labels_sorted, yticklabels=labels_sorted)
plt.xlabel("Vorhergesagte Label-Kombination")
plt.ylabel("Tatsächliche Label-Kombination")
plt.title("Multilabel Confusion Matrix (kombinierte Klassen)")
plt.tight_layout()
plt.show()

In [None]:
from joblib import dump
dump(best_forest_model, 'Models/segmentation_identifier_random_forest.joblib')

In [None]:
from joblib import load

loaded_model = load('Models/segmentation_identifier_random_forest.joblib')