# Pixel classification with Random Forest plugin

Create a plugin to interactively label and classify pixels with a random forest

1) Read images
2) Extract features from each pixel
3) Label images in napari
4) Train a random forest classifier
5) Interactively display results

In [5]:
from skimage import data
import numpy as np
import matplotlib.pyplot as plt

import napari
from magicgui import magicgui

In [6]:
IMAGE3D = data.cells3d()[:,1]

In [7]:
def get_image_features(img):
    from skimage import filters
    img_blurred = filters.gaussian(img, sigma=2)
    img_edges = filters.sobel(img_blurred)

    img_features = [img.ravel(), img_blurred.ravel(), img_edges.ravel()]

    return np.asarray(img_features).T

def get_training_data(img_features, labels):
    X = img_features
    y = labels.ravel()

    X = X[y > 0]
    y = y[y > 0]

    return X, y

In [8]:
@magicgui
def random_forest_pixel_classifier(
    image: 'napari.layers.Image', 
    labels: 'napari.layers.Labels') -> 'napari.types.LabelsData':

    from sklearn.ensemble import RandomForestClassifier
    
    img_features = get_image_features(image.data)
    
    X, y = get_training_data(img_features, labels.data)
    
    clf = RandomForestClassifier(max_depth=2, random_state=1)
    clf.fit(X, y)

    return clf.predict(img_features).reshape(image.data.shape)

viewer = napari.view_image(IMAGE3D)
viewer.window.add_dock_widget(random_forest_pixel_classifier)

<napari._qt.widgets.qt_viewer_dock_widget.QtViewerDockWidget at 0x7f5ebd648eb0>