In [None]:
# | default_exp nets/unetr_3d_decoder

# Imports

In [None]:
# | export

import torch
from einops import rearrange
from huggingface_hub import PyTorchModelHubMixin
from torch import nn

from vision_architectures.blocks.cnn import CNNBlock3D
from vision_architectures.docstrings import populate_docstring
from vision_architectures.utils.custom_base_model import CustomBaseModel, Field, model_validator

# Config

In [None]:
# | export

KernelSizeType = int | tuple[int, int, int]


class UNetR3DDecoderConfig(CustomBaseModel):
    num_outputs: int = Field(..., description="The number of output channels")
    conv_kernel_size: KernelSizeType = Field(..., description="The kernel size of the convolution layers")
    final_layer_kernel_size: KernelSizeType = Field(..., description="The kernel size of the final layer")


class UNetR3DStageConfig(CustomBaseModel):
    in_dim: int = Field(..., description="The number of input channels")
    out_dim: int = Field(..., description="The number of output channels")
    in_patch_size: tuple[int, int, int] = Field(..., description="The patch size of the input")
    out_patch_size: tuple[int, int, int] = Field(..., description="The patch size of the output")


class UNetR3DConfig(CustomBaseModel):
    in_channels: int = Field(..., description="The number of input channels")

    decoder: UNetR3DDecoderConfig = Field(..., description="The decoder configuration")
    stages: list[UNetR3DStageConfig] = Field(..., description="The stage configurations")

    @model_validator(mode="after")
    def validate(self):
        out_dim = None
        out_patch_size = None
        for stage in self.stages:
            if out_dim is not None:
                assert stage.in_dim == out_dim, "in_dim of each stage should match the out_stage of the previous stage"
                assert (
                    stage.in_patch_size == out_patch_size
                ), "in_patch_size of each stage should match the out_patch_size of the previous stage"
            out_dim = stage.out_dim
            out_patch_size = stage.out_patch_size
        return self

In [None]:
test_config = UNetR3DConfig.model_validate(
    {
        "in_channels": 1,
        "decoder": {
            "conv_kernel_size": (3, 3, 3),
            "final_layer_kernel_size": (5, 5, 5),
            "num_outputs": 5,
        },
        "stages": [
            {
                "in_dim": 12,
                "in_patch_size": (1, 4, 4),
                "out_dim": 12,
                "out_patch_size": (1, 4, 4),
            },
            {
                "in_dim": 12,
                "in_patch_size": (1, 4, 4),
                "out_dim": 48,
                "out_patch_size": (2, 8, 8),
            },
            {
                "in_dim": 48,
                "in_patch_size": (2, 8, 8),
                "out_dim": 192,
                "out_patch_size": (4, 16, 16),
            },
            {
                "in_dim": 192,
                "in_patch_size": (4, 16, 16),
                "out_dim": 768,
                "out_patch_size": (8, 32, 32),
            },
        ],
    }
)

test_config


[1;35mUNetR3DConfig[0m[1m([0m
    [33min_channels[0m=[1;36m1[0m,
    [33mdecoder[0m=[1;35mUNetR3DDecoderConfig[0m[1m([0m
        [33mnum_outputs[0m=[1;36m5[0m,
        [33mconv_kernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m,
        [33mfinal_layer_kernel_size[0m=[1m([0m[1;36m5[0m, [1;36m5[0m, [1;36m5[0m[1m)[0m
    [1m)[0m,
    [33mstages[0m=[1m[[0m
        [1;35mUNetR3DStageConfig[0m[1m([0m[33min_dim[0m=[1;36m12[0m, [33mout_dim[0m=[1;36m12[0m, [33min_patch_size[0m=[1m([0m[1;36m1[0m, [1;36m4[0m, [1;36m4[0m[1m)[0m, [33mout_patch_size[0m=[1m([0m[1;36m1[0m, [1;36m4[0m, [1;36m4[0m[1m)[0m[1m)[0m,
        [1;35mUNetR3DStageConfig[0m[1m([0m[33min_dim[0m=[1;36m12[0m, [33mout_dim[0m=[1;36m48[0m, [33min_patch_size[0m=[1m([0m[1;36m1[0m, [1;36m4[0m, [1;36m4[0m[1m)[0m, [33mout_patch_size[0m=[1m([0m[1;36m2[0m, [1;36m8[0m, [1;36m8[0m[1m)[0m[1m)[0m,
        [1;

# Architecture

### Basic block

In [None]:
# | export


@populate_docstring
class _UNetR3DBlock(nn.Module):
    """The conv and deconv layers that make up a UNetR3D block. {CLASS_DESCRIPTION_3D_DOC}"""

    @populate_docstring
    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        conv_kernel_size,
        deconv_kernel_size,
        is_first_layer: bool,
        checkpointing_level: int = 0,
    ):
        """Initialize the UNetR3DBlock.

        Args:
            in_dim: Input dimension of the block.
            out_dim: Output dimension of the block.
            conv_kernel_size: Kernel size for the convolution layer.
            deconv_kernel_size: Kernel size for the deconvolution layer.
            is_first_layer: Whether this is the first layer of the UNetR3D.
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
        """
        super().__init__()

        self.conv = CNNBlock3D(
            in_channels=in_dim,
            out_channels=in_dim,
            kernel_size=conv_kernel_size,
            stride=1,
            padding=tuple([k // 2 for k in conv_kernel_size]),
            normalization="batchnorm3d",
            activation="relu",
            sequence="CNA",
            checkpointing_level=checkpointing_level,
        )
        if not is_first_layer:
            in_dim = in_dim * 2
        self.deconv = CNNBlock3D(
            in_channels=in_dim,
            out_channels=out_dim,
            transposed=True,
            kernel_size=deconv_kernel_size,
            stride=deconv_kernel_size,
            padding=0,
            normalization="batchnorm3d",
            activation="relu",
            sequence="CNA",
            checkpointing_level=checkpointing_level,
        )

    @populate_docstring
    def forward(self, current_layer_output, previous_layer_output=None) -> torch.Tensor:
        """Process a single scale of the UNet.

        Args:
            current_layer_output: Current layer output.
            previous_layer_output: Previous layer output.

        Returns:
            {OUTPUT_3D_DOC}
        """
        x = self.conv(current_layer_output)
        if previous_layer_output is not None:
            x = torch.cat([x, previous_layer_output], dim=1)
        x = self.deconv(x)
        return x

In [None]:
test = _UNetR3DBlock(768, 192, (3, 3, 3), (2, 2, 2), is_first_layer=True)
display(test)
display(test(torch.randn(2, 768, 4, 8, 8)).shape)


[1;35m_UNetR3DBlock[0m[1m([0m
  [1m([0mconv[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
    [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m768[0m, [1;36m768[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mnorm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m768[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
    [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mdeconv[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
    [1m([0mconv[1m)[0m: [1;35mCo

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m192[0m, [1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m

In [None]:
test = _UNetR3DBlock(192, 48, (3, 3, 3), (2, 2, 2), is_first_layer=False)
display(test)
display(test(torch.randn(2, 192, 8, 16, 16), torch.randn(2, 192, 8, 16, 16)).shape)


[1;35m_UNetR3DBlock[0m[1m([0m
  [1m([0mconv[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
    [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m192[0m, [1;36m192[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mnorm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m192[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
    [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m[33menabled[0m=[3;91mFalse[0m[1m)[0m
  [1m)[0m
  [1m([0mdeconv[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
    [1m([0mconv[1m)[0m: [1;35mCo

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m48[0m, [1;36m16[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m

### Complete decoder

In [None]:
# | export


@populate_docstring
class UNetR3DDecoder(nn.Module, PyTorchModelHubMixin):
    """UNetR3DDecoder made using multiple conv and deconv blocks."""

    @populate_docstring
    def __init__(self, config: UNetR3DConfig = {}, checkpointing_level: int = 0, **kwargs):
        """Initialize the UNetR3DDecoder.

        Args:
            config: {CONFIG_INSTANCE_DOC}
            checkpointing_level: {CHECKPOINTING_LEVEL_DOC}
            **kwargs: {CONFIG_KWARGS_DOC}
        """
        super().__init__()

        self.config = UNetR3DConfig.model_validate(config | kwargs)

        self.blocks = nn.ModuleList()
        for i in range(len(config.stages)):
            stage = config.stages[-i - 1]

            in_dim = stage.out_dim
            is_first_layer = i == 0

            if i == len(config.stages) - 1:
                out_dim = config.in_channels
                deconv_kernel_size = stage.out_patch_size
            else:
                out_dim = stage.in_dim
                deconv_kernel_size = tuple([o // i for o, i in zip(stage.out_patch_size, stage.in_patch_size)])

            self.blocks.append(
                _UNetR3DBlock(
                    in_dim=in_dim,
                    out_dim=out_dim,
                    conv_kernel_size=self.config.decoder.conv_kernel_size,
                    deconv_kernel_size=deconv_kernel_size,
                    is_first_layer=is_first_layer,
                    checkpointing_level=checkpointing_level,
                )
            )
        self.scan_conv = nn.Conv3d(
            config.in_channels,
            config.in_channels,
            kernel_size=self.config.decoder.final_layer_kernel_size,
            padding=tuple([k // 2 for k in self.config.decoder.final_layer_kernel_size]),
        )
        self.final_conv = nn.Conv3d(
            config.in_channels * 2,
            config.decoder.num_outputs,
            kernel_size=self.config.decoder.final_layer_kernel_size,
            padding=tuple([k // 2 for k in self.config.decoder.final_layer_kernel_size]),
        )

    @populate_docstring
    def forward(self, embeddings, scan) -> torch.Tensor:
        """Process the multi-scale input embeddings and scan datapoint.

        Args:
            embeddings: {INPUT_3D_DOC}
            scan: {INPUT_3D_DOC}

        Returns:
            {OUTPUT_3D_DOC}
        """
        # embeddings is a list of (B, C_layer, D_layer, W_layer, H_layer)
        embeddings = embeddings[::-1]

        decoded = None
        for i in range(len(embeddings)):
            embedding = embeddings[i]
            if i == 0:
                decoded = self.blocks[i](embedding)
            else:
                decoded = self.blocks[i](embedding, decoded)

        high_resolution_embeddings = self.scan_conv(scan)
        final_embeddings = torch.cat([high_resolution_embeddings, decoded], dim=1)
        decoded = self.final_conv(final_embeddings)

        return decoded

    @staticmethod
    def _reduce(loss, reduction):
        if reduction is None:
            return loss
        elif reduction == "mean":
            return loss.mean()
        elif reduction == "sum":
            return loss.sum()
        else:
            raise NotImplementedError("Please implement the reduction type")

    @staticmethod
    def soft_dice_loss_fn(
        prediction: torch.Tensor,
        target: torch.Tensor,
        reduction="mean",
        ignore_index: int = -100,
        smooth: float = 1e-5,
    ):
        """
        Both prediction and target should be of the form (batch_size, num_classes, depth, width, height).

        prediction: probability scores for each class
        target: should be binary masks.
        """

        num_classes = prediction.shape[1]

        prediction = rearrange(prediction, "b n d h w -> b n (d h w)").contiguous()
        target = rearrange(target, "b n d h w -> b n (d h w)").contiguous()

        if ignore_index is not None:
            # Remove gradients of the predictions based on the target
            mask = target != ignore_index
            prediction = prediction * mask
            target = target * mask

        loss = 1 - (1 / num_classes) * (
            (2 * (prediction * target).sum(dim=2) + smooth)
            / ((prediction**2).sum(dim=2) + (target**2).sum(dim=2) + smooth)
        ).sum(dim=1)
        loss = UNetR3DDecoder._reduce(loss, reduction)

        return loss

    @staticmethod
    def cross_entropy_loss_fn(
        prediction: torch.Tensor,
        target: torch.Tensor,
        reduction="mean",
        ignore_index: int = -100,
        smooth: float = 1e-5,
    ):
        """
        Both prediction and target should be of the form (batch_size, num_classes, depth, width, height).

        prediction: probability scores for each class
        target: should be binary masks.
        """

        num_voxels = torch.prod(torch.tensor(prediction.shape[2:]))

        prediction = rearrange(prediction, "b n d h w -> b n (d h w)").contiguous()
        target = rearrange(target, "b n d h w -> b n (d h w)").contiguous()

        if ignore_index is not None:
            # Remove gradients of the predictions based on the target
            mask = target != ignore_index
            prediction = prediction * mask
            target = target * mask

        loss = -(1 / num_voxels) * (target * torch.log(prediction + smooth)).sum(dim=(1, 2))
        loss = UNetR3DDecoder._reduce(loss, reduction)

        return loss

    @staticmethod
    def loss_fn(
        prediction: torch.Tensor,
        target: torch.Tensor,
        reduction="mean",
        weight_dsc=1.0,
        weight_ce=1.0,
        ignore_index=-100,
        smooth: float = 1e-5,
        return_components=False,
    ):
        """
        Both prediction and target should be of the form (batch_size, num_classes, depth, width, height).

        prediction: probability scores for each class
        target: should be binary masks.
        """

        loss1 = UNetR3DDecoder.soft_dice_loss_fn(
            prediction, target, reduction=None, ignore_index=ignore_index, smooth=smooth
        )
        loss2 = UNetR3DDecoder.cross_entropy_loss_fn(
            prediction, target, reduction=None, ignore_index=ignore_index, smooth=smooth
        )
        loss = weight_dsc * loss1 + weight_ce * loss2

        loss = UNetR3DDecoder._reduce(loss, reduction)

        if return_components:
            loss1 = UNetR3DDecoder._reduce(loss1, reduction)
            loss2 = UNetR3DDecoder._reduce(loss2, reduction)
            return loss, [loss1, loss2]
        return loss

In [None]:
test = UNetR3DDecoder(test_config)
display(test)
o = test(
    [
        torch.randn(2, 12, 32, 64, 64),
        torch.randn(2, 48, 16, 32, 32),
        torch.randn(2, 192, 8, 16, 16),
        torch.randn(2, 768, 4, 8, 8),
    ],
    torch.randn(2, 1, 32, 256, 256),
)
display(o.shape)


[1;35mUNetR3DDecoder[0m[1m([0m
  [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35m_UNetR3DBlock[0m[1m([0m
      [1m([0mconv[1m)[0m: [1;35mCNNBlock3D[0m[1m([0m
        [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m768[0m, [1;36m768[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
        [1m([0mnorm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m768[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mact[1m)[0m: [1;35mReLU[0m[1m([0m[1m)[0m
        [1m([0mcheckpointing_level1[1m)[0m: [1;35mActivationCheckpointing[0m[1m([0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m5[0m, [1;36m32[0m, [1;36m256[0m, [1;36m256[0m[1m][0m[1m)[0m

In [None]:
pred = torch.softmax(o, dim=1)
gt = torch.randint(0, 2, pred.shape)

print(pred.shape, gt.shape)
test.loss_fn(pred, gt, return_components=True)

torch.Size([2, 5, 32, 256, 256]) torch.Size([2, 5, 32, 256, 256])



[1m([0m
    [1;35mtensor[0m[1m([0m[1;36m4.7840[0m, [33mgrad_fn[0m=[1m<[0m[1;95mMeanBackward0[0m[39m>[0m[1;39m)[0m[39m,[0m
[39m    [0m[1;39m[[0m[1;35mtensor[0m[1;39m([0m[1;36m0.6329[0m[39m, [0m[33mgrad_fn[0m[39m=<MeanBackward0>[0m[1;39m)[0m[39m, [0m[1;35mtensor[0m[1;39m([0m[1;36m4.1511[0m[39m, [0m[33mgrad_fn[0m[39m=<MeanBackward0[0m[1m>[0m[1m)[0m[1m][0m
[1m)[0m

In [None]:
pred = torch.softmax(o, dim=1)
gt = torch.full(pred.shape, -100)

print(pred.shape, gt.shape)
test.loss_fn(pred, gt, return_components=True)

torch.Size([2, 5, 32, 256, 256]) torch.Size([2, 5, 32, 256, 256])



[1m([0m
    [1;35mtensor[0m[1m([0m[1;36m0[0m., [33mgrad_fn[0m=[1m<[0m[1;95mMeanBackward0[0m[39m>[0m[1;39m)[0m[39m,[0m
[39m    [0m[1;39m[[0m[1;35mtensor[0m[1;39m([0m[1;36m0[0m[39m., [0m[33mgrad_fn[0m[39m=<MeanBackward0>[0m[1;39m)[0m[39m, [0m[1;35mtensor[0m[1;39m([0m[1;36m0[0m[39m., [0m[33mgrad_fn[0m[39m=<MeanBackward0[0m[1m>[0m[1m)[0m[1m][0m
[1m)[0m

# nbdev

In [None]:
!nbdev_export