# Pixel classification using Scikit image
Pixel classification is a technique for assigning pixels to multiple classes. If there are two classes (object and background), we are talking about binarization. In this example we use a [random forest classifier](https://en.wikipedia.org/wiki/Random_forest) for pixel classification.

See also
* [Scikit-image random forest](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html)
* [Classification of land cover by Chris Holden](https://ceholden.github.io/open-geo-tutorial/python/chapter_5_classification.html)


In [None]:
# !pip install scikit-image
# !pip install matplotlib
# !pip install nd2
# !pip install -U scikit-learn

In [None]:
from skimage.io import imsave, imread
import os
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
import pickle

In [None]:
# path to the folder containing the images of a certain channel

filepath = '/media/eric/WD_Elements/EPFL-data/Lucas/'

fn_orange = 'orange_20230920_AirGel_WTmSc_dwspFmN_mix_LAM02J-L_CIP_5.5hpi_23.5h.nd...el_WTmSc_dwspFmN_mix_LAM02J-L_CIP_5.5hpi_23.5h.nd2 (series 08) - C=1.tif'


orange = filepath + fn_orange

# if orange endswith .nd2 then import pims_nd2
if orange.endswith('.nd2'):
    !pip install pims_nd2
    import pims_nd2 as pims

else:
    !pip install pims
    import pims

In [None]:
from skimage import filters

def generate_feature_stack(image):
    # determine features
    blurred = filters.gaussian(image, sigma=2)
    edges = filters.sobel(blurred)

    # collect features in a stack
    # The ravel() function turns a nD image into a 1-D image.
    # We need to use it because scikit-learn expects values in a 1-D format here.
    feature_stack = [
        image.ravel(),
        blurred.ravel(),
        edges.ravel()
    ]

    # return stack as numpy-array
    return np.asarray(feature_stack)


# Use napari for visualization

In [None]:
import napari
%gui qt



def MakePrediction(img, img_shape,model):
    # process the whole image and show result

    print(img.T.shape)
    result_1d = model.predict(img.T)
    result_3d = result_1d.reshape(img_shape)
    result_3d = result_3d
    result_3d[result_3d ==3]=0 # make background 0
    #delete unused variables
    del result_1d, img, img_shape
    return result_3d

def ShowNapari(img, prediction):
    viewer = napari.Viewer()
    viewer.add_image(img)
    labels = viewer.add_labels(prediction)

    viewer.show(block=True)

    viewer.close()




In [None]:
model_filename = '../models/RFC_3D/model_for_3D_data.pkl'
model = pickle.load(open(model_filename, 'rb'))
segmentation_output_dir = '../data_3D/segmentation_output/'

if not os.path.exists(segmentation_output_dir):
    os.makedirs(segmentation_output_dir)


orange_seq = pims.open(orange)
Z_slices = 19
xdim = 1200
ydim = 1200

timestep = 0
counter = 0
for t, img in enumerate(orange_seq):    

    # iterate through all z slices and stack them together
    if np.mod(t,Z_slices) == 0:
        # prepare empty array with the same dtype as the first image
        counter = 0
        orange_zstack = np.zeros((Z_slices,xdim,ydim))

    orange_zstack[counter,:,:] = img

    if np.mod(t,Z_slices) == Z_slices-1:
        print("Timestep: ", timestep)

        # now there is a 3D image stack of 19 slices
        print(orange_zstack.shape)

        feature_stack = generate_feature_stack(orange_zstack)
        img_shape = orange_zstack.shape


        
        prediction = MakePrediction(feature_stack, img_shape,model)
        del feature_stack
        # ShowNapari(orange_zstack, prediction)
        del orange_zstack

        imsave(segmentation_output_dir + 'T_'+str(timestep).zfill(3)+  '_' + fn_orange , prediction,check_contrast=False)
        del prediction

        timestep += 1

    counter += 1



could be tested later

In [None]:
import matplotlib.pyplot as plt
hist1 = plt.hist(img[result_2d==1],50,alpha=0.5,color='c')
hist2 = plt.hist(img[result_2d==2],50,alpha=0.5,color='k')
hist2 = plt.hist(img[result_2d==0],50,alpha=0.5,color='m')
# plt.ylim(0,40000)
# plt.xlim(-5,10)