## Segmentation using StarDist

#### 1. Setup & Data loading

In [19]:
from torch.utils.data import DataLoader

from cellgroup.configs import DataConfig
from cellgroup.data import InMemoryDataset
from cellgroup.data.datasets.harvard import SampleHarvard, ChannelHarvard, get_fnames
from cellgroup.data.utils import in_memory_collate_fn

In [20]:
dset_config = DataConfig(
    samples=[SampleHarvard.A06],
    channels=[ChannelHarvard.Ch1, ChannelHarvard.Ch13],
    time_steps=(32, 42, 2),
    img_dim="2D",
    patch_size=(256, 256),
    patch_overlap=(128, 128),
    batch_size=1,
    dloader_kwargs={
        "num_workers": 0,
        "collate_fn": in_memory_collate_fn,
    },
)

In [21]:
dset = InMemoryDataset(
    data_dir="/group/jug/federico/data/Cellgroup",
    data_config=dset_config,
    get_fnames_fn=get_fnames,
)

In [22]:
dloader = DataLoader(
    dset, 
    batch_size=dset_config.batch_size, 
    shuffle=False, 
    num_workers=dset_config.dloader_kwargs.get("num_workers"), 
    collate_fn=dset_config.dloader_kwargs.get("collate_fn")
)

In [23]:
batch = next(iter(dloader))

In [24]:
patches, tinfos = batch
print(patches.shape, patches.mean(), patches.std())
print(len(tinfos), tinfos[0])

(1, 256, 256) -0.994490284538771 0.005252108658092157
1 {<Axis.N>: array(<Sample.A06>, dtype=object), <Axis.C>: array(<Channel.Ch1>, dtype=object), <Axis.T>: 0, <Axis.P>: array(PatchInfo(array_shape=(6456, 6380), last_tile=False, overlap_crop_coords=((0, 192), (0, 192)), stitch_coords=((0, 192), (0, 192))),
      dtype=object)}


#### 2. Model setup

In [25]:
from stardist.models import StarDist2D

In [26]:
model = StarDist2D.from_pretrained("2D_versatile_fluo")

Found model '2D_versatile_fluo' for 'StarDist2D'.
Loading network weights from 'weights_best.h5'.
Loading thresholds from 'thresholds.json'.
Using default values: prob_thresh=0.479071, nms_thresh=0.3.


#### 3. Get predictions

In [27]:
from tqdm import tqdm
from csbdeep.utils import normalize # possibly needed

In [None]:
def get_stardist_predictions(model: StarDist2D, dloader: DataLoader):
    pred_patches = []
    patch_infos = []
    for patches, pinfos in tqdm(dloader):
        patches = patches.squeeze()
        preds = model.predict_instances(patches, axes="YX")
        pred_patches.append(preds)
        patch_infos.append(pinfos)
    return pred_patches, patch_infos

In [29]:
segmented_patches, patch_infos = get_stardist_predictions(model, dloader)

  0%|          | 0/24500 [00:00<?, ?it/s]


ValueError: axes (SYX) must be of length 2.

In [None]:
import matplotlib.pyplot as plt

_, ax = plt.subplots(1, 2, figsize=(12, 5))
ax[0].imshow(sample_patch)
ax[1].imshow(labels)
