In [None]:
# | default_exp layers/upsample

# Imports

In [None]:
# | export


import torch
from einops import rearrange
from torch import nn

from vision_architectures.blocks.cnn import CNNBlock3D, CNNBlockConfig
from vision_architectures.utils.activation_checkpointing import ActivationCheckpointing
from vision_architectures.utils.custom_base_model import Field
from vision_architectures.utils.rearrange import rearrange_channels

# Config

In [None]:
# | export


class PixelShuffleUpsampleConfig(CNNBlockConfig):
    """
    Configuration class for PixelShuffleUpsample.
    """

    scale_factor: int = Field(2, description="Scale factor for upsampling.")

# Upsample layers

In [None]:
# | export


class PixelShuffleUpsample3D(nn.Module):
    def __init__(self, config: PixelShuffleUpsampleConfig = {}, checkpointing_level: int = 0, **kwargs):
        super().__init__()

        self.config = PixelShuffleUpsampleConfig.model_validate(config | kwargs)

        expand_config = CNNBlockConfig.model_validate(self.config)
        expand_config.out_channels = expand_config.out_channels * (self.config.scale_factor**3)
        self.expand = CNNBlock3D(expand_config, checkpointing_level)

        self.checkpointing_level1 = ActivationCheckpointing(1, checkpointing_level)

    def _forward(self, x: torch.Tensor, channels_first: bool = True) -> torch.Tensor:
        # x: (b, [in_channels], z, y, x, [in_channels])

        x = rearrange_channels(x, channels_first, True)
        # (b, in_channels, z, y, x)

        x = self.expand(x)
        # (b, out_channels * scale_factor**2, y, x)
        x = rearrange(
            x,
            "b (c s1 s2 s3) z y x -> b c (z s1) (y s2) (x s3)",
            s1=self.config.scale_factor,
            s2=self.config.scale_factor,
            s3=self.config.scale_factor,
        ).contiguous()
        # (b, out_channels, z * scale_factor, y * scale_factor, x * scale_factor)

        x = rearrange_channels(x, True, channels_first)
        # (b, [out_channels], z * scale_factor, y * scale_factor, x * scale_factor, [out_channels])

        return x

    def forward(self, *args, **kwargs):
        return self.checkpointing_level1(self._forward, *args, **kwargs)

In [None]:
test = PixelShuffleUpsample3D(
    in_channels=1024, out_channels=512, kernel_size=1, activation=None, normalization=None, padding=0
)
display(test)

sample_input = torch.randn(2, 1024, 4, 4, 4)

sum([param.numel() for param in test.parameters()]), test(sample_input).shape


[1;35mPixelShuffleUpsample3D[0m[1m([0m
  [1m([0mexpand[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
    [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m1024[0m, [1;36m4096[0m, [33mkernel_size[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m[1m)[0m
    [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
[1m)[0m

[1m([0m[1;36m4198400[0m, [1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m512[0m, [1;36m8[0m, [1;36m8[0m, [1;36m8[0m[1m][0m[1m)[0m[1m)[0m

# nbdev

In [None]:
!nbdev_export