In [1]:
!ls

Logs  Users  projects


In [2]:
%cd Users/jose.d.berlin/3D-MRI-Segmentation/

/mnt/batch/tasks/shared/LS_root/mounts/clusters/e4ds-v4/code/Users/jose.d.berlin/3D-MRI-Segmentation


In [3]:
ls

[0m[34;42mNotebooks[0m/  [01;32mREADME.md[0m*  [34;42mdata[0m/  [34;42mmodels[0m/  [34;42msrc[0m/


In [4]:
# Importing necessary libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import nibabel as nib

In [35]:
sample = nib.load("data/mpr-1.nifti.img")
sample_in = torch.tensor(sample.get_fdata()).float().squeeze(3).unsqueeze(0).unsqueeze(0)
print(sample_in.shape)

torch.Size([1, 1, 256, 256, 128])


### Patch Embed

In [36]:
class PatchEmbed(nn.Module):
    def __init__(self, img_size, patch_size, in_chans, embed_dim):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.D, self.H, self.W = img_size
        self.pD, self.pH, self.pW = patch_size

        assert self.D % self.pD == 0 and self.H % self.pH == 0 and self.W % self.pW == 0, "Image dimensions must be divisible by the patch size."

        self.num_patches = (self.D // self.pD) * (self.H // self.pH) * (self.W // self.pW)
        self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)
    
    def forward(self, x):
        print("input:",x.shape)
        x = self.proj(x)  # B, embed_dim, D//pD, H//pH, W//pW
        B, C, Dp, Hp, Wp = x.shape
        print("bef:",x.shape)

        x = x.permute(0, 2, 3, 4, 1).contiguous()
        print("aft:",x.shape)
        x = x.view(B, -1, C)  # B, N, C
        print("view:",x.shape)
        x = self.norm(x)
        print("norm:",x.shape)
        return x

In [37]:
class SwinUNETR3D(nn.Module):
    def __init__(
        self, 
        in_ch=1, 
        out_ch=1, 
        img_size=(128, 128, 128), 
        patch_size=(4, 4, 4), 
        window_size=(7, 7, 7), 
        embed_dim=96, 
        depths=(2, 2, 2, 2), 
        num_heads=(2, 4, 8, 16),
        mlp_ratio=4.0,
        qkv_bias=True,
        drop_rate=0.0,
        attn_drop_rate=0.0,
        drop_path_rate=0.0
        ):
        super().__init__()

        self.patch_embed = PatchEmbed(
            img_size=img_size,
            patch_size=patch_size,
            in_chans=in_ch,
            embed_dim=embed_dim
        )
    
    def forward(self, x):
        B = x.shape[0] # Batch size

        x_tokens = self.patch_embed(x)
        return x_tokens

In [41]:
swin = SwinUNETR3D(in_ch=1, out_ch=1, img_size=sample_in.shape[2:], embed_dim=96)

In [42]:
swin(sample_in)

input: torch.Size([1, 1, 256, 256, 128])
bef: torch.Size([1, 96, 64, 64, 32])
aft: torch.Size([1, 64, 64, 32, 96])
view: torch.Size([1, 131072, 96])
norm: torch.Size([1, 131072, 96])


tensor([[[-1.5086,  0.5042, -0.4807,  ...,  0.5815, -0.0622,  0.1525],
         [-1.4675,  0.4533,  0.6392,  ...,  0.4482, -0.0679,  0.0435],
         [-1.4159,  0.8678, -0.7656,  ...,  0.2521,  0.7943,  0.7551],
         ...,
         [-0.9434,  1.2161,  0.2173,  ...,  0.0936, -0.1224,  0.6493],
         [-1.6088,  0.5546,  0.6410,  ..., -0.0180, -0.1719,  0.7827],
         [-0.4031, -0.0071, -0.5993,  ...,  0.5524, -0.3631,  1.0740]]],
       grad_fn=<NativeLayerNormBackward0>)