Skip to content

Commit

Permalink
Update NeurIPS CellSeg - to allow download from zenodo links (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
anwai98 committed Mar 30, 2024
1 parent 73f7021 commit fbb6f6c
Showing 1 changed file with 154 additions and 148 deletions.
302 changes: 154 additions & 148 deletions torch_em/data/datasets/neurips_cell_seg.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,41 @@
import json
import os
import numpy as np
from glob import glob
from typing import Union, Tuple, Any, Optional

import numpy as np
import torch

import torch_em
from . import util
from .. import ImageCollectionDataset, RawImageCollectionDataset, ConcatDataset


URL = {
"train": "https://zenodo.org/records/10719375/files/Training-labeled.zip",
"val": "https://zenodo.org/records/10719375/files/Tuning.zip",
"test": "https://zenodo.org/records/10719375/files/Testing.zip",
"unlabeled": "https://zenodo.org/records/10719375/files/train-unlabeled-part1.zip",
"unlabeled_wsi": "https://zenodo.org/records/10719375/files/train-unlabeled-part2.zip"
}

CHECKSUM = {
"train": "b2383929eb8e99b2716fa0d4e2f6e03983e626a57cf00fe85175869c54aa3592",
"val": "849423d36bb8fcc2d91a5b189a3b6d93c3d4071c9701eaaa44ba393a510459c4",
"test": "3379730221f43830d30fddf131750e967c9c9bdf04f98811e852a050eb659ccc",
"unlabeled": "390b38b398b05e9e5306a024a3bd48ab22e49592cfab3c1a119eab3636b38e0d",
"unlabeled_wsi": "d1e68eba2918305eab8b846e7578ac14683de970e3fa6a7c2a4a55753be56204"
}

"""TODO: refactor the loader based on the updated data structure
- Training
- images (multi-modal training inputs)
- labels
- unlabeled (WSI)
- Tuning
- images (multi-modal tuning inputs)
- labels
- Testing
- Public
- images (multi-modal testing inputs)
- labels
- WSI (whole-slide testing inputs)
- WSI-labels
- * (results from `osilab` - ranked 1st in the challenge)
- Hidden
- images (multi-modal hidden testing inputs - unlabeled)
- * (results from `osilab` - ranked 1st in the challenge)
"""
URL = "https://drive.google.com/drive/folders/1NFplvkQzc_nHFwpnB55lw2nD6coc91VV"

DIR_NAMES = {
"train": "Training-labeled", "val": "Tuning", "test": "Testing/Public",
"unlabeled": "release-part1", "unlabeled_wsi": "train-unlabeled-part2"
}

ZIP_PATH = {
"train": "Training-labeled.zip", "val": "Tuning.zip", "test": "Testing.zip",
"unlabeled": "train-unlabeled-part1.zip", "unlabeled_wsi": "train-unlabeled-part2.zip"
}


def to_rgb(image):
Expand All @@ -40,12 +50,21 @@ def to_rgb(image):
return image


# would be better to make balanced splits for the different data modalities
# (but we would need to know mapping of images to modality)
def _get_image_and_label_paths(root, split, val_fraction):
path = os.path.join(root, "TrainLabeled")
assert os.path.exists(root), "Please download the dataset and assort the data as expected here.\
See `get_neurips_cellseg_supervised_dataset`"
def _download_dataset(root, split, download):
os.makedirs(root, exist_ok=True)

target_dir = os.path.join(root, DIR_NAMES[split])
zip_path = os.path.join(root, ZIP_PATH[split])

if not os.path.exists(target_dir):
util.download_source(path=zip_path, url=URL[split], download=download, checksum=CHECKSUM[split])
util.unzip(zip_path=zip_path, dst=root)

return target_dir


def _get_image_and_label_paths(root, split, download):
path = _download_dataset(root, split, download)

image_folder = os.path.join(path, "images")
assert os.path.exists(image_folder)
Expand All @@ -58,115 +77,94 @@ def _get_image_and_label_paths(root, split, val_fraction):
all_label_paths.sort()
assert len(all_image_paths) == len(all_label_paths)

if split is None:
return all_image_paths, all_label_paths

split_file = os.path.join(
os.path.split(__file__)[0], f"split_{val_fraction}.json"
)

if os.path.exists(split_file):
with open(split_file) as f:
split_ids = json.load(f)[split]
else:
# split into training and val images
n_images = len(all_image_paths)
n_train = int((1.0 - val_fraction) * n_images)
image_ids = list(range(n_images))
np.random.shuffle(image_ids)
train_ids, val_ids = image_ids[:n_train], image_ids[n_train:]
assert len(train_ids) + len(val_ids) == n_images

with open(split_file, "w") as f:
json.dump({"train": train_ids, "val": val_ids}, f)

split_ids = val_ids if split == "val" else train_ids

image_paths = [all_image_paths[idx] for idx in split_ids]
label_paths = [all_label_paths[idx] for idx in split_ids]
assert len(image_paths) == len(label_paths)
return image_paths, label_paths
return all_image_paths, all_label_paths


def get_neurips_cellseg_supervised_dataset(
root, split, patch_shape,
make_rgb=True,
label_transform=None,
label_transform2=None,
raw_transform=None,
transform=None,
label_dtype=torch.float32,
n_samples=None,
sampler=None,
val_fraction=0.1,
root: Union[str, os.PathLike],
split: str,
patch_shape: Tuple[int, int],
make_rgb: bool = True,
label_transform: Optional[Any] = None,
label_transform2: Optional[Any] = None,
raw_transform: Optional[Any] = None,
transform: Optional[Any] = None,
label_dtype: torch.dtype = torch.float32,
n_samples: Optional[int] = None,
sampler: Optional[Any] = None,
download: bool = False,
):
"""Dataset for the segmentation of cells in light microscopy.
This dataset is part of the NeuRIPS Cell Segmentation challenge: https://neurips22-cellseg.grand-challenge.org/.
NOTE:
- The dataset isn't available to download using an in-built functionality
- Please download the dataset from here:\
https://drive.google.com/drive/folders/1NFplvkQzc_nHFwpnB55lw2nD6coc91VV
- REMEMBER: to convert the available data in the expected directory format
This dataset is part of the NeurIPS Cell Segmentation challenge: https://neurips22-cellseg.grand-challenge.org/.
"""
assert split in ("train", "val", None), split
image_paths, label_paths = _get_image_and_label_paths(root, split, val_fraction)
assert split in ("train", "val", "test"), split
image_paths, label_paths = _get_image_and_label_paths(root, split, download)

if raw_transform is None:
trafo = to_rgb if make_rgb else None
raw_transform = torch_em.transform.get_raw_transform(augmentation2=trafo)
if transform is None:
transform = torch_em.transform.get_augmentations(ndim=2)

ds = torch_em.data.ImageCollectionDataset(image_paths, label_paths,
patch_shape=patch_shape,
raw_transform=raw_transform,
label_transform=label_transform,
label_transform2=label_transform2,
label_dtype=label_dtype,
transform=transform,
n_samples=n_samples,
sampler=sampler)
ds = ImageCollectionDataset(
raw_image_paths=image_paths,
label_image_paths=label_paths,
patch_shape=patch_shape,
raw_transform=raw_transform,
label_transform=label_transform,
label_transform2=label_transform2,
label_dtype=label_dtype,
transform=transform,
n_samples=n_samples,
sampler=sampler
)
return ds


def get_neurips_cellseg_supervised_loader(
root, split,
patch_shape, batch_size,
make_rgb=True,
label_transform=None,
label_transform2=None,
raw_transform=None,
transform=None,
label_dtype=torch.float32,
n_samples=None,
sampler=None,
val_fraction=0.1,
root: Union[str, os.PathLike],
split: str,
patch_shape: Tuple[int, int],
batch_size: int,
make_rgb: bool = True,
label_transform: Optional[Any] = None,
label_transform2: Optional[Any] = None,
raw_transform: Optional[Any] = None,
transform: Optional[Any] = None,
label_dtype: torch.dtype = torch.float32,
n_samples: Optional[Any] = None,
sampler: Optional[Any] = None,
download: bool = False,
**loader_kwargs
):
"""Dataloader for the segmentation of cells in light microscopy. See `get_neurips_cellseg_supervised_dataset`."""
ds = get_neurips_cellseg_supervised_dataset(
root, split, patch_shape, make_rgb=make_rgb, label_transform=label_transform,
label_transform2=label_transform2, raw_transform=raw_transform, transform=transform,
label_dtype=label_dtype, n_samples=n_samples, sampler=sampler, val_fraction=val_fraction,
root=root,
split=split,
patch_shape=patch_shape,
make_rgb=make_rgb,
label_transform=label_transform,
label_transform2=label_transform2,
raw_transform=raw_transform,
transform=transform,
label_dtype=label_dtype,
n_samples=n_samples,
sampler=sampler,
download=download
)
return torch_em.segmentation.get_data_loader(ds, batch_size, **loader_kwargs)


def _get_image_paths(root):
path = os.path.join(root, "TrainUnlabeled")
assert os.path.exists(path), "Please download the dataset and assort the data as expected here.\
See `get_neurips_cellseg_unsupervised_dataset`"
def _get_image_paths(root, download):
path = _download_dataset(root, "unlabeled", download)
image_paths = glob(os.path.join(path, "*"))
image_paths.sort()
return image_paths


def _get_wholeslide_paths(root, patch_shape):
path = os.path.join(root, "TrainUnlabeled_WholeSlide")
assert os.path.exists(path), "Please download the dataset and assort the data as expected here.\
See `get_neurips_cellseg_unsupervised_dataset`"
def _get_wholeslide_paths(root, patch_shape, download):
path = _download_dataset(root, "unlabeled_wsi", download)
image_paths = glob(os.path.join(path, "*"))
image_paths.sort()

Expand All @@ -185,24 +183,20 @@ def _get_wholeslide_paths(root, patch_shape):


def get_neurips_cellseg_unsupervised_dataset(
root, patch_shape,
make_rgb=True,
raw_transform=None,
transform=None,
dtype=torch.float32,
sampler=None,
use_images=True,
use_wholeslide=True,
root: Union[str, os.PathLike],
patch_shape: Tuple[int, int],
make_rgb: bool = True,
raw_transform: Optional[Any] = None,
transform: Optional[Any] = None,
dtype: torch.dtype = torch.float32,
sampler: Optional[Any] = None,
use_images: bool = True,
use_wholeslide: bool = True,
download: bool = False,
):
"""Dataset for the segmentation of cells in light microscopy.
This dataset is part of the NeuRIPS Cell Segmentation challenge: https://neurips22-cellseg.grand-challenge.org/.
NOTE:
- The dataset isn't available to download using an in-built functionality
- Please download the dataset from here:\
https://drive.google.com/drive/folders/1NFplvkQzc_nHFwpnB55lw2nD6coc91VV
- REMEMBER: to convert the available data in the expected directory format
This dataset is part of the NeurIPS Cell Segmentation challenge: https://neurips22-cellseg.grand-challenge.org/.
"""
if raw_transform is None:
trafo = to_rgb if make_rgb else None
Expand All @@ -212,40 +206,52 @@ def get_neurips_cellseg_unsupervised_dataset(

datasets = []
if use_images:
image_paths = _get_image_paths(root)
datasets.append(torch_em.data.RawImageCollectionDataset(image_paths,
patch_shape=patch_shape,
raw_transform=raw_transform,
transform=transform,
dtype=dtype,
sampler=sampler))
image_paths = _get_image_paths(root, download)
datasets.append(
RawImageCollectionDataset(
raw_image_paths=image_paths,
patch_shape=patch_shape,
raw_transform=raw_transform,
transform=transform,
dtype=dtype,
sampler=sampler
)
)
if use_wholeslide:
image_paths, n_samples = _get_wholeslide_paths(root, patch_shape)
datasets.append(torch_em.data.RawImageCollectionDataset(image_paths,
patch_shape=patch_shape,
raw_transform=raw_transform,
transform=transform,
dtype=dtype,
n_samples=n_samples,
sampler=sampler))
image_paths, n_samples = _get_wholeslide_paths(root, patch_shape, download)
datasets.append(
RawImageCollectionDataset(
raw_image_paths=image_paths,
patch_shape=patch_shape,
raw_transform=raw_transform,
transform=transform,
dtype=dtype,
n_samples=n_samples,
sampler=sampler
)
)
assert len(datasets) > 0
return torch.utils.data.ConcatDataset(datasets)
return ConcatDataset(*datasets)


def get_neurips_cellseg_unsupervised_loader(
root, patch_shape, batch_size,
make_rgb=True,
raw_transform=None,
transform=None,
dtype=torch.float32,
sampler=None,
use_images=True,
use_wholeslide=True,
root: Union[str, os.PathLike],
patch_shape: Tuple[int, int],
batch_size: int,
make_rgb: bool = True,
raw_transform: Optional[Any] = None,
transform: Optional[Any] = None,
dtype: torch.dtype = torch.float32,
sampler: Optional[Any] = None,
use_images: bool = True,
use_wholeslide: bool = True,
download: bool = False,
**loader_kwargs,
):
"""Dataloader for the segmentation of cells in light microscopy. See `get_neurips_cellseg_unsupervised_dataset`."""
"""Dataloader for the segmentation of cells in light microscopy. See `get_neurips_cellseg_unsupervised_dataset`.
"""
ds = get_neurips_cellseg_unsupervised_dataset(
root, patch_shape, make_rgb=make_rgb, raw_transform=raw_transform, transform=transform,
dtype=dtype, sampler=sampler, use_images=use_images, use_wholeslide=use_wholeslide
root=root, patch_shape=patch_shape, make_rgb=make_rgb, raw_transform=raw_transform, transform=transform,
dtype=dtype, sampler=sampler, use_images=use_images, use_wholeslide=use_wholeslide, download=download
)
return torch_em.segmentation.get_data_loader(ds, batch_size, **loader_kwargs)

0 comments on commit fbb6f6c

Please sign in to comment.