From 6f75ac43dbffad6efeebd2f397a27720f12e7c3b Mon Sep 17 00:00:00 2001 From: Wenqi Li Date: Mon, 5 Jul 2021 15:53:26 +0100 Subject: [PATCH] enhance one-hot documentation Signed-off-by: Wenqi Li --- monai/networks/utils.py | 32 ++++++++++++++++++++++++++++---- 1 file changed, 28 insertions(+), 4 deletions(-) diff --git a/monai/networks/utils.py b/monai/networks/utils.py index bb02f78de9..d85175ef7e 100644 --- a/monai/networks/utils.py +++ b/monai/networks/utils.py @@ -37,13 +37,37 @@ def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1) -> torch.Tensor: """ - For a tensor `labels` of dimensions [B]1[spatial_dims], return a tensor of dimensions `[B]N[spatial_dims]` - for `num_classes` N number of classes. + For every value v in `labels`, the value in the output will be either 1 or 0. Each vector along the `dim`-th + dimension has the "one-hot" format, i.e., it has a total length of `num_classes`, + with a one and `num_class-1` zeros. + Note that this will include the background label, thus a binary mask should be treated as having two classes. + + Args: + labels: input tensor of integers to be converted into the 'one-hot' format. Internally `labels` will be + converted into integers `labels.long()`. + num_classes: number of output channels, the corresponding length of `labels[dim]` will be converted to + `num_classes` from `1`. + dtype: the data type of the output one_hot label. + dim: the dimension to be converted to `num_classes` channels from `1` channel. Example: - For every value v = labels[b,1,h,w], the value in the result at [b,v,h,w] will be 1 and all others 0. - Note that this will include the background label, thus a binary mask should be treated as having 2 classes. + For a tensor `labels` of dimensions [B]1[spatial_dims], return a tensor of dimensions `[B]N[spatial_dims]` + when `num_classes=N` number of classes and `dim=1`. + + .. code-block:: python + + from monai.networks.utils import one_hot + import torch + + a = torch.randint(0, 2, size=(1, 2, 2, 2)) + out = one_hot(a, num_classes=2, dim=0) + print(out.shape) # torch.Size([2, 2, 2, 2]) + + a = torch.randint(0, 2, size=(2, 1, 2, 2, 2)) + out = one_hot(a, num_classes=2, dim=1) + print(out.shape) # torch.Size([2, 2, 2, 2, 2]) + """ if labels.dim() == 0: # if no channel dim, add it