# Now work with the GT to train the model

In [None]:
from skimage.io import imread
import os
import numpy as np
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import GridSearchCV
import pickle
import matplotlib.pyplot as plt

In [None]:
data_dir = '../data_3D/'
channel_name = "green" # rename it to WT or wsp or whatever mutant
model_directory = "../models/RFC_3D/"


# Now go through all existing GT files that contain non-zero pixels
all_filenames = os.listdir(data_dir + channel_name + "GT")



In [None]:
from skimage import filters

def generate_feature_stack(image):
    # determine features
    blurred = filters.gaussian(image, sigma=2)
    edges = filters.sobel(blurred)

    # collect features in a stack
    # The ravel() function turns a nD image into a 1-D image.
    # We need to use it because scikit-learn expects values in a 1-D format here.
    feature_stack = [
        image.ravel(),
        blurred.ravel(),
        edges.ravel()
    ]

    # return stack as numpy-array
    return np.asarray(feature_stack)


## Formating data


In [None]:
def format_data(feature_stack, annotation):
    # reformat the data to match what scikit-learn expects
    # transpose the feature stack
    X = feature_stack.T
    # make the annotation 1-dimensional
    y = annotation.ravel()

    # remove all pixels from the feature and annotations which have not been annotated
    mask = y > 0
    X = X[mask]
    y = y[mask]

    return X, y

In [None]:

# in case you have not much memory as I do, you should run only one channel_name at a time

X_stack = []
y_stack = []

for fn in all_filenames:
    print(fn)

    # load img and annotation
    img = imread(data_dir + channel_name + "/RAW/" + fn)
    annotation = imread(data_dir + channel_name + "/GT/" + fn)
    
    # prepare data for training
    feature_stack = generate_feature_stack(img)
    X, y = format_data(feature_stack, annotation)
    del img, annotation, feature_stack
       
    X_stack.append(X)
    y_stack.append(y)

    # delete variables to save memory
    del X, y
    
X_stack = np.concatenate(X_stack)
y_stack = np.concatenate(y_stack)

In [None]:
#detect inf values and set them to false
mask = X_stack != np.inf 
# count if all 3 columns are True (or not)
mask = np.sum(mask,axis=1) > 2

X_stack = X_stack[mask,:]
y_stack = y_stack[mask]


# Training begins

In [None]:
filename = 'model_for_3D_data.pkl'

In [None]:
# train classifier if not trained yet
classifier = RandomForestClassifier()

param_grid = {
    'n_estimators': [50,100],  # Vary the number of trees
    'max_depth': [2, 3],       # Vary the maximum depth of trees
}


grid_search = GridSearchCV(classifier, param_grid, cv=5)
grid_search.fit(X_stack, y_stack)  # X and y are your training data and labels, respectively



In [None]:
results = grid_search.cv_results_

# Extract the mean scores and reshape them into a grid
scores = np.array(results['mean_test_score']).reshape(len(param_grid['n_estimators']),
                                                      len(param_grid['max_depth']))


# Create a heatmap of the mean scores
plt.imshow(scores, cmap='viridis', origin='lower')
plt.colorbar(label='Mean Score')
plt.xlabel('min_samples_split')
plt.ylabel('max_depth')
plt.title('Grid Search Mean Scores')
plt.show()

In [None]:
best_classifier = grid_search.best_estimator_

In [None]:
# save classifier 

if not os.path.exists("../models"):
    os.makedirs("../models")

if not os.path.exists(model_directory):
    os.makedirs(model_directory)


pickle.dump(best_classifier, open(model_directory+filename, 'wb'))

# Now go to the other notebook for prediction. 