# Pixel classification using Scikit image
Pixel classification is a technique for assigning pixels to multiple classes. If there are two classes (object and background), we are talking about binarization. In this example we use a [random forest classifier](https://en.wikipedia.org/wiki/Random_forest) for pixel classification.

See also
* [Scikit-image random forest](https://scikit-learn.org/stable/modules/generated/sklearn.ensemble.RandomForestClassifier.html)
* [Classification of land cover by Chris Holden](https://ceholden.github.io/open-geo-tutorial/python/chapter_5_classification.html)


In [None]:
# !pip install scikit-image
# !pip install matplotlib
# !pip install nd2
# !pip install -U scikit-learn

In [1]:
from skimage.io import imsave, imread
import os
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV


# Define input path

In [2]:
# path to the folder containing the images of a certain channel

filepath = '/media/eric/WD_Elements/EPFL-data/Lucas/'

fn_orange = 'orange_20230920_AirGel_WTmSc_dwspFmN_mix_LAM02J-L_CIP_5.5hpi_23.5h.nd...el_WTmSc_dwspFmN_mix_LAM02J-L_CIP_5.5hpi_23.5h.nd2 (series 08) - C=1.tif'
fn_green = 'green_20230920_AirGel_WTmSc_dwspFmN_mix_LAM02J-L_CIP_5.5hpi_23.5h.nd...el_WTmSc_dwspFmN_mix_LAM02J-L_CIP_5.5hpi_23.5h.nd2 (series 08) - C=0.tif'


orange = filepath + fn_orange
green = filepath + fn_green

# if orange endswith .nd2 then import pims_nd2
if orange.endswith('.nd2'):
    import pims_nd2 as pims

else:
    import pims


In [7]:
import napari
%gui qt

def Draw_and_Save(img,output_dir,output_fn):
    viewer = napari.Viewer()
    viewer.add_image(img)
    labels = viewer.add_labels(np.zeros(img.shape,dtype='uint8'))


    viewer.show(block=True)

    # get the drawings first
    manual_annotations = labels.data

    # save the drawings
    manual_labels_filename = output_dir + '/' + output_fn
    imsave(manual_labels_filename, manual_annotations, check_contrast=False)
    viewer.close()


# Prepare GT and save it together with the image

In [3]:
output_dir = '../data_3D/'


In [4]:


# create output directory
import os
if not os.path.exists(output_dir):
    os.makedirs(output_dir)

if not os.path.exists(output_dir + "orange"):
    os.makedirs(output_dir + "orange/GT")

if not os.path.exists(output_dir + "green"):
    os.makedirs(output_dir + "green/GT")

if not os.path.exists(output_dir + "orange/RAW"):
    os.makedirs(output_dir + "orange/RAW")

if not os.path.exists(output_dir + "green/RAW"):
    os.makedirs(output_dir + "green/RAW")


orange_seq = pims.open(orange)
green_seq = pims.open(green)
Z_slices = 19
xdim = 1200
ydim = 1200

timestep = 0
counter = 0
for t, (img_o, img_g) in enumerate(zip(orange_seq,green_seq)):
    print("Timestep: ", timestep)

    # iterate through all z slices and stack them together
    if np.mod(t,Z_slices) == 0:
        # prepare empty array with the same dtype as the first image
        counter = 0
        orange_zstack = np.zeros((Z_slices,xdim,ydim))
        green_zstack = np.zeros((Z_slices,xdim,ydim))

    orange_zstack[counter,:,:] = img_o
    green_zstack[counter,:,:] = img_g

    if np.mod(t,Z_slices) == Z_slices-1:

        # now there is a 3D image stack of 19 slices
        print(orange_zstack.shape)
        print(green_zstack.shape)

        # start napari, save GT and RAW images
        # keep the basename of fn
        fn_orange = os.path.basename(fn_orange)
        fn_orange = os.path.splitext(fn_orange)[0]

        Draw_and_Save(orange_zstack, output_dir + "orange/GT", fn_orange + '_T_'+str(timestep).zfill(3)+'.tif')
        imsave(output_dir + "orange/RAW/" + fn_orange + '_T_'+str(timestep).zfill(3)+'.tif', orange_zstack,check_contrast=False)

        Draw_and_Save(green_zstack, output_dir + "green/GT", fn_green + '_T_'+str(timestep).zfill(3)+'.tif')
        imsave(output_dir + "green/RAW/" + fn_green + '_T_'+str(timestep).zfill(3)+'.tif', green_zstack,check_contrast=False)


        timestep += 1

    counter += 1



KeyboardInterrupt: 

# Now work with the GT to train the model

In [6]:


# Now go through all existing GT files that contain non-zero pixels
green_gt_fn = os.listdir(output_dir + "green/GT")
orange_gt_fn = os.listdir(output_dir + "orange/GT")



## Generating a feature stack
Pixel classifiers such as the random forest classifier takes multiple images as input. We typically call these images a feature stack because for every pixel exist now multiple values (features). In the following example we create a feature stack containing three features:
* The original pixel value
* The pixel value after a Gaussian blur
* The pixel value of the Gaussian blurred image processed through a Sobel operator.

Thus, we denoise the image and detect edges. All three images serve the pixel classifier to differentiate positive an negative pixels.

In [9]:
from skimage import filters

def generate_feature_stack(image):
    # determine features
    blurred = filters.gaussian(image, sigma=2)
    edges = filters.sobel(blurred)

    # collect features in a stack
    # The ravel() function turns a nD image into a 1-D image.
    # We need to use it because scikit-learn expects values in a 1-D format here.
    feature_stack = [
        image.ravel(),
        blurred.ravel(),
        edges.ravel()
    ]

    # return stack as numpy-array
    return np.asarray(feature_stack)


## Formating data
We now need to format the input data so that it fits to what scikit learn expects. Scikit-learn asks for an array of shape (n, m) as input data and (n) annotations. n corresponds to number of pixels and m to number of features. In our case m = 3.

In [10]:
def format_data(feature_stack, annotation):
    # reformat the data to match what scikit-learn expects
    # transpose the feature stack
    X = feature_stack.T
    # make the annotation 1-dimensional
    y = annotation.ravel()

    # remove all pixels from the feature and annotations which have not been annotated
    mask = y > 0
    X = X[mask]
    y = y[mask]

    return X, y

In [11]:

# in case you have not much memory as I do, you should run only one channel at a time

channels = ['orange', 'green']
filename_array = [orange_gt_fn, green_gt_fn]


# make this to a for loop later
channel = channels[0]
filenames = filename_array[0]


X_stack = []
y_stack = []

for fn in filenames[:1]:
    print(fn)

    img = imread(output_dir + channel + "/RAW/" + fn)
    feature_stack = generate_feature_stack(img)
    X, y = format_data(feature_stack, img)
   
    X_stack.append(X)
    y_stack.append(y)

    # delete variables to save memory
    del X, y, img, feature_stack
    
X_stack = np.concatenate(X_stack)
y_stack = np.concatenate(y_stack)

orange_20230920_AirGel_WTmSc_dwspFmN_mix_LAM02J-L_CIP_5.5hpi_23.5h.nd...el_WTmSc_dwspFmN_mix_LAM02J-L_CIP_5.5hpi_23.5h.nd2 (series 08) - C=1.tif_T_000.tif


In [12]:
y_stack.shape, X_stack.shape

((27360000,), (27360000, 3))

In [None]:
# feature_stack = generate_feature_stack(mutant)

# # show feature images
# import matplotlib.pyplot as plt
# fig,axes = plt.subplots(1, 3, figsize=(10,10))


# %gui qt
# # reshape(image.shape) is the opposite of ravel() here. We just need it for visualization.
# axes[0].imshow(feature_stack[0].reshape(mutant.shape), cmap=plt.cm.gray)
# axes[1].imshow(feature_stack[1].reshape(mutant.shape), cmap=plt.cm.gray)
# axes[2].imshow(feature_stack[2].reshape(mutant.shape), cmap=plt.cm.gray)

## Interactive segmentation
We can also use napari to annotate some regions as negative (label = 1) and positive (label = 2).

# Training begins

In [None]:
filename = 'transwell_denoised_2_categories.pkl'

In [12]:
# train classifier if not trained yet
classifier = RandomForestClassifier()
classifier.fit(X_stack, y_stack)

# # save classifier
# import pickle
# pickle.dump(classifier, open(filename, 'wb'))

# # make prediction
# prediction = classifier.predict(feature_stack.T)



## Gridsearch

In [None]:

param_grid = {
    'n_estimators': [50,100],  # Vary the number of trees
    'max_depth': [2, 3],       # Vary the maximum depth of trees
}


grid_search = GridSearchCV(classifier, param_grid, cv=5)
grid_search.fit(X, y)  # X and y are your training data and labels, respectively



In [None]:
results = grid_search.cv_results_

# Extract the mean scores and reshape them into a grid
scores = np.array(results['mean_test_score']).reshape(len(param_grid['n_estimators']),
                                                      len(param_grid['max_depth']))



# Create a heatmap of the mean scores
plt.imshow(scores, cmap='viridis', origin='lower')
plt.colorbar(label='Mean Score')
plt.xlabel('min_samples_split')
plt.ylabel('max_depth')
plt.title('Grid Search Mean Scores')
plt.show()

In [None]:
best_classifier = grid_search.best_estimator_


In [None]:
# save classifier 
import pickle
pickle.dump(best_classifier, open(filename, 'wb'))

# Now go to the other notebook for prediction. 