In [None]:
# | default_exp utils/rearrange

# Imports

In [None]:
# | export


import numpy as np
import torch
from einops import rearrange

# Rearranges

In [None]:
# | export


def make_channels_first(x: torch.Tensor | np.ndarray):
    """Convert an n-dimensional tensor or array to channels first format.

    Args:
        x: The input tensor / array. Should have at least 3 dimensions.

    Returns:
        The input tensor / array in channels first format.
    """
    x = rearrange(x, "b ... d -> b d ...")
    if torch.is_tensor(x):
        x = x.contiguous()

    return x

In [None]:
tests = [
    torch.randn(1, 4, 4, 4, 16),
    np.random.randn(1, 4, 4, 4, 16),
]

for input_ in tests:
    output_ = make_channels_first(input_)
    print(tuple(input_.shape), tuple(output_.shape))

(1, 4, 4, 4, 16) (1, 16, 4, 4, 4)
(1, 4, 4, 4, 16) (1, 16, 4, 4, 4)


In [None]:
# | export


def make_channels_last(x: torch.Tensor | np.ndarray):
    """Convert an n-dimensional tensor or array to channels last format.

    Args:
        x: The input tensor / array. Should have at least 3 dimensions.

    Returns:
        The input tensor / array in channels last format.
    """
    x = rearrange(x, "b d ... -> b ... d")
    if torch.is_tensor(x):
        x = x.contiguous()

    return x

In [None]:
tests = [
    torch.randn(1, 16, 4, 4, 4),
    np.random.randn(1, 16, 4, 4, 4),
]

for input_ in tests:
    output_ = make_channels_last(input_)
    print(tuple(input_.shape), tuple(output_.shape))

(1, 16, 4, 4, 4) (1, 4, 4, 4, 16)
(1, 16, 4, 4, 4) (1, 4, 4, 4, 16)


In [None]:
# | export


def rearrange_channels(x: torch.Tensor | np.ndarray, cur_channels_first: bool, new_channels_first: bool):
    """Rearrange the channels of a tensor / array to either channels_first or channels_last format.

    Args:
        x: The input tensor / array.
        cur_channels_first: Whether the input tensor / array is in channels first format.
        new_channels_first: Whether the output should be in channels first format.

    Returns:
        The input tensor / array with the channels rearranged.
    """

    if cur_channels_first is new_channels_first:
        return x
    elif cur_channels_first:
        return make_channels_last(x)
    else:
        return make_channels_first(x)

In [None]:
tests = [
    (torch.randn(1, 16, 4, 4, 4), True),
    (torch.randn(1, 4, 4, 4, 16), False),
    (np.random.randn(1, 16, 4, 4, 4), True),
    (np.random.randn(1, 4, 4, 4, 16), False),
]

for new_channels_first in [True, False]:
    for _input, cur_channels_first in tests:
        output_ = rearrange_channels(_input, cur_channels_first, new_channels_first)
        print(tuple(_input.shape), cur_channels_first, new_channels_first, tuple(output_.shape))

(1, 16, 4, 4, 4) True True (1, 16, 4, 4, 4)
(1, 4, 4, 4, 16) False True (1, 16, 4, 4, 4)
(1, 16, 4, 4, 4) True True (1, 16, 4, 4, 4)
(1, 4, 4, 4, 16) False True (1, 16, 4, 4, 4)
(1, 16, 4, 4, 4) True False (1, 4, 4, 4, 16)
(1, 4, 4, 4, 16) False False (1, 4, 4, 4, 16)
(1, 16, 4, 4, 4) True False (1, 4, 4, 4, 16)
(1, 4, 4, 4, 16) False False (1, 4, 4, 4, 16)


# nbdev

In [None]:
!nbdev_export