## Input
- `tif`, each conatining only 1 channel
- a cellpose classifyer (default or costum trained)

## Outpout
for each provided `tif`:
- segmentation mask as `png`
- segmentation mask as `npy`
- segmentation outline as `txt`
- vis folder with detection visualization for all `tif`s as `png` (good for checking segmentation results)

# 0) Imports and functions

The functions required for this to work are collected in the `pipelines/fish_utils` folder. Download the folder from `/../` in this repository and `sys.path.append(/path/to/fish_utils/)`. You can skip this if you are providing `tif`s.

In [8]:
import os
import numpy as np
from glob import glob
import matplotlib.pyplot as plt

from cellpose import models,io
from cellpose.io import imread
from cellpose import plot

import torch

In [9]:
# to do the segmentaion fast work on the gpu
device = torch.device('cuda:1')
torch.cuda.is_available()

True

# 1) Parameters

In [10]:
# in path
in_path = "/home/stumberger/ep2024/example/"

# model
model = "/home/stumberger/ep2024/example/segmentation_model/es_20231026"

# segmentation parameters
chan = [[0,0]]
diams = 120
min_size = 5000
# sampling in z / sampling in xy (eg. 0.3 / 0.13 = 2.3)
anisotropy = 2.3 

In [11]:
# model for segmentation
model = models.CellposeModel(model_type = model, device=device)

# in and out paths based on upper directory
files = glob(f"{in_path}/tif/*_ch0.tif")
out_path = f"{in_path}/segmentation"

  state_dict = torch.load(filename, map_location=device)


# 2) Segmentation

In [12]:
#create out directories
os.makedirs(f"{out_path}/vis", exist_ok=True)

# apply to all files
for filename in files:
    
    img = io.imread(filename)
    name = os.path.basename(filename).rsplit(".", 1)[0]
    out = f"{out_path}/{name}.tif"
    
    masks, flows, styles = model.eval(img, 
                                      do_3D=True,
                                      diameter = diams,
                                      min_size = min_size,
                                      anisotropy = anisotropy
                                     )

    # save results so you can load in gui
    # io.masks_flows_to_seg(img, masks, flows, diams, out)
    io.masks_flows_to_seg(img, masks, flows, out, diams) 

    # save results as png
    io.save_masks(img, masks, flows, out, tif=True)
    
    # max projection of segmentation for quick visualization
    fig = plt.figure(figsize=(12,5))
    plot.show_segmentation(fig, img.max(axis=0), masks.max(axis=0), flows[0].max(axis=0), channels=chan)
    plt.tight_layout()
    fig.savefig(f"{out_path}/vis/{os.path.basename(out)}.png",dpi=300)
    plt.close(fig)

100%|██████████████████████████████████████████| 31/31 [00:00<00:00, 425.42it/s]
100%|██████████████████████████████████████████| 31/31 [00:00<00:00, 829.75it/s]
