# Datasets

## ISBISegment

In [None]:
from src.datasets import ISBISegment
import matplotlib.pyplot as plt


ds = ISBISegment("./datasets/ISBI-2012-challenge/", trim_seg=True)

sample = ds[0]
img, seg = sample["img"], sample["seg"]
orig_img = img[
    ...,
    ds.subject_region[1] : ds.subject_region[3],
    ds.subject_region[0] : ds.subject_region[2],
]
print(f"Image shape: {img.shape}, dtype: {img.dtype}")
print(f"Segmentation shape: {seg.shape}, dtype: {seg.dtype}")
print(f"Sample count: {len(ds)}")


fig, ax = plt.subplots(1, 3, figsize=(12,4))
ax[0].imshow(orig_img.numpy().transpose(1, 2, 0))
ax[1].imshow(seg.numpy().transpose(1, 2, 0), cmap="gray")
ax[2].imshow(img.numpy().transpose(1, 2, 0))

ax[0].set_title("Original image")
ax[1].set_title("Segmentation (target)")
ax[2].set_title("Extrapolated image (model input)")

plt.show()

## DICHeLa

In [None]:
from src.datasets import ISBICellTrack
import matplotlib.pyplot as plt


ds = ISBICellTrack("./datasets/PhC-C2DH-U373")  # or "datasets/DIC-C2DH-HeLa"

sample = ds[0]
img, seg = sample["img"], sample["seg"]
orig_img = img[
    ...,
    ds.subject_region[1] : ds.subject_region[3],
    ds.subject_region[0] : ds.subject_region[2],
]
print(f"Image shape: {img.shape}, dtype: {img.dtype}")
print(f"Segmentation shape: {seg.shape}, dtype: {seg.dtype}")
print(f"Sample count: {len(ds)}")


fig, ax = plt.subplots(1, 3, figsize=(12, 4))
ax[0].imshow(orig_img.numpy().transpose(1, 2, 0))
ax[1].imshow(seg.numpy().transpose(1, 2, 0), cmap="gray")
ax[2].imshow(img.numpy().transpose(1, 2, 0))

ax[0].set_title("Original image")
ax[1].set_title("Segmentation (target)")
ax[2].set_title("Extrapolated image (model input)")

plt.show()

# Model

In [None]:
import torch
from src.learner import UNet

model = UNet()
img = torch.Tensor(8, 1, 572, 572)
batch = {"img": img}
out = model(batch)
out["logits"].shape

# Loss

In [None]:
from src.datasets import ISBISegment
from src.learner import UNet
from src.loss import SegmentCrossEntropy


sample_id = 0
device = 1
ds = ISBISegment("./datasets/ISBI-2012-challenge/")
model = UNet()
model.set_devices([device])
loss_fn = SegmentCrossEntropy(device, [92, 92, 480, 480])
sample = ds[sample_id]
info = model(sample)
loss = loss_fn(info, sample)

loss

# Warp

In [None]:
import matplotlib.pyplot as plt
from src.util.data import get_9_pt_flow, warp_image
from src.util.visualize import flow2rgb
from src.datasets import ISBISegment


ds = ISBISegment("./datasets/ISBI-2012-challenge/", do_aug=False)
sample = ds[0]
img = sample["img"].numpy().transpose(1, 2, 0)

flow = get_9_pt_flow((572, 572), std=10)  # also can use random_warp
warped_img = warp_image(img, flow)

fig, ax = plt.subplots(1, 3, figsize=(12, 4))

ax[0].imshow(img.squeeze())
ax[1].imshow(flow2rgb(flow))
ax[2].imshow(warped_img.squeeze())

ax[0].set_title("Original image (extrapolated)")
ax[1].set_title("Warping flow")
ax[2].set_title("Warped image")

plt.show()

# Evaluator

In [None]:
from mt_pipe.src.evaluators import SegmentationEvaluator
from src.datasets import ISBISegment
from src.learner import UNet
from torch.utils.data import DataLoader

device = 1
ds = ISBISegment("./datasets/ISBI-2012-challenge/")
dl = DataLoader(ds, 4)
model = UNet()
model.set_devices([device])
evaluator = SegmentationEvaluator()
results = []
for i, batch in enumerate(dl):
    info = model(batch)
    res = evaluator.process_batch(batch, info)
    results.append(res)
    if i>=2:
        break

evaluator.output(results)
