Skip to content
Merged
25 changes: 25 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,31 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

---

## [0.4.4] - 2026-04-23

### Added

#### Dataset wrapper and MONAI augmentation adapter (`virtual_stain_flow/datasets/`):

Introduces a lightweight wrapper abstraction for dataset composition and a MONAI-compatible adapter for dictionary-based augmentation pipelines. This enables augmentation workflows to be layered on top of existing dataset implementations without modifying core dataset logic.

- **`BaseWrapperDataset`** (`base_wrapper_dataset.py`): Abstract wrapper base class that forwards dataset access to an underlying dataset instance and provides recursive access to the original base dataset via the `original` property. Establishes a reusable pattern for composing dataset behaviors (e.g., augmentation, caching, preprocessing) while preserving compatibility with existing dataset APIs.
- **`MonaiAdapter`** (`monai_aug_adapter_dataset.py`): Wrapper dataset that adapts `(input, target)` tuple samples into MONAI dictionary format (`{"input": ..., "target": ...}`), applies optional MONAI `Compose` transforms, and returns transformed samples back as `(input, target)` tuples for trainer compatibility.

#### MONAI augmentation usage example (`examples/`):

- Added/updated **`4.data_augmentation_example.ipynb`** demonstrating:
- construction of a base dataset and crop dataset,
- application of MONAI dictionary transforms through `MonaiAdapter`,
- visualization of repeated stochastic augmentations,
- integration of the augmented dataset into a standard training dataloader/trainer workflow.

### Refactored

#### Visualization suite to support `BaseWrapperDataset`

---

## [0.4.3] - 2025-12-16

### Added
Expand Down
461 changes: 461 additions & 0 deletions examples/4.data_augmentation_example.ipynb

Large diffs are not rendered by default.

240 changes: 240 additions & 0 deletions examples/nbconverted/4.data_augmentation_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
#!/usr/bin/env python
# coding: utf-8

# # Examples for incorporating monai image augmentation suite for training

# ## Dependencies

# In[ ]:


import re
import pathlib
from typing import List

import pandas as pd
from monai.transforms import (
Compose,
EnsureTyped,
RandFlipd,
RandRotate90d,
RandAffined,
RandGaussianNoised,
RandGaussianSmoothd,
RandAdjustContrastd
)

from virtual_stain_flow.datasets.base_dataset import BaseImageDataset
from virtual_stain_flow.datasets.crop_dataset import CropImageDataset
from virtual_stain_flow.datasets.monai_aug_adapter_dataset import MonaiAdapter
from virtual_stain_flow.transforms.normalizations import MaxScaleNormalize
from virtual_stain_flow.evaluation.visualization import plot_dataset_grid


# ## Pathing and Additional utils

# In[ ]:


DATA_PATH = pathlib.Path("/YOUR/DATA/PATH/") # Change to where the download_data script outputs data

# Sanity check for data existence
if not DATA_PATH.exists() or not DATA_PATH.is_dir():
raise FileNotFoundError(f"Data path {DATA_PATH} does not exist or is not a directory.")

# Matches filenames like:
# r01c01f01p01-ch1sk1fk1fl1.tiff
FIELD_RE = re.compile(
r"(r\d{2}c\d{2}f\d{2}p01)-ch(\d+)sk1fk1fl1\.tiff$"
)

def _collect_field_prefixes(
plate_dir: pathlib.Path,
max_fields: int = 16,
) -> List[str]:
"""
Scan a JUMP CPJUMP1 plate directory and collect distinct field prefixes.
Expects image filename like:
r01c01f01p01-ch1sk1fk1fl1.tiff
Comment thread
wli51 marked this conversation as resolved.
"""
prefixes: List[str] = []
for path in sorted(plate_dir.glob("*.tiff")):
m = FIELD_RE.match(path.name)
if not m:
continue
prefix = m.group(1) # e.g. "r01c01f01p01"
if prefix not in prefixes:
prefixes.append(prefix)
if len(prefixes) >= max_fields:
break
return prefixes

def build_file_index(
plate_dir: pathlib.Path,
max_fields: int = 16,
) -> pd.DataFrame:
"""
Helper function to build a file index that specifies
the relationship of images across channels and field/fovs.
The result can directly be supplied to BaseImageDataset to create a
dataset with the correct image pairs.
"""

fields = _collect_field_prefixes(
plate_dir,
max_fields=max_fields,
)

file_index_list = []
for field in fields:
sample = {}
for chan in DATA_PATH.glob(f"**/{field}*.tiff"):
Comment thread
wli51 marked this conversation as resolved.
match = FIELD_RE.match(chan.name)
if match and match.groups()[1]:
sample[f"ch{match.groups()[1]}"] = str(chan)

file_index_list.append(sample)

file_index = pd.DataFrame(file_index_list)
file_index.dropna(how='all', inplace=True)
if file_index.empty:
raise ValueError(f"No files found in {plate_dir} matching the expected pattern.")

return file_index.loc[:, sorted(file_index.columns)]


# In[3]:


# For stable wGAN, we don't want the dataset to be too small that the discriminator
# quickly memorizes the set and overpowers the generator.
# So here a bigger, 2048 FOV subset of CJUMP1 (BF and Hoechst channel) is used as demo dataset
# See https://github.com/jump-cellpainting/2024_Chandrasekaran_NatureMethods_CPJUMP1 for details
file_index = build_file_index(DATA_PATH, max_fields=64)
print(file_index.head())


# ## Create dataset from CPJUMP1 and take center crops

# In[4]:


# Create a dataset with Brightfield as input and Hoechst as target
# See https://github.com/jump-cellpainting/2024_Chandrasekaran_NatureMethods_CPJUMP1
# for which channel codes correspond to which channel
dataset = BaseImageDataset(
file_index=file_index,
check_exists=True,
pil_image_mode="I;16",
input_channel_keys=["ch7"],
target_channel_keys=["ch5"],
)
print(f"Dataset length: {len(dataset)}")
print(
f"Input channels: {dataset.input_channel_keys}, target channels: {dataset._target_channel_keys}"
)

cropped_dataset = CropImageDataset.from_base_dataset(
dataset,
crop_size=128,
transforms=MaxScaleNormalize(
normalization_factor='16bit'
)
)
plot_dataset_grid(
dataset=cropped_dataset,
indices=[0],
wspace=0.025,
hspace=0.05
)


# ## Transformation example

# In[5]:


monai_transform = Compose([
EnsureTyped(keys=["input", "target"]),
RandFlipd(keys=["input", "target"], prob=0.5, spatial_axis=0),
RandFlipd(keys=["input", "target"], prob=0.5, spatial_axis=1),
RandRotate90d(keys=["input", "target"], prob=0.5, max_k=3),
RandAffined(
keys=["input", "target"],
prob=0.7,
rotate_range=(0.0, 0.0, 0.15),
translate_range=(0, 0), # no translate
scale_range=(0.0, 0.0), # no scale
padding_mode="border",
),
RandGaussianSmoothd(
keys=["input"],
prob=0.2,
sigma_x=(0.25, 0.5), # more aggressive smoothing to simulate out-of-focus
sigma_y=(0.25, 0.5),
),
RandAdjustContrastd(
keys=["input"],
prob=0.2,
gamma=(0.95, 1.05), # small variation to avoid unrealistic contrast change
invert_image=False,
retain_stats=True,
),
RandGaussianNoised(
keys=["input"],
prob=0.2,
mean=0.0, # no bias
std=1e-4, # subtle salt and pepper
),
])

augmented_dataset = MonaiAdapter(cropped_dataset, transform=monai_transform)


# ## Visualize the same augmented dataset multiple times to see effects of augmentation
# Note that augmentation is only applied to the crop and the shown full FOV is always un-augmented

# In[6]:


for i in range(5):
plot_dataset_grid(
dataset=augmented_dataset,
indices=[0], # only first sample to better see difference
wspace=0.025,
hspace=0.05
)


# ## Use `MonaiAdapter` for training as would with any image dataset
#
# e.g.
# ```python
# ...
#
# # Make train loader from augmented adataset
# train_loader = DataLoader(
# augmented_dataset,
# batch_size=batch_size,
# shuffle=True,
# )
# ...
#
# # feed to trainer
# trainer = SingleGeneratorTrainer(
# model=...,
# optimizer=...,
# losses=...,
# loss_weights=...,
# device='cuda',
# train_loader=train_loader
# )
Comment thread
wli51 marked this conversation as resolved.
#
# # optionally, if want to use plot prediction callback
# plot_callback = PlotPredictionCallback(
# name="...",
# dataset=crop_dataset, # non-augmented dataset recommended for consistency
# # but augment datasets also work here
# ...
# )
# ```
43 changes: 43 additions & 0 deletions src/virtual_stain_flow/datasets/base_wrapper_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""
base_wrapper_dataset.py

Defines a simple BaseWrapperDataset scheme that can wraps any BaseImageDataset
and forwards all method calls to it.
"""

from abc import ABC, abstractmethod
from typing import Union

from .base_dataset import BaseImageDataset
from .crop_dataset import CropImageDataset

class BaseWrapperDataset(ABC):

def __init__(
self,
dataset: Union[BaseImageDataset, CropImageDataset]
):
self._dataset = dataset
# optionally do something to the dataset

def __len__(self):
return len(self._dataset)

@abstractmethod
def __getitem__(self, idx):
# retrieve images from dataset
input, target = self._dataset[idx]

# do something to the input and target here
# (e.g. apply transformations, generate crops, cache in RAM, etc.)

return input, target

@property
def original(self) -> Union[BaseImageDataset, CropImageDataset]:
"""
Access the original underlying dataset for metadata etc.
"""
if isinstance(self._dataset, BaseWrapperDataset):
return self._dataset.original
return self._dataset
41 changes: 41 additions & 0 deletions src/virtual_stain_flow/datasets/monai_aug_adapter_dataset.py
Comment thread
wli51 marked this conversation as resolved.
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
monai_aug_adapter_dataset.py
"""

from monai.transforms import (
Compose
)

from .base_dataset import BaseImageDataset
from .base_wrapper_dataset import BaseWrapperDataset

class MonaiAdapter(BaseWrapperDataset):
"""
Adapter dataset to wrap any BaseImageDataset and return samples
in dictionary format compatible with MONAI transforms and pipelines.
Specifically, each sample is returned as a dictionary with keys "input" and "target",
containing the input and target tensors respectively, then ran through
MONAI transforms if provided, and finally returned back as a tuple of
(input, target) tensors. It would be meaningless to use this adapter
without any MONAI transforms as the data just gets wrapped and unwrapped.
"""
def __init__(
self,
base_dataset: BaseImageDataset,
transform: Compose | None = None
):
super().__init__(base_dataset)
self._transform = transform

def __len__(self):
return len(self._dataset)

def __getitem__(self, idx):
x, y = self._dataset[idx]

sample = {"input": x, "target": y}

if self._transform is not None:
sample = self._transform(sample)

return sample["input"], sample["target"]
Loading
Loading