info
**UNETR: Transformers for 3D Medical Image Segmentation**    
*Ali Hatamizadeh, Yucheng Tang, Vishwesh Nath, Dong Yang, Andriy Myronenko, Bennett Landman, Holger Roth, Daguang Xu*   
[[paper](https://arxiv.org/abs/2103.10504)]   
WACV 2022     

In [4]:
import torch
import torch.nn as nn

from monai.networks.nets import UNETR

In [26]:
class PatchEmbedding_UNETR(nn.Module):
    def __init__(self, image_size, patch_size, in_dim, latent_dim) -> None:
        super(PatchEmbedding_UNETR, self).__init__()

        # N = H*W*D / P^3
        self.num_patches = image_size // patch_size
        print(self.num_patches)

        # project into a K dim
        self.project = nn.Linear(in_features= (patch_size**3)*in_dim, out_features=latent_dim)

        # positional embedding
        self.pos_embed = nn.Parameter(torch.zeros(self.num_patches, latent_dim))

    def forward(self, img):
        # Flatten -> Projection -> Position
        B, _, _, _, _ = img.shape # (B C D W H)

        patches = img.reshape(B, self.num_patch, -1) # (B C D W H) -> (B, N, P^3xC) --- N = (HxWxD)/P^3
        patches = self.project(patches) # (B, N, P*P*P*C) -> (B, N, D)
        patches = patches + self.pos_embed

        return patches

In [33]:
class Transformer_block(nn.Module):
    def __init__(self, latent_dim, num_heads) -> None:
        super(Transformer_block, self).__init__()

        self.self_attn = nn.Sequential(
            nn.LayerNorm(latent_dim),
            nn.MultiheadAttention(embed_dim=latent_dim, num_heads=num_heads)
        )

        self.ff = nn.Sequential(
            nn.LayerNorm(latent_dim),
            nn.Linear(in_features=latent_dim, out_features=latent_dim),
            nn.GELU(),
            nn.Linear(in_features=latent_dim, out_features=latent_dim),
        )

    def forward(self, patches):

        residual = patches
        patches = self.self_attn(patches)
        patches = patches + residual

        residual = patches
        patches = self.ff(patches)
        patches = patches + residual

        return patches

In [34]:
class UNETR_encoder(nn.Module):
    def __init__(self, in_dim, d_model, image_size, patch_size, num_heads) -> None:
        super(UNETR_encoder, self).__init__()
        self.num_patches = image_size // patch_size

        self.patch_emb = PatchEmbedding_UNETR(in_dim=in_dim, latent_dim=d_model, image_size=image_size, patch_size=patch_size)

        self.layer1 = nn.Sequential(
            *[Transformer_block(latent_dim=d_model, num_heads=num_heads) for _ in range(3)]
        )

        self.layer2 = nn.Sequential(
            *[Transformer_block(latent_dim=d_model, num_heads=num_heads) for _ in range(3)]
        )

        self.layer3 = nn.Sequential(
            *[Transformer_block(latent_dim=d_model, num_heads=num_heads) for _ in range(3)]
        )

        self.layer4 = nn.Sequential(
            *[Transformer_block(latent_dim=d_model, num_heads=num_heads) for _ in range(3)]
        )

    def forward(self, x:torch.tensor):
        B, C, D, H, W = x.shape
        stage_outputs = {}

        x = self.patch_emb(x)

        x = self.layer1(x)
        stage_outputs['z3'] = x.reshape(B, -1, D/self.num_patches, H/self.num_patches, W/self.num_patches)

        x = self.layer2(x)
        stage_outputs['z6'] = x.reshape(B, -1, D/self.num_patches, H/self.num_patches, W/self.num_patches)

        x = self.layer3(x)
        stage_outputs['z9'] = x.reshape(B, -1, D/self.num_patches, H/self.num_patches, W/self.num_patches)

        x = self.layer4(x)
        stage_outputs['z12'] = x.reshape(B, -1, D/self.num_patches, H/self.num_patches, W/self.num_patches)

        return stage_outputs



In [35]:
from typing import Sequence

class UNETR_skipconnection(nn.Module):
    def __init__(self, in_dim:int, hidden_dim:Sequence[int]):
        super(UNETR_skipconnection, self).__init__()
        
        self.skip1 = nn.Sequential(
            nn.Conv3d(in_channels=in_dim, out_channels=hidden_dim[0], kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(hidden_dim[0]),
            nn.ReLU(),
            nn.Conv3d(in_channels=hidden_dim[0], out_channels=hidden_dim[0], kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(hidden_dim[0]),
            nn.ReLU()
        )

        self.skip2 = nn.Sequential(
            # 768 -> 512
            nn.ConvTranspose3d(in_channels=hidden_dim[4], out_channels=hidden_dim[3], kernel_size=2, stride=2),
            nn.Conv3d(in_channels=hidden_dim[3], out_channels=hidden_dim[3], kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(hidden_dim[3]),
            nn.ReLU(),
            # 512 -> 256
            nn.ConvTranspose3d(in_channels=hidden_dim[3], out_channels=hidden_dim[2], kernel_size=2, stride=2),
            nn.Conv3d(in_channels=hidden_dim[2], out_channels=hidden_dim[2], kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(hidden_dim[2]),
            nn.ReLU(),
            # 256 -> 128
            nn.ConvTranspose3d(in_channels=hidden_dim[2], out_channels=hidden_dim[1], kernel_size=2, stride=2),
            nn.Conv3d(in_channels=hidden_dim[1], out_channels=hidden_dim[1], kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(hidden_dim[1]),
            nn.ReLU(),
        )

        self.skip3 = nn.Sequential(
            # 768 -> 512
            nn.ConvTranspose3d(in_channels=hidden_dim[4], out_channels=hidden_dim[3], kernel_size=2, stride=2),
            nn.Conv3d(in_channels=hidden_dim[3], out_channels=hidden_dim[3], kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(hidden_dim[3]),
            nn.ReLU(),
            # 512 -> 256
            nn.ConvTranspose3d(in_channels=hidden_dim[3], out_channels=hidden_dim[2], kernel_size=2, stride=2),
            nn.Conv3d(in_channels=hidden_dim[2], out_channels=hidden_dim[2], kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(hidden_dim[2]),
            nn.ReLU(),
        )

        self.skip4 = nn.Sequential(
            # 768 -> 512
            nn.ConvTranspose3d(in_channels=hidden_dim[4], out_channels=hidden_dim[3], kernel_size=2, stride=2),
            nn.Conv3d(in_channels=hidden_dim[3], out_channels=hidden_dim[3], kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(hidden_dim[3]),
            nn.ReLU(),
        )

    def forward(self, x, stage_outputs:dict):

        stage_outputs['z0'] = self.skip1(x)
        stage_outputs['z3'] = self.skip2(stage_outputs['z3'])
        stage_outputs['z6'] = self.skip3(stage_outputs['z6'])
        stage_outputs['z9'] = self.skip4(stage_outputs['z9'])

        return stage_outputs

In [36]:
class UNETR_decoder(nn.Module):
    def __init__(self, out_dim:int, hidden_dim:Sequence[int]) -> None:
        super(UNETR_decoder, self).__init__()

        self.up1 = nn.ConvTranspose3d(in_channels=hidden_dim[4], out_channels=hidden_dim[3], kernel_size=2, stride=2)
        self.conv1 = nn.Sequential(
            nn.Conv3d(in_channels=hidden_dim[3]*2, out_channels=hidden_dim[3], kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(hidden_dim[3]),
            nn.ReLU(),
            nn.Conv3d(in_channels=hidden_dim[3], out_channels=hidden_dim[3], kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(hidden_dim[3]),
            nn.ReLU(),
        )

        self.up2 = nn.ConvTranspose3d(in_channels=hidden_dim[3], out_channels=hidden_dim[2], kernel_size=2, stride=2)
        self.conv2 = nn.Sequential(
            nn.Conv3d(in_channels=hidden_dim[2]*2, out_channels=hidden_dim[2], kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(hidden_dim[2]),
            nn.ReLU(),
            nn.Conv3d(in_channels=hidden_dim[2], out_channels=hidden_dim[2], kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(hidden_dim[2]),
            nn.ReLU(),
        )

        self.up3 = nn.ConvTranspose3d(in_channels=hidden_dim[2], out_channels=hidden_dim[1], kernel_size=2, stride=2)
        self.conv3 = nn.Sequential(
            nn.Conv3d(in_channels=hidden_dim[1]*2, out_channels=hidden_dim[1], kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(hidden_dim[1]),
            nn.ReLU(),
            nn.Conv3d(in_channels=hidden_dim[1], out_channels=hidden_dim[1], kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(hidden_dim[1]),
            nn.ReLU(),
        )

        self.up4 = nn.ConvTranspose3d(in_channels=hidden_dim[1], out_channels=hidden_dim[0], kernel_size=2, stride=2)
        self.conv4 = nn.Sequential(
            nn.Conv3d(in_channels=hidden_dim[0]*2, out_channels=hidden_dim[0], kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(hidden_dim[0]),
            nn.ReLU(),
            nn.Conv3d(in_channels=hidden_dim[0], out_channels=hidden_dim[0], kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm3d(hidden_dim[0]),
            nn.ReLU(),
        )

        self.fc = nn.Conv3d(in_channels=hidden_dim[0], out_channels=out_dim, kernel_size=1)

    def forward(self, stage_outputs:dict):

        z = stage_outputs['z12']
        z = self.up1(z)
        z = torch.concat([z, stage_outputs['z9']], dim=1)
        z = self.conv1(z)

        z = self.up2(z)
        z = torch.concat([z, stage_outputs['z6']], dim=1)
        z = self.conv2(z)

        z = self.up3(z)
        z = torch.concat([z, stage_outputs['z3']], dim=1)
        z = self.conv3(z)

        z = self.up4(z)
        z = torch.concat([z, stage_outputs['z0']], dim=1)
        z = self.conv4(z)

        z = self.conv5(z)

        return z

In [37]:
class UNETR(nn.Module):
    def __init__(self, in_dim, out_dim, hidden_dim:Sequence[int], img_size, patch_size, num_heads=16, d_model=768) -> None:
        super(UNETR, self).__init__()

        self.encoder = UNETR_encoder(in_dim=in_dim, d_model=d_model, image_size=img_size, patch_size=patch_size, num_heads=num_heads)
        self.skipnet = UNETR_skipconnection(in_dim=in_dim, hidden_dim=hidden_dim)
        self.docoder = UNETR_decoder(out_dim=out_dim, hidden_dim=hidden_dim)
    
    def forward(self, x):

        stage_outputs = self.encoder(x)
        stage_outputs = self.skipnet(x, stage_outputs)
        out = self.decoder(stage_outputs)

        return out

In [38]:
model = UNETR(
    in_dim=4,
    out_dim=3,
    hidden_dim=[64, 128, 256, 512, 768],
    img_size=128,
    patch_size=8,
    num_heads=16,
    d_model=768
)

16


AttributeError: cannot assign module before Module.__init__() call