In [None]:
# | default_exp utils/rearrange

# Imports

In [None]:
# | export


import numpy as np
import torch
from einops import rearrange

# Rearranges

In [None]:
# | export


def make_channels_first(x: torch.Tensor | np.ndarray):
    """Convert an n-dimensional tensor or array to channels first format.

    Args:
        x: The input tensor / array. Should have at least 3 dimensions.

    Returns:
        The input tensor / array in channels first format.
    """
    x = rearrange(x, "b ... d -> b d ...")
    if torch.is_tensor(x):
        x = x.contiguous()

    return x

In [None]:
tests = [
    torch.randn(1, 4, 4, 4, 16),
    np.random.randn(1, 4, 4, 4, 16),
]

for input_ in tests:
    output_ = make_channels_first(input_)
    print(tuple(input_.shape), tuple(output_.shape))

(1, 4, 4, 4, 16) (1, 16, 4, 4, 4)
(1, 4, 4, 4, 16) (1, 16, 4, 4, 4)


In [None]:
# | export


def make_channels_last(x: torch.Tensor | np.ndarray):
    """Convert an n-dimensional tensor or array to channels last format.

    Args:
        x: The input tensor / array. Should have at least 3 dimensions.

    Returns:
        The input tensor / array in channels last format.
    """
    x = rearrange(x, "b d ... -> b ... d")
    if torch.is_tensor(x):
        x = x.contiguous()

    return x

In [None]:
tests = [
    torch.randn(1, 16, 4, 4, 4),
    np.random.randn(1, 16, 4, 4, 4),
]

for input_ in tests:
    output_ = make_channels_last(input_)
    print(tuple(input_.shape), tuple(output_.shape))

(1, 16, 4, 4, 4) (1, 4, 4, 4, 16)
(1, 16, 4, 4, 4) (1, 4, 4, 4, 16)


# nbdev

In [None]:
!nbdev_export