In [17]:
from skimage.io import imsave
import os
import numpy as np
import pickle
from skimage import filters
import time

In [14]:
# path to the folder containing the images of a certain channel

filepath = '/media/eric/WD_Elements/EPFL-data/Lucas/'
fn_channel = 'green_20230920_AirGel_WTmSc_dwspFmN_mix_LAM02J-L_CIP_5.5hpi_23.5h.nd...el_WTmSc_dwspFmN_mix_LAM02J-L_CIP_5.5hpi_23.5h.nd2 (series 08) - C=0.tif'
channel_name = "green"

Z_slices = 19
xdim = 1200
ydim = 1200

visualize = False

model_filename = '../models/RFC_3D/model_for_3D_data.pkl'
segmentation_output_dir = '../data_3D/' + channel_name + '/segmentation_output/'

if not os.path.exists(segmentation_output_dir):
    os.makedirs(segmentation_output_dir)

full_fn = filepath + fn_channel

# if full_fn endswith .nd2 then import pims_nd2
if full_fn.endswith('.nd2'):
    !pip install pims_nd2
    import pims_nd2 as pims

else:
    !pip install pims
    import pims

if visualize:
    import napari
    %gui qt



In [15]:

def generate_feature_stack(image):
    # determine features
    blurred = filters.gaussian(image, sigma=2)
    edges = filters.sobel(blurred)

    # collect features in a stack
    # The ravel() function turns a nD image into a 1-D image.
    # We need to use it because scikit-learn expects values in a 1-D format here.
    feature_stack = [
        image.ravel(),
        blurred.ravel(),
        edges.ravel()
    ]

    # return stack as numpy-array
    return np.asarray(feature_stack)


# Use napari for visualization

In [16]:
def MakePrediction(img, img_shape,model):
    # process the whole image and show result
    result_1d = model.predict(img.T)
    result_3d = result_1d.reshape(img_shape)
    result_3d = result_3d
    result_3d[result_3d ==3]=0 # make background 0
    #delete unused variables
    del result_1d, img, img_shape
    return result_3d

def ShowNapari(img, prediction):
    
    viewer = napari.Viewer()
    viewer.add_image(img)
    viewer.add_labels(prediction)

    viewer.show(block=True)
    viewer.close()




In [18]:

model = pickle.load(open(model_filename, 'rb'))
stack_seq = pims.open(full_fn)


timestep = 0
counter = 0

ref_time = time.time()

for t, img in enumerate(stack_seq):    

    # iterate through all z slices and stack them together
    if np.mod(t,Z_slices) == 0:
        # prepare empty array with the same dtype as the first image
        counter = 0
        img_stack = np.zeros((Z_slices,xdim,ydim))

    img_stack[counter,:,:] = img

    if np.mod(t,Z_slices) == Z_slices-1:
        print("Timestep: ", timestep)


        feature_stack = generate_feature_stack(img_stack)
        img_shape = img_stack.shape

        prediction = MakePrediction(feature_stack, img_shape,model)

        # 1 is the label for the cells, 2 is blurry stuff, 3 is background
        prediction[prediction != 1] = 0

        del feature_stack
        if visualize:
            ShowNapari(img_stack, prediction)
            
        del img_stack

        imsave(segmentation_output_dir + 'T_'+str(timestep).zfill(3)+  '_' + fn_channel , prediction,check_contrast=False)
        del prediction

        timestep += 1
        print("it took ", np.round(time.time()-ref_time,0), " seconds to process this timestep")
        ref_time = time.time()

    counter += 1



Timestep:  0
it took  49.61208486557007  seconds to process this timestep
Timestep:  1
it took  51.51386475563049  seconds to process this timestep
Timestep:  2
it took  54.104984998703  seconds to process this timestep
Timestep:  3
it took  54.5662055015564  seconds to process this timestep
Timestep:  4
it took  54.02252221107483  seconds to process this timestep
Timestep:  5
