Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions monai/handlers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,12 +225,16 @@ def _compute_op(op: str, d: np.ndarray):
f.write(f"{class_labels[i]}{deli}{deli.join([f'{_compute_op(k, c):.4f}' for k in ops])}\n")


def from_engine(keys: KeysCollection):
def from_engine(keys: KeysCollection, first: bool = False):
"""
Utility function to simplify the `batch_transform` or `output_transform` args of ignite components
when handling dictionary or list of dictionaries(for example: `engine.state.batch` or `engine.state.output`).
Users only need to set the expected keys, then it will return a callable function to extract data from
dictionary and construct a tuple respectively.

If data is a list of dictionaries after decollating, extract expected keys and construct lists respectively,
for example, if data is `[{"A": 1, "B": 2}, {"A": 3, "B": 4}]`, from_engine(["A", "B"]): `([1, 3], [2, 4])`.

It can help avoid a complicated `lambda` function and make the arg of metrics more straight-forward.
For example, set the first key as the prediction and the second key as label to get the expected data
from `engine.state.output` for a metric::
Expand All @@ -242,15 +246,23 @@ def from_engine(keys: KeysCollection):
output_transform=from_engine(["pred", "label"])
)

Args:
keys: specified keys to extract data from dictionary or decollated list of dictionaries.
first: whether only extract sepcified keys from the first item if input data is a list of dictionaries,
it's used to extract the scalar data which doesn't have batch dim and was replicated into every
dictionary when decollating, like `loss`, etc.


"""
keys = ensure_tuple(keys)

def _wrapper(data):
if isinstance(data, dict):
return tuple(data[k] for k in keys)
elif isinstance(data, list) and isinstance(data[0], dict):
# if data is a list of dictionaries, extract expected keys and construct lists
ret = [[i[k] for i in data] for k in keys]
# if data is a list of dictionaries, extract expected keys and construct lists,
# if `first=True`, only extract keys from the first item of the list
ret = [data[0][k] if first else [i[k] for i in data] for k in keys]
return tuple(ret) if len(ret) > 1 else ret[0]

return _wrapper
4 changes: 2 additions & 2 deletions tests/test_integration_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ def _model_completed(self, engine):
train_handlers = [
LrScheduleHandler(lr_scheduler=lr_scheduler, print_lr=True),
ValidationHandler(validator=evaluator, interval=2, epoch_level=True),
StatsHandler(tag_name="train_loss", output_transform=lambda x: x[0]["loss"]),
StatsHandler(tag_name="train_loss", output_transform=from_engine("loss", first=True)),
TensorBoardStatsHandler(
summary_writer=summary_writer, tag_name="train_loss", output_transform=lambda x: x[0]["loss"]
summary_writer=summary_writer, tag_name="train_loss", output_transform=from_engine("loss", first=True)
),
CheckpointSaver(save_dir=root_dir, save_dict={"net": net, "opt": opt}, save_interval=2, epoch_level=True),
_TestTrainIterEvents(),
Expand Down