## Segmentation using StarDist

#### 1. Setup & Data loading

In [1]:
from torch.utils.data import DataLoader

from cellgroup.configs import DataConfig
from cellgroup.data import InMemoryDataset
from cellgroup.data.datasets.harvard import SampleHarvard, ChannelHarvard, get_fnames
from cellgroup.data.utils import in_memory_collate_fn

In [2]:
dset_config = DataConfig(
    samples=[SampleHarvard.A06],
    channels=[ChannelHarvard.Ch1, ChannelHarvard.Ch13],
    time_steps=(32, 42, 2),
    img_dim="2D",
    patch_size=(256, 256),
    patch_overlap=(128, 128),
    batch_size=1,
    dloader_kwargs={
        "num_workers": 0,
        "collate_fn": in_memory_collate_fn,
    },
)

In [3]:
dset = InMemoryDataset(
    data_dir="/group/jug/federico/data/Cellgroup",
    data_config=dset_config,
    get_fnames_fn=get_fnames,
)

In [4]:
dloader = DataLoader(
    dset, 
    batch_size=dset_config.batch_size, 
    shuffle=False, 
    num_workers=dset_config.dloader_kwargs.get("num_workers"), 
    collate_fn=dset_config.dloader_kwargs.get("collate_fn")
)

In [5]:
batch = next(iter(dloader))

In [6]:
patches, tinfos = batch
print(patches.shape, patches.mean(), patches.std())
print(len(tinfos), tinfos[0])

(1, 256, 256) -0.994490284538771 0.005252108658092157
1 {<Axis.N>: array(<Sample.A06>, dtype=object), <Axis.C>: array(<Channel.Ch1>, dtype=object), <Axis.T>: 0, <Axis.P>: array(PatchInfo(array_shape=(6456, 6380), last_tile=False, overlap_crop_coords=((0, 192), (0, 192)), stitch_coords=((0, 192), (0, 192))),
      dtype=object)}


#### 2. Model setup

In [7]:
from stardist.models import StarDist2D

2024-12-13 15:17:45.483873: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-12-13 15:17:45.483914: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-12-13 15:17:45.483932: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-12-13 15:17:45.489713: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [8]:
model = StarDist2D.from_pretrained("2D_versatile_fluo")

Found model '2D_versatile_fluo' for 'StarDist2D'.


2024-12-13 15:17:47.251505: I tensorflow/compiler/xla/stream_executor/cuda/cuda_gpu_executor.cc:894] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-12-13 15:17:47.253369: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2211] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.479071, nms_thresh=0.3.


#### 3. Get predictions

In [None]:
from tqdm import tqdm
from csbdeep.utils import normalize # possibly needed

In [None]:
def get_stardist_predictions(model: StarDist2D, dloader: DataLoader):
    pred_patches = []
    patch_infos = []
    for patches, pinfos in tqdm(dloader):
        patches = patches.squeeze()
        preds = model.predict_instances(patches, axes="YX")
        pred_patches.append(preds)
        patch_infos.append(pinfos)
    return pred_patches, patch_infos

In [None]:
segmented_patches, patch_infos = get_stardist_predictions(model, dloader)

In [None]:
import matplotlib.pyplot as plt

_, ax = plt.subplots(1, 2, figsize=(12, 5))
ax[0].imshow(sample_patch)
ax[1].imshow(labels)
