-
Notifications
You must be signed in to change notification settings - Fork 1
Add support for data augmentation via adapter layer to monai image transformation suite #25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
628d417
Add BaseWrapperDataset and MonaiAdapter classes to integrate Monai au…
wli51 91542c3
Modify visualization modules to support wrapper dataset
wli51 8aaa4f4
Add data augmentation example using MONAI transforms and script
wli51 f24389e
Fix index check not considering empty list error
wli51 d0ea53f
Update dataset parameter descriptions to include BaseWrapperDataset i…
wli51 bb186d8
enhance plot prediction logic to re-use inference time input and targ…
wli51 a014360
fix typo
wli51 65ce370
fix typos in transformation example section and consistency note
wli51 8931842
Add dataset wrapper and MONAI augmentation adapter with usage example
wli51 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,240 @@ | ||
| #!/usr/bin/env python | ||
| # coding: utf-8 | ||
|
|
||
| # # Examples for incorporating monai image augmentation suite for training | ||
|
|
||
| # ## Dependencies | ||
|
|
||
| # In[ ]: | ||
|
|
||
|
|
||
| import re | ||
| import pathlib | ||
| from typing import List | ||
|
|
||
| import pandas as pd | ||
| from monai.transforms import ( | ||
| Compose, | ||
| EnsureTyped, | ||
| RandFlipd, | ||
| RandRotate90d, | ||
| RandAffined, | ||
| RandGaussianNoised, | ||
| RandGaussianSmoothd, | ||
| RandAdjustContrastd | ||
| ) | ||
|
|
||
| from virtual_stain_flow.datasets.base_dataset import BaseImageDataset | ||
| from virtual_stain_flow.datasets.crop_dataset import CropImageDataset | ||
| from virtual_stain_flow.datasets.monai_aug_adapter_dataset import MonaiAdapter | ||
| from virtual_stain_flow.transforms.normalizations import MaxScaleNormalize | ||
| from virtual_stain_flow.evaluation.visualization import plot_dataset_grid | ||
|
|
||
|
|
||
| # ## Pathing and Additional utils | ||
|
|
||
| # In[ ]: | ||
|
|
||
|
|
||
| DATA_PATH = pathlib.Path("/YOUR/DATA/PATH/") # Change to where the download_data script outputs data | ||
|
|
||
| # Sanity check for data existence | ||
| if not DATA_PATH.exists() or not DATA_PATH.is_dir(): | ||
| raise FileNotFoundError(f"Data path {DATA_PATH} does not exist or is not a directory.") | ||
|
|
||
| # Matches filenames like: | ||
| # r01c01f01p01-ch1sk1fk1fl1.tiff | ||
| FIELD_RE = re.compile( | ||
| r"(r\d{2}c\d{2}f\d{2}p01)-ch(\d+)sk1fk1fl1\.tiff$" | ||
| ) | ||
|
|
||
| def _collect_field_prefixes( | ||
| plate_dir: pathlib.Path, | ||
| max_fields: int = 16, | ||
| ) -> List[str]: | ||
| """ | ||
| Scan a JUMP CPJUMP1 plate directory and collect distinct field prefixes. | ||
| Expects image filename like: | ||
| r01c01f01p01-ch1sk1fk1fl1.tiff | ||
| """ | ||
| prefixes: List[str] = [] | ||
| for path in sorted(plate_dir.glob("*.tiff")): | ||
| m = FIELD_RE.match(path.name) | ||
| if not m: | ||
| continue | ||
| prefix = m.group(1) # e.g. "r01c01f01p01" | ||
| if prefix not in prefixes: | ||
| prefixes.append(prefix) | ||
| if len(prefixes) >= max_fields: | ||
| break | ||
| return prefixes | ||
|
|
||
| def build_file_index( | ||
| plate_dir: pathlib.Path, | ||
| max_fields: int = 16, | ||
| ) -> pd.DataFrame: | ||
| """ | ||
| Helper function to build a file index that specifies | ||
| the relationship of images across channels and field/fovs. | ||
| The result can directly be supplied to BaseImageDataset to create a | ||
| dataset with the correct image pairs. | ||
| """ | ||
|
|
||
| fields = _collect_field_prefixes( | ||
| plate_dir, | ||
| max_fields=max_fields, | ||
| ) | ||
|
|
||
| file_index_list = [] | ||
| for field in fields: | ||
| sample = {} | ||
| for chan in DATA_PATH.glob(f"**/{field}*.tiff"): | ||
|
wli51 marked this conversation as resolved.
|
||
| match = FIELD_RE.match(chan.name) | ||
| if match and match.groups()[1]: | ||
| sample[f"ch{match.groups()[1]}"] = str(chan) | ||
|
|
||
| file_index_list.append(sample) | ||
|
|
||
| file_index = pd.DataFrame(file_index_list) | ||
| file_index.dropna(how='all', inplace=True) | ||
| if file_index.empty: | ||
| raise ValueError(f"No files found in {plate_dir} matching the expected pattern.") | ||
|
|
||
| return file_index.loc[:, sorted(file_index.columns)] | ||
|
|
||
|
|
||
| # In[3]: | ||
|
|
||
|
|
||
| # For stable wGAN, we don't want the dataset to be too small that the discriminator | ||
| # quickly memorizes the set and overpowers the generator. | ||
| # So here a bigger, 2048 FOV subset of CJUMP1 (BF and Hoechst channel) is used as demo dataset | ||
| # See https://github.com/jump-cellpainting/2024_Chandrasekaran_NatureMethods_CPJUMP1 for details | ||
| file_index = build_file_index(DATA_PATH, max_fields=64) | ||
| print(file_index.head()) | ||
|
|
||
|
|
||
| # ## Create dataset from CPJUMP1 and take center crops | ||
|
|
||
| # In[4]: | ||
|
|
||
|
|
||
| # Create a dataset with Brightfield as input and Hoechst as target | ||
| # See https://github.com/jump-cellpainting/2024_Chandrasekaran_NatureMethods_CPJUMP1 | ||
| # for which channel codes correspond to which channel | ||
| dataset = BaseImageDataset( | ||
| file_index=file_index, | ||
| check_exists=True, | ||
| pil_image_mode="I;16", | ||
| input_channel_keys=["ch7"], | ||
| target_channel_keys=["ch5"], | ||
| ) | ||
| print(f"Dataset length: {len(dataset)}") | ||
| print( | ||
| f"Input channels: {dataset.input_channel_keys}, target channels: {dataset._target_channel_keys}" | ||
| ) | ||
|
|
||
| cropped_dataset = CropImageDataset.from_base_dataset( | ||
| dataset, | ||
| crop_size=128, | ||
| transforms=MaxScaleNormalize( | ||
| normalization_factor='16bit' | ||
| ) | ||
| ) | ||
| plot_dataset_grid( | ||
| dataset=cropped_dataset, | ||
| indices=[0], | ||
| wspace=0.025, | ||
| hspace=0.05 | ||
| ) | ||
|
|
||
|
|
||
| # ## Transformation example | ||
|
|
||
| # In[5]: | ||
|
|
||
|
|
||
| monai_transform = Compose([ | ||
| EnsureTyped(keys=["input", "target"]), | ||
| RandFlipd(keys=["input", "target"], prob=0.5, spatial_axis=0), | ||
| RandFlipd(keys=["input", "target"], prob=0.5, spatial_axis=1), | ||
| RandRotate90d(keys=["input", "target"], prob=0.5, max_k=3), | ||
| RandAffined( | ||
| keys=["input", "target"], | ||
| prob=0.7, | ||
| rotate_range=(0.0, 0.0, 0.15), | ||
| translate_range=(0, 0), # no translate | ||
| scale_range=(0.0, 0.0), # no scale | ||
| padding_mode="border", | ||
| ), | ||
| RandGaussianSmoothd( | ||
| keys=["input"], | ||
| prob=0.2, | ||
| sigma_x=(0.25, 0.5), # more aggressive smoothing to simulate out-of-focus | ||
| sigma_y=(0.25, 0.5), | ||
| ), | ||
| RandAdjustContrastd( | ||
| keys=["input"], | ||
| prob=0.2, | ||
| gamma=(0.95, 1.05), # small variation to avoid unrealistic contrast change | ||
| invert_image=False, | ||
| retain_stats=True, | ||
| ), | ||
| RandGaussianNoised( | ||
| keys=["input"], | ||
| prob=0.2, | ||
| mean=0.0, # no bias | ||
| std=1e-4, # subtle salt and pepper | ||
| ), | ||
| ]) | ||
|
|
||
| augmented_dataset = MonaiAdapter(cropped_dataset, transform=monai_transform) | ||
|
|
||
|
|
||
| # ## Visualize the same augmented dataset multiple times to see effects of augmentation | ||
| # Note that augmentation is only applied to the crop and the shown full FOV is always un-augmented | ||
|
|
||
| # In[6]: | ||
|
|
||
|
|
||
| for i in range(5): | ||
| plot_dataset_grid( | ||
| dataset=augmented_dataset, | ||
| indices=[0], # only first sample to better see difference | ||
| wspace=0.025, | ||
| hspace=0.05 | ||
| ) | ||
|
|
||
|
|
||
| # ## Use `MonaiAdapter` for training as would with any image dataset | ||
| # | ||
| # e.g. | ||
| # ```python | ||
| # ... | ||
| # | ||
| # # Make train loader from augmented adataset | ||
| # train_loader = DataLoader( | ||
| # augmented_dataset, | ||
| # batch_size=batch_size, | ||
| # shuffle=True, | ||
| # ) | ||
| # ... | ||
| # | ||
| # # feed to trainer | ||
| # trainer = SingleGeneratorTrainer( | ||
| # model=..., | ||
| # optimizer=..., | ||
| # losses=..., | ||
| # loss_weights=..., | ||
| # device='cuda', | ||
| # train_loader=train_loader | ||
| # ) | ||
|
wli51 marked this conversation as resolved.
|
||
| # | ||
| # # optionally, if want to use plot prediction callback | ||
| # plot_callback = PlotPredictionCallback( | ||
| # name="...", | ||
| # dataset=crop_dataset, # non-augmented dataset recommended for consistency | ||
| # # but augment datasets also work here | ||
| # ... | ||
| # ) | ||
| # ``` | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,43 @@ | ||
| """ | ||
| base_wrapper_dataset.py | ||
|
|
||
| Defines a simple BaseWrapperDataset scheme that can wraps any BaseImageDataset | ||
| and forwards all method calls to it. | ||
| """ | ||
|
|
||
| from abc import ABC, abstractmethod | ||
| from typing import Union | ||
|
|
||
| from .base_dataset import BaseImageDataset | ||
| from .crop_dataset import CropImageDataset | ||
|
|
||
| class BaseWrapperDataset(ABC): | ||
|
|
||
| def __init__( | ||
| self, | ||
| dataset: Union[BaseImageDataset, CropImageDataset] | ||
| ): | ||
| self._dataset = dataset | ||
| # optionally do something to the dataset | ||
|
|
||
| def __len__(self): | ||
| return len(self._dataset) | ||
|
|
||
| @abstractmethod | ||
| def __getitem__(self, idx): | ||
| # retrieve images from dataset | ||
| input, target = self._dataset[idx] | ||
|
|
||
| # do something to the input and target here | ||
| # (e.g. apply transformations, generate crops, cache in RAM, etc.) | ||
|
|
||
| return input, target | ||
|
|
||
| @property | ||
| def original(self) -> Union[BaseImageDataset, CropImageDataset]: | ||
| """ | ||
| Access the original underlying dataset for metadata etc. | ||
| """ | ||
| if isinstance(self._dataset, BaseWrapperDataset): | ||
| return self._dataset.original | ||
| return self._dataset |
41 changes: 41 additions & 0 deletions
41
src/virtual_stain_flow/datasets/monai_aug_adapter_dataset.py
|
wli51 marked this conversation as resolved.
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,41 @@ | ||
| """ | ||
| monai_aug_adapter_dataset.py | ||
| """ | ||
|
|
||
| from monai.transforms import ( | ||
| Compose | ||
| ) | ||
|
|
||
| from .base_dataset import BaseImageDataset | ||
| from .base_wrapper_dataset import BaseWrapperDataset | ||
|
|
||
| class MonaiAdapter(BaseWrapperDataset): | ||
| """ | ||
| Adapter dataset to wrap any BaseImageDataset and return samples | ||
| in dictionary format compatible with MONAI transforms and pipelines. | ||
| Specifically, each sample is returned as a dictionary with keys "input" and "target", | ||
| containing the input and target tensors respectively, then ran through | ||
| MONAI transforms if provided, and finally returned back as a tuple of | ||
| (input, target) tensors. It would be meaningless to use this adapter | ||
| without any MONAI transforms as the data just gets wrapped and unwrapped. | ||
| """ | ||
| def __init__( | ||
| self, | ||
| base_dataset: BaseImageDataset, | ||
| transform: Compose | None = None | ||
| ): | ||
| super().__init__(base_dataset) | ||
| self._transform = transform | ||
|
|
||
| def __len__(self): | ||
| return len(self._dataset) | ||
|
|
||
| def __getitem__(self, idx): | ||
| x, y = self._dataset[idx] | ||
|
|
||
| sample = {"input": x, "target": y} | ||
|
|
||
| if self._transform is not None: | ||
| sample = self._transform(sample) | ||
|
|
||
| return sample["input"], sample["target"] |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.