Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 28 additions & 4 deletions monai/networks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,37 @@

def one_hot(labels: torch.Tensor, num_classes: int, dtype: torch.dtype = torch.float, dim: int = 1) -> torch.Tensor:
"""
For a tensor `labels` of dimensions [B]1[spatial_dims], return a tensor of dimensions `[B]N[spatial_dims]`
for `num_classes` N number of classes.
For every value v in `labels`, the value in the output will be either 1 or 0. Each vector along the `dim`-th
dimension has the "one-hot" format, i.e., it has a total length of `num_classes`,
with a one and `num_class-1` zeros.
Note that this will include the background label, thus a binary mask should be treated as having two classes.

Args:
labels: input tensor of integers to be converted into the 'one-hot' format. Internally `labels` will be
converted into integers `labels.long()`.
num_classes: number of output channels, the corresponding length of `labels[dim]` will be converted to
`num_classes` from `1`.
dtype: the data type of the output one_hot label.
dim: the dimension to be converted to `num_classes` channels from `1` channel.

Example:

For every value v = labels[b,1,h,w], the value in the result at [b,v,h,w] will be 1 and all others 0.
Note that this will include the background label, thus a binary mask should be treated as having 2 classes.
For a tensor `labels` of dimensions [B]1[spatial_dims], return a tensor of dimensions `[B]N[spatial_dims]`
when `num_classes=N` number of classes and `dim=1`.

.. code-block:: python

from monai.networks.utils import one_hot
import torch

a = torch.randint(0, 2, size=(1, 2, 2, 2))
out = one_hot(a, num_classes=2, dim=0)
print(out.shape) # torch.Size([2, 2, 2, 2])

a = torch.randint(0, 2, size=(2, 1, 2, 2, 2))
out = one_hot(a, num_classes=2, dim=1)
print(out.shape) # torch.Size([2, 2, 2, 2, 2])

"""
if labels.dim() == 0:
# if no channel dim, add it
Expand Down