In [1]:
# | default_exp unetr3d_decoder

# Imports

In [2]:
# | export

import torch

from torch import nn

# Architecture

### Basic block

In [3]:
# | export


class UNetR3DBlock(nn.Module):
    def __init__(self, in_dim, out_dim, conv_kernel_size, deconv_kernel_size, is_first_layer):
        super().__init__()

        self.conv = nn.Conv3d(
            in_dim, in_dim, kernel_size=conv_kernel_size, padding=tuple([k // 2 for k in conv_kernel_size])
        )
        if not is_first_layer:
            in_dim = in_dim * 2
        self.deconv = nn.ConvTranspose3d(
            in_dim, out_dim, kernel_size=deconv_kernel_size, stride=deconv_kernel_size, padding=0
        )

    def forward(self, current_layer, previous_layer=None):
        x = self.conv(current_layer)
        if previous_layer is not None:
            x = torch.cat([x, previous_layer], dim=1)
        x = self.deconv(x)
        return x

In [4]:
test = UNetR3DBlock(768, 192, (3, 3, 3), (2, 2, 2), is_first_layer=True)
display(test)
display(test(torch.randn(2, 768, 4, 8, 8)).shape)


[1;35mUNetR3DBlock[0m[1m([0m
  [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m768[0m, [1;36m768[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m[1m)[0m
  [1m([0mdeconv[1m)[0m: [1;35mConvTranspose3d[0m[1m([0m[1;36m768[0m, [1;36m192[0m, [33mkernel_size[0m=[1m([0m[1;36m2[0m, [1;36m2[0m, [1;36m2[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m2[0m, [1;36m2[0m, [1;36m2[0m[1m)[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m192[0m, [1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m

In [5]:
test = UNetR3DBlock(192, 48, (3, 3, 3), (2, 2, 2), is_first_layer=False)
display(test)
display(test(torch.randn(2, 192, 8, 16, 16), torch.randn(2, 192, 8, 16, 16)).shape)


[1;35mUNetR3DBlock[0m[1m([0m
  [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m192[0m, [1;36m192[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m[1m)[0m
  [1m([0mdeconv[1m)[0m: [1;35mConvTranspose3d[0m[1m([0m[1;36m384[0m, [1;36m48[0m, [33mkernel_size[0m=[1m([0m[1;36m2[0m, [1;36m2[0m, [1;36m2[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m2[0m, [1;36m2[0m, [1;36m2[0m[1m)[0m[1m)[0m
[1m)[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m48[0m, [1;36m16[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m

### Complete decoder

In [6]:
# | export


class UNetR3DDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()

        self.blocks = nn.ModuleList()
        for i in range(len(config["stages"])):
            stage = config["stages"][-i - 1]

            in_dim = stage["_out_dim"]
            is_first_layer = i == 0

            if i == len(config["stages"]) - 1:
                out_dim = config["decoder"]["num_outputs"]
                deconv_kernel_size = stage["_out_patch_size"]
            else:
                out_dim = stage["_in_dim"]
                deconv_kernel_size = tuple([o // i for o, i in zip(stage["_out_patch_size"], stage["_in_patch_size"])])

            self.blocks.append(
                UNetR3DBlock(
                    in_dim=in_dim,
                    out_dim=out_dim,
                    conv_kernel_size=config["decoder"]["conv_kernel_size"],
                    deconv_kernel_size=deconv_kernel_size,
                    is_first_layer=is_first_layer,
                )
            )

    def forward(self, embeddings):
        # embeddings is a list of (B, C_layer, D_layer, W_layer, H_layer)
        embeddings = embeddings[::-1]

        decoded = None
        for i in range(len(embeddings)):
            embedding = embeddings[i]
            if i == 0:
                decoded = self.blocks[i](embedding)
            else:
                decoded = self.blocks[i](embedding, decoded)

        return decoded

In [7]:
test_config = {
    "decoder": {
        "conv_kernel_size": (3, 3, 3),
        "num_outputs": 5,
    },
    "stages": [
        {
            "_in_dim": 12,
            "_in_patch_size": (1, 4, 4),
            "_out_dim": 12,
            "_out_patch_size": (1, 4, 4),
        },
        {
            "_in_dim": 12,
            "_in_patch_size": (1, 4, 4),
            "_out_dim": 48,
            "_out_patch_size": (2, 8, 8),
        },
        {
            "_in_dim": 48,
            "_in_patch_size": (2, 8, 8),
            "_out_dim": 192,
            "_out_patch_size": (4, 16, 16),
        },
        {
            "_in_dim": 192,
            "_in_patch_size": (4, 16, 16),
            "_out_dim": 768,
            "_out_patch_size": (8, 32, 32),
        },
    ],
}

test = UNetR3DDecoder(test_config)
display(test)
display(
    test(
        [
            torch.randn(2, 12, 32, 64, 64),
            torch.randn(2, 48, 16, 32, 32),
            torch.randn(2, 192, 8, 16, 16),
            torch.randn(2, 768, 4, 8, 8),
        ]
    ).shape
)


[1;35mUNetR3DDecoder[0m[1m([0m
  [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mUNetR3DBlock[0m[1m([0m
      [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m768[0m, [1;36m768[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m[1m)[0m
      [1m([0mdeconv[1m)[0m: [1;35mConvTranspose3d[0m[1m([0m[1;36m768[0m, [1;36m192[0m, [33mkernel_size[0m=[1m([0m[1;36m2[0m, [1;36m2[0m, [1;36m2[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m2[0m, [1;36m2[0m, [1;36m2[0m[1m)[0m[1m)[0m
    [1m)[0m
    [1m([0m[1;36m1[0m[1m)[0m: [1;35mUNetR3DBlock[0m[1m([0m
      [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m192[0m, [1;36m192[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m5[0m, [1;36m32[0m, [1;36m256[0m, [1;36m256[0m[1m][0m[1m)[0m

In [8]:
from neuro_utils.describe import describe_model

describe_model(test)

* 'smart_union' has been removed


Total Parameters: 18,326,189
+------------------------+------------+
|         Module         | Parameters |
+------------------------+------------+
|  blocks.0.conv.weight  | 15,925,248 |
|   blocks.0.conv.bias   |    768     |
| blocks.0.deconv.weight | 1,179,648  |
|  blocks.0.deconv.bias  |    192     |
|  blocks.1.conv.weight  |  995,328   |
|   blocks.1.conv.bias   |    192     |
| blocks.1.deconv.weight |  147,456   |
|  blocks.1.deconv.bias  |     48     |
|  blocks.2.conv.weight  |   62,208   |
|   blocks.2.conv.bias   |     48     |
| blocks.2.deconv.weight |   9,216    |
|  blocks.2.deconv.bias  |     12     |
|  blocks.3.conv.weight  |   3,888    |
|   blocks.3.conv.bias   |     12     |
| blocks.3.deconv.weight |   1,920    |
|  blocks.3.deconv.bias  |     5      |
+------------------------+------------+


# nbdev

In [9]:
!nbdev_export