In [5]:
!ls

 code.ipynb	     download_data.ipynb
 code.ipynb.amltmp  'swinUNETR implementation.ipynb'
 dev_note.ipynb     'swinunetr implementation.ipynb.amltmp'


In [6]:
%cd Users/jose.d.berlin/3D-MRI-Segmentation/

[Errno 2] No such file or directory: 'Users/jose.d.berlin/3D-MRI-Segmentation/'
/mnt/batch/tasks/shared/LS_root/mounts/clusters/e4ds-v4/code/Users/jose.d.berlin/3D-MRI-Segmentation/Notebooks


In [7]:
ls

 [0m[01;32mcode.ipynb[0m*          [01;32mdownload_data.ipynb[0m*
 [01;32mcode.ipynb.amltmp[0m*  [01;32m'swinUNETR implementation.ipynb'[0m*
 [01;32mdev_note.ipynb[0m*     [01;32m'swinunetr implementation.ipynb.amltmp'[0m*


In [3]:
# Importing necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import nibabel as nib

In [8]:
sample = nib.load("../data/mpr-1.nifti.img")
sample_in = torch.tensor(sample.get_fdata()).float().squeeze(3).unsqueeze(0).unsqueeze(0)
print(sample_in.shape)

torch.Size([1, 1, 256, 256, 128])


## Window Partition

In [28]:
def window_partition(x, window_size):
    B, D, H, W, C = x.shape
    x = x.view(
        B,
        D // window_size,
        window_size,
        H // window_size,
        window_size,
        W // window_size,
        window_size,
        C,
    )
    windows = x.permute(0,1,3,5,2,4,6,7).contiguous()
    windows = windows.view(-1, window_size**3, C)
    return windows


def window_reverse(windows, window_size, D, H, W):
    B = int(windows.shape[0] / (D * H * W / window_size**3))
    x = windows.view(
        B,
        D // window_size,
        H // window_size,
        W // window_size,
        window_size,
        window_size,
        window_size,
        -1
    )
    x = x.permute(0,1,4,2,5,3,6,7).contiguous()
    return x.view(B, D, H, W, -1)

### Window attn

In [29]:
class WindowAttention(nn.Module):
    def __init__(self, dim, num_heads):
        super().__init__()
        self.num_heads = num_heads
        self.scale = (dim // num_heads) ** -0.5

        self.qkv = nn.Linear(dim, dim * 3)
        self.proj = nn.Linear(dim, dim)

    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
        q, k, v = qkv.permute(2,0,3,1,4)

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)

        x = (attn @ v).transpose(1,2).reshape(B, N, C)
        return self.proj(x)

### SWIN Transformer blk

In [30]:
class SwinBlock(nn.Module):
    def __init__(self, dim, num_heads, window_size):
        super().__init__()
        self.window_size = window_size

        self.norm1 = nn.LayerNorm(dim)
        self.attn = WindowAttention(dim, num_heads)
        self.norm2 = nn.LayerNorm(dim)

        self.mlp = nn.Sequential(
            nn.Linear(dim, dim * 4),
            nn.GELU(),
            nn.Linear(dim * 4, dim)
        )

    def forward(self, x):
        B, C, D, H, W = x.shape
        ws = self.window_size

        shortcut = x

        # Channel last
        x = x.permute(0, 2, 3, 4, 1).contiguous()  # [B, D, H, W, C]

        # Window partition
        x_windows = window_partition(x, ws)  # [num_windows*B, wsÂ³, C]

        # Attention 
        x_windows = self.norm1(x_windows)
        x_windows = self.attn(x_windows)

        # Reverse windows
        x = window_reverse(x_windows, ws, D, H, W)  # [B, D, H, W, C]

        # FFN 
        x = x + self.mlp(self.norm2(x))

        # Channel first
        x = x.permute(0, 4, 1, 2, 3).contiguous()

        return x + shortcut


### Patch Embed

In [31]:
class PatchEmbedding(nn.Module):
    def __init__(self, in_channels, embed_dim):
        super().__init__()
        self.proj = nn.Conv3d(
            in_channels, embed_dim,
            kernel_size=2, stride=2
        )

    def forward(self, x):
        x = self.proj(x)
        return x

### ENCODER stage

In [32]:
class SwinStage(nn.Module):
    def __init__(self, dim, depth, num_heads, window_size):
        super().__init__()
        self.blocks = nn.ModuleList([
            SwinBlock(dim, num_heads, window_size)
            for _ in range(depth)
        ])

    def forward(self, x):
        for blk in self.blocks:
            x = blk(x)
        return x

### Decoder Stage

In [33]:
class DecoderBlock(nn.Module):
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, 3, padding=1),
            nn.InstanceNorm3d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_ch, out_ch, 3, padding=1),
            nn.InstanceNorm3d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x, skip):
        x = F.interpolate(x, scale_factor=2, mode="trilinear", align_corners=False)
        x = torch.cat([x, skip], dim=1)
        return self.conv(x)

### SWIN

In [34]:
class SwinUNETR(nn.Module):
    def __init__(
        self,
        in_channels=1,
        out_channels=14,
        embed_dim=48,
        depths=(2,2,2,2),
        num_heads=(3,6,12,24),
        window_size=8,
    ):
        super().__init__()

        self.patch_embed = PatchEmbedding(in_channels, embed_dim)

        self.enc1 = SwinStage(embed_dim, depths[0], num_heads[0], window_size)
        self.enc2 = SwinStage(embed_dim*2, depths[1], num_heads[1], window_size)
        self.enc3 = SwinStage(embed_dim*4, depths[2], num_heads[2], window_size)
        self.enc4 = SwinStage(embed_dim*8, depths[3], num_heads[3], window_size)

        self.down1 = nn.Conv3d(embed_dim, embed_dim*2, 2, 2)
        self.down2 = nn.Conv3d(embed_dim*2, embed_dim*4, 2, 2)
        self.down3 = nn.Conv3d(embed_dim*4, embed_dim*8, 2, 2)

        self.dec3 = DecoderBlock(embed_dim*8 + embed_dim*4, embed_dim*4)
        self.dec2 = DecoderBlock(embed_dim*4 + embed_dim*2, embed_dim*2)
        self.dec1 = DecoderBlock(embed_dim*2 + embed_dim, embed_dim)

        self.out = nn.Conv3d(embed_dim, out_channels, 1)

    def forward(self, x):
        x1 = self.patch_embed(x)
        x1 = self.enc1(x1)

        x2 = self.enc2(self.down1(x1))
        x3 = self.enc3(self.down2(x2))
        x4 = self.enc4(self.down3(x3))

        x = self.dec3(x4, x3)
        x = self.dec2(x, x2)
        x = self.dec1(x, x1)

        return self.out(x)


In [18]:
sample_in.shape

torch.Size([1, 1, 256, 256, 128])

In [36]:
model = SwinUNETR(in_channels=1, out_channels=3, window_size=8)


y = model(sample_in)

In [39]:
y.shape

torch.Size([1, 3, 128, 128, 64])

In [40]:
y

tensor([[[[[-5.3905e-03, -3.2963e-02, -1.5296e-02,  ..., -1.9657e-02,
             3.4582e-02, -4.1594e-03],
           [ 6.7382e-02, -7.4139e-02,  5.7053e-02,  ..., -8.5702e-02,
            -1.2613e-01,  1.1440e-01],
           [ 1.1550e-01, -7.8226e-02, -1.2644e-02,  ...,  5.4226e-02,
            -9.6189e-02,  7.1722e-02],
           ...,
           [ 1.3955e-01,  6.6642e-02, -1.0558e-01,  ..., -4.4179e-02,
            -5.9765e-02,  6.3264e-02],
           [ 1.0033e-01, -1.0772e-01, -5.0648e-02,  ...,  3.6984e-02,
            -3.7740e-02, -6.5374e-02],
           [-8.1794e-03,  2.6061e-02, -2.7844e-01,  ...,  2.6291e-02,
            -5.4583e-02, -8.2234e-02]],

          [[ 1.0688e-01,  1.7772e-02, -5.4158e-02,  ..., -4.7550e-02,
             5.6923e-02, -3.8186e-02],
           [ 6.1769e-02, -7.7323e-02, -1.8357e-01,  ..., -8.6725e-02,
            -1.6145e-01, -9.5584e-02],
           [ 2.0068e-01, -3.9839e-02, -1.6370e-01,  ...,  1.1261e-01,
            -1.8334e-01, -5.8337e-02],
 