Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion monai/apps/deepgrow/interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch

from monai.data import decollate_batch, list_data_collate
from monai.engines import SupervisedEvaluator, SupervisedTrainer
from monai.engines.utils import IterationEvents
from monai.transforms import Compose
Expand Down Expand Up @@ -74,6 +75,9 @@ def __call__(self, engine: Union[SupervisedTrainer, SupervisedEvaluator], batchd
batchdata[self.key_probability] = torch.as_tensor(
([1.0 - ((1.0 / self.max_interactions) * j)] if self.train else [1.0]) * len(inputs)
)
batchdata = self.transforms(batchdata)
# decollate batch data to execute click transforms
batchdata_list = [self.transforms(i) for i in decollate_batch(batchdata, detach=True)]
# collate list into a batch for next round interaction
batchdata = list_data_collate(batchdata_list)

return engine._iteration(engine, batchdata)
103 changes: 26 additions & 77 deletions monai/apps/deepgrow/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from typing import Callable, Dict, Optional, Sequence, Union

import numpy as np
Expand Down Expand Up @@ -144,7 +145,7 @@ def _apply(self, label, sid):
def __call__(self, data):
d = dict(data)
self.randomize(data)
d[self.guidance] = self._apply(d[self.label], self.sid)
d[self.guidance] = json.dumps(self._apply(d[self.label], self.sid).astype(int).tolist())
return d


Expand All @@ -159,7 +160,7 @@ class AddGuidanceSignald(Transform):
guidance: key to store guidance.
sigma: standard deviation for Gaussian kernel.
number_intensity_ch: channel index.
batched: whether input is batched or not.

"""

def __init__(
Expand All @@ -168,17 +169,16 @@ def __init__(
guidance: str = "guidance",
sigma: int = 2,
number_intensity_ch: int = 1,
batched: bool = False,
):
self.image = image
self.guidance = guidance
self.sigma = sigma
self.number_intensity_ch = number_intensity_ch
self.batched = batched

def _get_signal(self, image, guidance):
dimensions = 3 if len(image.shape) > 3 else 2
guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance
guidance = json.loads(guidance) if isinstance(guidance, str) else guidance
if dimensions == 3:
signal = np.zeros((len(guidance), image.shape[-3], image.shape[-2], image.shape[-1]), dtype=np.float32)
else:
Expand Down Expand Up @@ -210,16 +210,9 @@ def _get_signal(self, image, guidance):
return signal

def _apply(self, image, guidance):
if not self.batched:
signal = self._get_signal(image, guidance)
return np.concatenate([image, signal], axis=0)

images = []
for i, g in zip(image, guidance):
i = i[0 : 0 + self.number_intensity_ch, ...]
signal = self._get_signal(i, g)
images.append(np.concatenate([i, signal], axis=0))
return images
signal = self._get_signal(image, guidance)
image = image[0 : 0 + self.number_intensity_ch, ...]
return np.concatenate([image, signal], axis=0)

def __call__(self, data):
d = dict(data)
Expand All @@ -234,26 +227,17 @@ class FindDiscrepancyRegionsd(Transform):
"""
Find discrepancy between prediction and actual during click interactions during training.

If batched is true:

label is in shape (B, C, D, H, W) or (B, C, H, W)
pred has same shape as label
discrepancy will have shape (B, 2, C, D, H, W) or (B, 2, C, H, W)

Args:
label: key to label source.
pred: key to prediction source.
discrepancy: key to store discrepancies found between label and prediction.
batched: whether input is batched or not.

"""

def __init__(
self, label: str = "label", pred: str = "pred", discrepancy: str = "discrepancy", batched: bool = True
):
def __init__(self, label: str = "label", pred: str = "pred", discrepancy: str = "discrepancy"):
self.label = label
self.pred = pred
self.discrepancy = discrepancy
self.batched = batched

@staticmethod
def disparity(label, pred):
Expand All @@ -266,13 +250,7 @@ def disparity(label, pred):
return [pos_disparity, neg_disparity]

def _apply(self, label, pred):
if not self.batched:
return self.disparity(label, pred)

disparity = []
for la, pr in zip(label, pred):
disparity.append(self.disparity(la, pr))
return disparity
return self.disparity(label, pred)

def __call__(self, data):
d = dict(data)
Expand All @@ -286,53 +264,32 @@ def __call__(self, data):
class AddRandomGuidanced(Randomizable, Transform):
"""
Add random guidance based on discrepancies that were found between label and prediction.

If batched is True, input shape is as below:

Guidance is of shape (B, 2, N, # of dim) where B is batch size, 2 means positive and negative,
N means how many guidance points, # of dim is the total number of dimensions of the image
(for example if the image is CDHW, then # of dim would be 4).

Discrepancy is of shape (B, 2, C, D, H, W) or (B, 2, C, H, W)

Probability is of shape (B, 1)

else:

Guidance is of shape (2, N, # of dim)

Discrepancy is of shape (2, C, D, H, W) or (2, C, H, W)

Probability is of shape (1)
input shape is as below:
Guidance is of shape (2, N, # of dim)
Discrepancy is of shape (2, C, D, H, W) or (2, C, H, W)
Probability is of shape (1)

Args:
guidance: key to guidance source.
discrepancy: key that represents discrepancies found between label and prediction.
probability: key that represents click/interaction probability.
batched: whether input is batched or not.

"""

def __init__(
self,
guidance: str = "guidance",
discrepancy: str = "discrepancy",
probability: str = "probability",
batched: bool = True,
):
self.guidance = guidance
self.discrepancy = discrepancy
self.probability = probability
self.batched = batched
self._will_interact = None

def randomize(self, data=None):
probability = data[self.probability]
if not self.batched:
self._will_interact = self.R.choice([True, False], p=[probability, 1.0 - probability])
else:
self._will_interact = []
for p in probability:
self._will_interact.append(self.R.choice([True, False], p=[p, 1.0 - p]))
self._will_interact = self.R.choice([True, False], p=[probability, 1.0 - probability])

def find_guidance(self, discrepancy):
distance = distance_transform_cdt(discrepancy).flatten()
Expand Down Expand Up @@ -368,24 +325,16 @@ def add_guidance(self, discrepancy, will_interact):

def _apply(self, guidance, discrepancy):
guidance = guidance.tolist() if isinstance(guidance, np.ndarray) else guidance
if not self.batched:
pos, neg = self.add_guidance(discrepancy, self._will_interact)
if pos:
guidance[0].append(pos)
guidance[1].append([-1] * len(pos))
if neg:
guidance[0].append([-1] * len(neg))
guidance[1].append(neg)
else:
for g, d, w in zip(guidance, discrepancy, self._will_interact):
pos, neg = self.add_guidance(d, w)
if pos:
g[0].append(pos)
g[1].append([-1] * len(pos))
if neg:
g[0].append([-1] * len(neg))
g[1].append(neg)
return np.asarray(guidance)
guidance = json.loads(guidance) if isinstance(guidance, str) else guidance
pos, neg = self.add_guidance(discrepancy, self._will_interact)
if pos:
guidance[0].append(pos)
guidance[1].append([-1] * len(pos))
if neg:
guidance[0].append([-1] * len(neg))
guidance[1].append(neg)

return json.dumps(np.asarray(guidance).astype(int).tolist())

def __call__(self, data):
d = dict(data)
Expand Down
39 changes: 30 additions & 9 deletions tests/test_deepgrow_interaction.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,39 +11,60 @@

import unittest

import numpy as np
import torch

from monai.apps.deepgrow.interaction import Interaction
from monai.apps.deepgrow.transforms import (
AddGuidanceSignald,
AddInitialSeedPointd,
AddRandomGuidanced,
FindAllValidSlicesd,
FindDiscrepancyRegionsd,
)
from monai.data import Dataset
from monai.engines import SupervisedTrainer
from monai.engines.utils import IterationEvents
from monai.transforms import Activationsd, Compose, ToNumpyd
from monai.transforms import Activationsd, Compose, ToNumpyd, ToTensord


def add_one(engine):
if engine.state.best_metric is -1:
if engine.state.best_metric == -1:
engine.state.best_metric = 0
else:
engine.state.best_metric = engine.state.best_metric + 1


class TestInteractions(unittest.TestCase):
def run_interaction(self, train, compose):
data = []
for i in range(5):
data.append({"image": torch.tensor([float(i)]), "label": torch.tensor([float(i)])})
network = torch.nn.Linear(1, 1)
data = [{"image": np.ones((1, 2, 2, 2)).astype(np.float32), "label": np.ones((1, 2, 2, 2))} for _ in range(5)]
network = torch.nn.Linear(2, 2)
lr = 1e-3
opt = torch.optim.SGD(network.parameters(), lr)
loss = torch.nn.L1Loss()
dataset = Dataset(data, transform=None)
train_transforms = Compose(
[
FindAllValidSlicesd(label="label", sids="sids"),
AddInitialSeedPointd(label="label", guidance="guidance", sids="sids"),
AddGuidanceSignald(image="image", guidance="guidance"),
ToTensord(keys=("image", "label")),
]
)
dataset = Dataset(data, transform=train_transforms)
data_loader = torch.utils.data.DataLoader(dataset, batch_size=5)

iteration_transforms = [Activationsd(keys="pred", sigmoid=True), ToNumpyd(keys="pred")]
iteration_transforms = [
Activationsd(keys="pred", sigmoid=True),
ToNumpyd(keys=["image", "label", "pred"]),
FindDiscrepancyRegionsd(label="label", pred="pred", discrepancy="discrepancy"),
AddRandomGuidanced(guidance="guidance", discrepancy="discrepancy", probability="probability"),
AddGuidanceSignald(image="image", guidance="guidance"),
ToTensord(keys=("image", "label")),
]
iteration_transforms = Compose(iteration_transforms) if compose else iteration_transforms

i = Interaction(transforms=iteration_transforms, train=train, max_interactions=5)
self.assertEqual(len(i.transforms.transforms), 2, "Mismatch in expected transforms")
self.assertEqual(len(i.transforms.transforms), 6, "Mismatch in expected transforms")

# set up engine
engine = SupervisedTrainer(
Expand Down
40 changes: 17 additions & 23 deletions tests/test_deepgrow_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@

IMAGE = np.array([[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]])
LABEL = np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]])
BATCH_IMAGE = np.array([[[[[1, 0, 2, 0, 1], [0, 1, 2, 1, 0], [2, 2, 3, 2, 2], [0, 1, 2, 1, 0], [1, 0, 2, 0, 1]]]]])
BATCH_LABEL = np.array([[[[[0, 0, 0, 0, 0], [0, 1, 0, 1, 0], [0, 0, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]])

DATA_1 = {
"image": IMAGE,
Expand Down Expand Up @@ -61,24 +59,22 @@
}

DATA_3 = {
"image": BATCH_IMAGE,
"label": BATCH_LABEL,
"pred": np.array([[[[[0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 1, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]]),
"image": IMAGE,
"label": LABEL,
"pred": np.array([[[[0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 1, 1, 0, 0], [0, 1, 0, 1, 0], [0, 0, 0, 0, 0]]]]),
}

DATA_4 = {
"image": BATCH_IMAGE,
"label": BATCH_LABEL,
"guidance": np.array([[[[1, 0, 2, 2]], [[-1, -1, -1, -1]]]]),
"image": IMAGE,
"label": LABEL,
"guidance": np.array([[[1, 0, 2, 2]], [[-1, -1, -1, -1]]]),
"discrepancy": np.array(
[
[
[[[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
[[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
]
[[[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
[[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
]
),
"probability": [1.0],
"probability": 1.0,
}

DATA_5 = {
Expand Down Expand Up @@ -192,11 +188,11 @@
ADD_INITIAL_POINT_TEST_CASE_1 = [
{"label": "label", "guidance": "guidance", "sids": "sids"},
DATA_1,
np.array([[[1, 0, 2, 2]], [[-1, -1, -1, -1]]]),
"[[[1, 0, 2, 2]], [[-1, -1, -1, -1]]]",
]

ADD_GUIDANCE_TEST_CASE_1 = [
{"image": "image", "guidance": "guidance", "batched": False},
{"image": "image", "guidance": "guidance"},
DATA_2,
np.array(
[
Expand Down Expand Up @@ -233,18 +229,16 @@
DATA_3,
np.array(
[
[
[[[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
[[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
]
[[[[0, 0, 0, 0, 0], [0, 0, 0, 1, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
[[[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]]],
]
),
]

ADD_RANDOM_GUIDANCE_TEST_CASE_1 = [
{"guidance": "guidance", "discrepancy": "discrepancy", "probability": "probability", "batched": True},
{"guidance": "guidance", "discrepancy": "discrepancy", "probability": "probability"},
DATA_4,
np.array([[[[1, 0, 2, 2], [1, 0, 1, 3]], [[-1, -1, -1, -1], [-1, -1, -1, -1]]]]),
"[[[1, 0, 2, 2], [1, 0, 1, 3]], [[-1, -1, -1, -1], [-1, -1, -1, -1]]]",
]

ADD_GUIDANCE_FROM_POINTS_TEST_CASE_1 = [
Expand Down Expand Up @@ -398,7 +392,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
add_fn = AddInitialSeedPointd(**arguments)
add_fn.set_random_state(seed)
result = add_fn(input_data)
np.testing.assert_allclose(result[arguments["guidance"]], expected_result)
self.assertEqual(result[arguments["guidance"]], expected_result)


class TestAddGuidanceSignald(unittest.TestCase):
Expand All @@ -422,7 +416,7 @@ def test_correct_results(self, arguments, input_data, expected_result):
add_fn = AddRandomGuidanced(**arguments)
add_fn.set_random_state(seed)
result = add_fn(input_data)
np.testing.assert_allclose(result[arguments["guidance"]], expected_result, rtol=1e-5)
self.assertEqual(result[arguments["guidance"]], expected_result)


class TestAddGuidanceFromPointsd(unittest.TestCase):
Expand Down