In [1]:
# | default_exp nets/unetr_3d_decoder

# Imports

In [2]:
# | export

import torch
from einops import rearrange
from huggingface_hub import PyTorchModelHubMixin
from torch import nn

# Architecture

### Conv block

In [3]:
# | export


class UNetR3DConvBlock(nn.Module):
    def __init__(self, dim, kernel_size):
        super().__init__()

        self.conv = nn.Conv3d(
            dim,
            dim,
            kernel_size=kernel_size,
            padding=tuple([k // 2 for k in kernel_size]),
            bias=False,
        )
        self.batch_norm = nn.BatchNorm3d(dim)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.conv(x)
        x = self.batch_norm(x)
        x = self.relu(x)
        return x

In [4]:
# | export


class UNetR3DDeConvBlock(nn.Module):
    def __init__(self, in_dim, out_dim, kernel_size):
        super().__init__()

        self.deconv = nn.ConvTranspose3d(
            in_dim,
            out_dim,
            kernel_size=kernel_size,
            stride=kernel_size,
            padding=0,
            bias=False,
        )
        self.batch_norm = nn.BatchNorm3d(out_dim)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        x = self.deconv(x)
        x = self.batch_norm(x)
        x = self.relu(x)
        return x

### Basic block

In [5]:
# | export


class UNetR3DBlock(nn.Module):
    def __init__(self, in_dim, out_dim, conv_kernel_size, deconv_kernel_size, is_first_layer):
        super().__init__()

        self.conv = UNetR3DConvBlock(in_dim, conv_kernel_size)
        if not is_first_layer:
            in_dim = in_dim * 2
        self.deconv = UNetR3DDeConvBlock(in_dim, out_dim, deconv_kernel_size)

    def forward(self, current_layer, previous_layer=None):
        x = self.conv(current_layer)
        if previous_layer is not None:
            x = torch.cat([x, previous_layer], dim=1)
        x = self.deconv(x)
        return x

In [6]:
test = UNetR3DBlock(768, 192, (3, 3, 3), (2, 2, 2), is_first_layer=True)
display(test)
display(test(torch.randn(2, 768, 4, 8, 8)).shape)


[1;35mUNetR3DBlock[0m[1m([0m
  [1m([0mconv[1m)[0m: [1;35mUNetR3DConvBlock[0m[1m([0m
    [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m768[0m, [1;36m768[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mbatch_norm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m768[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mrelu[1m)[0m: [1;35mReLU[0m[1m([0m[33minplace[0m=[3;92mTrue[0m[1m)[0m
  [1m)[0m
  [1m([0mdeconv[1m)[0m: [1;35mUNetR3DDeConvBlock[0m[1m([0m
    [1m([0mdeconv[1m)[0m: [1;35mConvTranspose3d[0m[1m([0m[1;36m768[0m, [1;36m192[0m, [33mkernel_size[

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m192[0m, [1;36m8[0m, [1;36m16[0m, [1;36m16[0m[1m][0m[1m)[0m

In [7]:
test = UNetR3DBlock(192, 48, (3, 3, 3), (2, 2, 2), is_first_layer=False)
display(test)
display(test(torch.randn(2, 192, 8, 16, 16), torch.randn(2, 192, 8, 16, 16)).shape)


[1;35mUNetR3DBlock[0m[1m([0m
  [1m([0mconv[1m)[0m: [1;35mUNetR3DConvBlock[0m[1m([0m
    [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m192[0m, [1;36m192[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
    [1m([0mbatch_norm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m192[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
    [1m([0mrelu[1m)[0m: [1;35mReLU[0m[1m([0m[33minplace[0m=[3;92mTrue[0m[1m)[0m
  [1m)[0m
  [1m([0mdeconv[1m)[0m: [1;35mUNetR3DDeConvBlock[0m[1m([0m
    [1m([0mdeconv[1m)[0m: [1;35mConvTranspose3d[0m[1m([0m[1;36m384[0m, [1;36m48[0m, [33mkernel_size[0

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m48[0m, [1;36m16[0m, [1;36m32[0m, [1;36m32[0m[1m][0m[1m)[0m

### Complete decoder

In [8]:
# | export


class UNetR3DDecoder(nn.Module, PyTorchModelHubMixin):
    def __init__(self, config):
        super().__init__()

        conv_kernel_size = config["decoder"]["conv_kernel_size"]
        final_layer_kernel_size = config["decoder"]["final_layer_kernel_size"]

        self.blocks = nn.ModuleList()
        for i in range(len(config["stages"])):
            stage = config["stages"][-i - 1]

            in_dim = stage["_out_dim"]
            is_first_layer = i == 0

            if i == len(config["stages"]) - 1:
                out_dim = config["in_channels"]
                deconv_kernel_size = stage["_out_patch_size"]
            else:
                out_dim = stage["_in_dim"]
                deconv_kernel_size = tuple([o // i for o, i in zip(stage["_out_patch_size"], stage["_in_patch_size"])])

            self.blocks.append(
                UNetR3DBlock(
                    in_dim=in_dim,
                    out_dim=out_dim,
                    conv_kernel_size=conv_kernel_size,
                    deconv_kernel_size=deconv_kernel_size,
                    is_first_layer=is_first_layer,
                )
            )
        self.scan_conv = nn.Conv3d(
            config["in_channels"],
            config["in_channels"],
            kernel_size=final_layer_kernel_size,
            padding=tuple([k // 2 for k in final_layer_kernel_size]),
        )
        self.final_conv = nn.Conv3d(
            config["in_channels"] * 2,
            config["decoder"]["num_outputs"],
            kernel_size=final_layer_kernel_size,
            padding=tuple([k // 2 for k in final_layer_kernel_size]),
        )

    def forward(self, embeddings, scan):
        # embeddings is a list of (B, C_layer, D_layer, W_layer, H_layer)
        embeddings = embeddings[::-1]

        decoded = None
        for i in range(len(embeddings)):
            embedding = embeddings[i]
            if i == 0:
                decoded = self.blocks[i](embedding)
            else:
                decoded = self.blocks[i](embedding, decoded)

        high_resolution_embeddings = self.scan_conv(scan)
        final_embeddings = torch.cat([high_resolution_embeddings, decoded], dim=1)
        decoded = self.final_conv(final_embeddings)

        return decoded

    @staticmethod
    def _reduce(loss, reduction):
        if reduction is None:
            return loss
        elif reduction == "mean":
            return loss.mean()
        elif reduction == "sum":
            return loss.sum()
        else:
            raise NotImplementedError("Please implement the reduction type")

    @staticmethod
    def soft_dice_loss_fn(
        prediction: torch.Tensor, target: torch.Tensor, reduction="mean", ignore_index: int = -100, smooth: float = 1e-8
    ):
        """
        Both prediction and target should be of the form (batch_size, num_classes, depth, width, height).

        prediction: probability scores for each class
        target: should be binary masks.
        """

        num_classes = prediction.shape[1]

        prediction = rearrange(prediction, "b n d h w -> b n (d h w)")
        target = rearrange(target, "b n d h w -> b n (d h w)")

        if ignore_index is not None:
            # Remove gradients of the predictions based on the target
            mask = target != ignore_index
            prediction = prediction * mask
            target = target * mask

        loss = 1 - (1 / num_classes) * (
            (2 * (prediction * target).sum(dim=2) + smooth)
            / ((prediction**2).sum(dim=2) + (target**2).sum(dim=2) + smooth)
        ).sum(dim=1)
        loss = UNetR3DDecoder._reduce(loss, reduction)

        return loss

    @staticmethod
    def cross_entropy_loss_fn(
        prediction: torch.Tensor, target: torch.Tensor, reduction="mean", ignore_index: int = -100, smooth: float = 1e-8
    ):
        """
        Both prediction and target should be of the form (batch_size, num_classes, depth, width, height).

        prediction: probability scores for each class
        target: should be binary masks.
        """

        num_voxels = torch.prod(torch.tensor(prediction.shape[2:]))

        prediction = rearrange(prediction, "b n d h w -> b n (d h w)")
        target = rearrange(target, "b n d h w -> b n (d h w)")

        if ignore_index is not None:
            # Remove gradients of the predictions based on the target
            mask = target != ignore_index
            prediction = prediction * mask
            target = target * mask

        loss = -(1 / num_voxels) * (target * torch.log(prediction + smooth)).sum(dim=(1, 2))
        loss = UNetR3DDecoder._reduce(loss, reduction)

        return loss

    @staticmethod
    def loss_fn(
        prediction: torch.Tensor,
        target: torch.Tensor,
        reduction="mean",
        weight_dsc=1.0,
        weight_ce=1.0,
        ignore_index=-100,
        smooth: float = 1e-8,
        return_components=False,
    ):
        """
        Both prediction and target should be of the form (batch_size, num_classes, depth, width, height).

        prediction: probability scores for each class
        target: should be binary masks.
        """

        loss1 = UNetR3DDecoder.soft_dice_loss_fn(
            prediction, target, reduction=None, ignore_index=ignore_index, smooth=smooth
        )
        loss2 = UNetR3DDecoder.cross_entropy_loss_fn(
            prediction, target, reduction=None, ignore_index=ignore_index, smooth=smooth
        )
        loss = weight_dsc * loss1 + weight_ce * loss2

        loss = UNetR3DDecoder._reduce(loss, reduction)

        if return_components:
            loss1 = UNetR3DDecoder._reduce(loss1, reduction)
            loss2 = UNetR3DDecoder._reduce(loss2, reduction)
            return loss, [loss1, loss2]
        return loss

In [9]:
test_config = {
    "in_channels": 1,
    "decoder": {
        "conv_kernel_size": (3, 3, 3),
        "final_layer_kernel_size": (5, 5, 5),
        "num_outputs": 5,
    },
    "stages": [
        {
            "_in_dim": 12,
            "_in_patch_size": (1, 4, 4),
            "_out_dim": 12,
            "_out_patch_size": (1, 4, 4),
        },
        {
            "_in_dim": 12,
            "_in_patch_size": (1, 4, 4),
            "_out_dim": 48,
            "_out_patch_size": (2, 8, 8),
        },
        {
            "_in_dim": 48,
            "_in_patch_size": (2, 8, 8),
            "_out_dim": 192,
            "_out_patch_size": (4, 16, 16),
        },
        {
            "_in_dim": 192,
            "_in_patch_size": (4, 16, 16),
            "_out_dim": 768,
            "_out_patch_size": (8, 32, 32),
        },
    ],
}

test = UNetR3DDecoder(test_config)
display(test)
o = test(
    [
        torch.randn(2, 12, 32, 64, 64),
        torch.randn(2, 48, 16, 32, 32),
        torch.randn(2, 192, 8, 16, 16),
        torch.randn(2, 768, 4, 8, 8),
    ],
    torch.randn(2, 1, 32, 256, 256),
)
display(o.shape)


[1;35mUNetR3DDecoder[0m[1m([0m
  [1m([0mblocks[1m)[0m: [1;35mModuleList[0m[1m([0m
    [1m([0m[1;36m0[0m[1m)[0m: [1;35mUNetR3DBlock[0m[1m([0m
      [1m([0mconv[1m)[0m: [1;35mUNetR3DConvBlock[0m[1m([0m
        [1m([0mconv[1m)[0m: [1;35mConv3d[0m[1m([0m[1;36m768[0m, [1;36m768[0m, [33mkernel_size[0m=[1m([0m[1;36m3[0m, [1;36m3[0m, [1;36m3[0m[1m)[0m, [33mstride[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mpadding[0m=[1m([0m[1;36m1[0m, [1;36m1[0m, [1;36m1[0m[1m)[0m, [33mbias[0m=[3;91mFalse[0m[1m)[0m
        [1m([0mbatch_norm[1m)[0m: [1;35mBatchNorm3d[0m[1m([0m[1;36m768[0m, [33meps[0m=[1;36m1e[0m[1;36m-05[0m, [33mmomentum[0m=[1;36m0[0m[1;36m.1[0m, [33maffine[0m=[3;92mTrue[0m, [33mtrack_running_stats[0m=[3;92mTrue[0m[1m)[0m
        [1m([0mrelu[1m)[0m: [1;35mReLU[0m[1m([0m[33minplace[0m=[3;92mTrue[0m[1m)[0m
      [1m)[0m
      [1m([0mdeconv[1m)[0m: 

[1;35mtorch.Size[0m[1m([0m[1m[[0m[1;36m2[0m, [1;36m5[0m, [1;36m32[0m, [1;36m256[0m, [1;36m256[0m[1m][0m[1m)[0m

In [10]:
from neuro_utils.describe import describe_model

describe_model(test)

Total Parameters: 18,327,303
+-----------------------------------+------------+
|               Module              | Parameters |
+-----------------------------------+------------+
|     blocks.0.conv.conv.weight     | 15,925,248 |
|  blocks.0.conv.batch_norm.weight  |    768     |
|   blocks.0.conv.batch_norm.bias   |    768     |
|   blocks.0.deconv.deconv.weight   | 1,179,648  |
| blocks.0.deconv.batch_norm.weight |    192     |
|  blocks.0.deconv.batch_norm.bias  |    192     |
|     blocks.1.conv.conv.weight     |  995,328   |
|  blocks.1.conv.batch_norm.weight  |    192     |
|   blocks.1.conv.batch_norm.bias   |    192     |
|   blocks.1.deconv.deconv.weight   |  147,456   |
| blocks.1.deconv.batch_norm.weight |     48     |
|  blocks.1.deconv.batch_norm.bias  |     48     |
|     blocks.2.conv.conv.weight     |   62,208   |
|  blocks.2.conv.batch_norm.weight  |     48     |
|   blocks.2.conv.batch_norm.bias   |     48     |
|   blocks.2.deconv.deconv.weight   |   9,216    |
| 

In [11]:
pred = torch.softmax(o, dim=1)
gt = torch.randint(0, 2, pred.shape)

print(pred.shape, gt.shape)
test.loss_fn(pred, gt, return_components=True)

torch.Size([2, 5, 32, 256, 256]) torch.Size([2, 5, 32, 256, 256])



[1m([0m
    [1;35mtensor[0m[1m([0m[1;36m4.7629[0m, [33mgrad_fn[0m=[1m<[0m[1;95mMeanBackward0[0m[39m>[0m[1;39m)[0m[39m,[0m
[39m    [0m[1;39m[[0m[1;35mtensor[0m[1;39m([0m[1;36m0.6320[0m[39m, [0m[33mgrad_fn[0m[39m=<MeanBackward0>[0m[1;39m)[0m[39m, [0m[1;35mtensor[0m[1;39m([0m[1;36m4.1309[0m[39m, [0m[33mgrad_fn[0m[39m=<MeanBackward0[0m[1m>[0m[1m)[0m[1m][0m
[1m)[0m

In [12]:
pred = torch.softmax(o, dim=1)
gt = torch.full(pred.shape, -100)

print(pred.shape, gt.shape)
test.loss_fn(pred, gt, return_components=True)

torch.Size([2, 5, 32, 256, 256]) torch.Size([2, 5, 32, 256, 256])



[1m([0m
    [1;35mtensor[0m[1m([0m[1;36m0[0m., [33mgrad_fn[0m=[1m<[0m[1;95mMeanBackward0[0m[39m>[0m[1;39m)[0m[39m,[0m
[39m    [0m[1;39m[[0m[1;35mtensor[0m[1;39m([0m[1;36m0[0m[39m., [0m[33mgrad_fn[0m[39m=<MeanBackward0>[0m[1;39m)[0m[39m, [0m[1;35mtensor[0m[1;39m([0m[1;36m0[0m[39m., [0m[33mgrad_fn[0m[39m=<MeanBackward0[0m[1m>[0m[1m)[0m[1m][0m
[1m)[0m

# nbdev

In [13]:
!nbdev_export