In [None]:
from cellpose import models
from tifffile import imread
from tifffile import imwrite
from glob import glob
import numpy as np
import napari
import time
from glob import glob
import os
import torch
import gc

# cellpose nuclei model is trained on diameter of 17, but our data is 13:
# maybe nuclei model training data should be rescaled to have diameter of 13 and we train *everything* together
# as suggested by "https://cellpose.readthedocs.io/en/latest/models.html"
diam_mean = 13  

# model location
# trained on G:\Data\IBIN_Nina\workspace\nina_cellpose\training_data\with_reslice\mask_only_right
model_path = "C:/Users/OPMuser/nina_cellpose/training_data/with_reslice/model/models/cellpose_residual_on_style_on_concatenation_off__2022_10_18_14_57_00.795536"

# save dir for segmentation masks
save_dir = "T:/IBIN_Nina/workspace/main_live4/trained_masks_new_model/all_runs_Hugh_test/"

# root dir of dataset with structure: <root>/run_001/main/YY-MM-DD/HH-MM-SS/cst4/run_xxx/field_xxx/excxxx_filterxxx
root_dir = "T:/IBIN_Nina/temp/20221125_live_plate4/"

search_dir = os.path.join(root_dir, "**/cst4/run_*")
search_dir = os.path.join(root_dir, "main/*/*/cst4/run_*")

print(search_dir)
all_runs = glob(search_dir, recursive=True)
print(all_runs)
print(len(all_runs))

In [None]:
gc.collect()
torch.cuda.empty_cache()
#os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:64"  ## attempts to prevent OOM

diam_mean = 15

search_dir = os.path.join(root_dir, "main*/*/*/cst4/run_*")
print(search_dir)
all_runs = glob(search_dir, recursive=True)
print(all_runs)

fields_to_include = None # [1, 5, 16, 49, 52, 64]  # if None, run all
runs_to_include = None# [10,11]  # counted from zero, if None run all

print(np.array(runs_to_include))
if runs_to_include is not None:
    runs_dirs = [all_runs[n] for n in runs_to_include]
else:
    runs_dirs = all_runs
    runs_to_include = range(len(all_runs))

f = open("process_details.txt", "a")
f.write("data path: " + root_dir)
f.write("model path: " + model_path)
f.write("diam_mean: %.1f" % diam_mean)
f.close()

for n_i in range(len(all_runs)):
    rundir = runs_dirs[n_i]
    n = runs_to_include[n_i]  # n_i is just arb. index, n is actual run number 
    fields = glob(os.path.join(rundir, "field_*"))  # get list of field dirs (field_0001/ etc.)
    print(rundir)
    f = open("info.txt", "a")
    f.write(rundir)
    f.close()
    for f in range(len(fields)):
        # field number number counts from 1, so do f+1 
        if fields_to_include is not None and f+1 not in fields_to_include:
            continue

        save_dir_run = os.path.join(save_dir, "run_%03d" % (n+1))

        if not os.path.exists(save_dir_run):
            os.makedirs(save_dir_run)

        fname = 'field%03d_trained.tif' % (f+1)
        savepath = os.path.join(save_dir_run, fname)

        if os.path.exists(savepath):
            print("Segmented image", savepath, "exists, skipping", n )
            continue

        start_time = time.time()
        load_dir = os.path.join(fields[f], 'exc561_filter605')
        seq = glob(os.path.join(load_dir, "*.tif"))
        stack = imread(seq)


        print("segmenting %s" % fields[f])

        ## remove bg/offset from data added by camera etc.
        stack = np.asarray(stack, dtype=np.int32)
        stack -= 100  # subtract from signed int
        negative_mask = stack > 0
        stack = stack*negative_mask  # remove negative values
        stack = np.asarray(stack, dtype=np.uint16)  # convert back to unsigned

        # should we initialise this less often
        model = models.CellposeModel(pretrained_model = model_path,
                                diam_mean = diam_mean,
                                model_type=None,
                                gpu=True,
                                torch = True,
                                net_avg = True
                            )

        output = model.eval(stack, channels=[0,0], do_3D=True, diameter=diam_mean)
        masks = output[0]

        segment_time = time.time() - start_time
        print("Time taken to segment:", segment_time, "s")

        ## save
        print("saving:", savepath)
        imwrite(savepath, masks)
print("finished on", n, rundir)