In [2]:
from monai import data, transforms
import glob
import numpy as np
import os
import re
import natsort
import SimpleITK as sitk

def get_loader():
    train_real = natsort.natsorted(glob.glob(f'/workspace/Ablation/ABLATION_PD/RECONSTRUCTION/DATA/*/*.nii.gz'))[:] #ALL -> 2125 or 2130

    print("Train [Total]  number = ", len(train_real))

    files_tr = [img_tr for img_tr in zip(train_real)]

    tr_transforms = transforms.Compose(
        [
            transforms.LoadImage(image_only=True),
            transforms.EnsureChannelFirst(),
            transforms.Orientation(axcodes="LPS"),
            transforms.ScaleIntensityRange(a_min=0.0, a_max=22.0, b_min=0.0, b_max=1.0, clip=True), #
            transforms.EnsureType(),
            transforms.ToTensor(track_meta=False)
        ]
    )

    # new_dataset -> Cachenew_dataset
    train_ds = data.Dataset(data = files_tr, transform = tr_transforms)

    train_loader = data.DataLoader(
        train_ds,
        batch_size=1,
        shuffle=False,
        num_workers=2,
        pin_memory=False
        # persistent_workers=True,
    )

    print("loader is ver (train)")

    loader = train_loader

    return loader, train_real

In [7]:
import pdb
import random
from typing import Optional, Sequence, Tuple, Type, Union

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_

from monai.networks.blocks import UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, get_conv_layer
from swin_transformer_3d import SwinTransformer3D

class SwinTransformerSkipForSimMIM(SwinTransformer3D):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        trunc_normal_(self.mask_token, mean=0.0, std=0.02)
        self.layers1 = nn.ModuleList([self.layers1])
        self.layers2 = nn.ModuleList([self.layers2])
        self.layers3 = nn.ModuleList([self.layers3])
        self.layers4 = nn.ModuleList([self.layers4])

    def forward(self, x, mask, choice):
        x_out = []

        _, _, D, H, W = x.size()
        x = self.patch_embed(x)
        x = x.flatten(2).transpose(1, 2)
        B, L, _ = x.shape

        x = self.pos_drop(x)
        x = x.view(-1, 
                   self.embed_dim, 
                   D // self.patch_size[0], 
                   H // self.patch_size[1], 
                   W // self.patch_size[2])

        # if choice == "sld" or choice == "all":
        rand_choice = random.sample(range(0, x.shape[1]), int(0.6 * x.shape[1])) #channel masking
        mask = torch.ones(x.shape).cuda()
        mask[:, rand_choice, :, :, :] = 0
        x = x * mask

        x_out.append(x)
        for layer in self.layers1:
            x = layer[0](x)

        x_out.append(x)

        for layer in self.layers2:
            x = layer[0](x)

        x_out.append(x)

        for layer in self.layers3:
            x = layer[0](x)

        x_out.append(x)

        for layer in self.layers4:
            x = layer[0](x)
 
        x_out.append(x)
        reduction = self.patch_size[0] * 16 #32
        x = x.reshape(-1, (D // reduction) * (H // reduction) * (W // reduction), 2 * self.num_features)
        x = self.norm(x)
        x = x.transpose(1, 2)
        x = x.view(-1, 2 * self.num_features, D // 32, H // 32, W // 32)

        return x, x_out

    @torch.jit.ignore
    def no_weight_decay(self):
        return super().no_weight_decay() | {"mask_token"}

In [8]:
class SimMIMSkip(nn.Module):
    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        encoder,
        encoder_stride,
        img_size: Union[Sequence[int], int],
        depths: Sequence[int] = (2, 2, 2, 2),
        num_heads: Sequence[int] = (3, 6, 12, 24),
        feature_size: int = 24,
        norm_name: Union[Tuple, str] = "instance",
        drop_rate: float = 0.0,
        attn_drop_rate: float = 0.0,
        dropout_path_rate: float = 0.0,
        normalize: bool = True,
        use_checkpoint: bool = False,
        spatial_dims: int = 3,
        decoder="deconv",
        loss="mask_only",
        choice="mae",
        inf="notsim",
        temperature=0.07,
    ):
        super().__init__()
        self.encoder = encoder
        self.encoder_stride = encoder_stride
        self.decoder = decoder
        self.loss = loss

        self.in_chans = self.encoder.in_chans
        self.patch_size = self.encoder.patch_size

        # add UNETR blocks

        self.encoder1 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=in_channels,
            out_channels=feature_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=True,
        )

        self.encoder2 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size,
            out_channels=feature_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=True,
        )

        self.encoder3 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=2 * feature_size,
            out_channels=2 * feature_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=True,
        )

        self.encoder4 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=4 * feature_size,
            out_channels=4 * feature_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=True,
        )

        self.encoder10 = UnetrBasicBlock(
            spatial_dims=spatial_dims,
            in_channels=16 * feature_size,
            out_channels=16 * feature_size,
            kernel_size=3,
            stride=1,
            norm_name=norm_name,
            res_block=True,
        )

        self.decoder5 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=16 * feature_size,
            out_channels=8 * feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=True,
        )

        self.decoder4 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 8,
            out_channels=feature_size * 4,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=True,
        )

        self.decoder3 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 4,
            out_channels=feature_size * 2,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=True,
        )
        
        self.decoder2 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size * 2,
            out_channels=feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=True,
        )

        self.decoder1 = UnetrUpBlock(
            spatial_dims=spatial_dims,
            in_channels=feature_size,
            out_channels=feature_size,
            kernel_size=3,
            upsample_kernel_size=2,
            norm_name=norm_name,
            res_block=True,
        )

        self.choice = choice
        self.inf = inf
        self.temp = temperature
        self.out = UnetOutBlock(
            spatial_dims=spatial_dims, in_channels=feature_size, out_channels=out_channels
        )  # type: ignore

    def forward(self, x, mask, x_org):
        choice = self.choice
        inf = self.inf
        # print("x_initial shape :", x.shape)
        z, hidden_states_out = self.encoder(x, mask, choice)
        # print("hidden_states_out[0] shape :", hidden_states_out[0].shape)
        # print("hidden_states_out[1] shape :", hidden_states_out[1].shape)
        # print("hidden_states_out[2] shape :", hidden_states_out[2].shape)
        # print("hidden_states_out[3] shape :", hidden_states_out[3].shape)
        # print("hidden_states_out[4] shape :", hidden_states_out[4].shape)

        enc0 = self.encoder1(x)
        enc1 = self.encoder2(hidden_states_out[0])
        enc2 = self.encoder3(hidden_states_out[1])
        enc3 = self.encoder4(hidden_states_out[2])
        dec4 = self.encoder10(hidden_states_out[4])
        dec3 = self.decoder5(dec4, hidden_states_out[3])
        dec2 = self.decoder4(dec3, enc3)
        dec1 = self.decoder3(dec2, enc2)
        dec0 = self.decoder2(dec1, enc1)
        out = self.decoder1(dec0, enc0)
        x_rec = self.out(out)

        loss_recon = F.l1_loss(x_org, x_rec, reduction="mean")

        return loss_recon, x_rec, mask

    @torch.jit.ignore
    def no_weight_decay(self):
        if hasattr(self.encoder, "no_weight_decay"):
            return {"encoder." + i for i in self.encoder.no_weight_decay()}
        return {}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        if hasattr(self.encoder, "no_weight_decay_keywords"):
            return {"encoder." + i for i in self.encoder.no_weight_decay_keywords()}
        return {}

class UnetUpBlock(nn.Module):
    """
    An upsampling module that can be used for DynUNet, based on:
    `Automated Design of Deep Learning Methods for Biomedical Image Segmentation <https://arxiv.org/abs/1904.08128>`_.
    `nnU-Net: Self-adapting Framework for U-Net-Based Medical Image Segmentation <https://arxiv.org/abs/1809.10486>`_.
    Args:
        spatial_dims: number of spatial dimensions.
        in_channels: number of input channels.
        out_channels: number of output channels.
        kernel_size: convolution kernel size.
        stride: convolution stride.
        upsample_kernel_size: convolution kernel size for transposed convolution layers.
        norm_name: feature normalization type and arguments.
        act_name: activation layer type and arguments.
        dropout: dropout probability.
        trans_bias: transposed convolution bias.
    """

    def __init__(
        self,
        spatial_dims: int,
        in_channels: int,
        out_channels: int,
        kernel_size: Union[Sequence[int], int],
        stride: Union[Sequence[int], int],
        upsample_kernel_size: Union[Sequence[int], int],
        norm_name: Union[Tuple, str],
        act_name: Union[Tuple, str] = ("leakyrelu", {"inplace": True, "negative_slope": 0.01}),
        dropout: Optional[Union[Tuple, str, float]] = None,
        trans_bias: bool = False,
    ):
        super().__init__()
        upsample_stride = upsample_kernel_size
        self.transp_conv = get_conv_layer(
            spatial_dims,
            in_channels,
            out_channels,
            kernel_size=upsample_kernel_size,
            stride=upsample_stride,
            dropout=dropout,
            bias=trans_bias,
            act=None,
            norm=None,
            conv_only=False,
            is_transposed=True,
        )
        self.conv_block = UnetBasicBlock(
            spatial_dims,
            out_channels,
            out_channels,
            kernel_size=kernel_size,
            stride=1,
            dropout=dropout,
            norm_name=norm_name,
            act_name=act_name,
        )

    def forward(self, inp):
        # number of channels for skip should equals to out_channels
        out = self.transp_conv(inp)
        # out = torch.cat((out, skip), dim=1)
        out = self.conv_block(out)
        return out


In [9]:
ckpt_path = f'/workspace/Ablation/ABLATION_PD/FINE_TUNING/WEIGHTS/DAE_1900.pt'
print(f"ckpt_path : {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location='cpu')

encoder = SwinTransformerSkipForSimMIM(
    num_classes=0,
    img_size=96,
    patch_size=(2, 2, 2),
    in_chans=1,
    embed_dim=48,
    depths=[2, 2, 2, 2],
    num_heads=[3, 6, 12, 24],
    window_size=(7, 7, 7),
    mlp_ratio=4.0,
    qkv_bias=True,
    qk_scale=None,
    drop_rate=0.0,
    drop_path_rate=0.0,
    use_checkpoint=False,
    patch_norm=True,
)
encoder_stride = 32

model = SimMIMSkip(
    encoder=encoder,
    encoder_stride=encoder_stride,
    loss="all_img",
    img_size=(192, 192, 96),
    in_channels=1,
    out_channels=1,
    feature_size=48,
    drop_rate=0.0,
    attn_drop_rate=0.0,
    dropout_path_rate=0.0,
    use_checkpoint=False,
    choice="all",
    temperature=0.07,
)
model.load_state_dict(ckpt['model'])

ckpt_path : /workspace/Ablation/ABLATION_PD/FINE_TUNING/WEIGHTS/DAE_1900.pt


  ckpt = torch.load(ckpt_path, map_location='cpu')
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


<All keys matched successfully>

In [11]:
import SimpleITK as sitk
def save_img(img, save_path):
    img = img.cpu().detach().numpy().transpose(0,4,3,2,1).squeeze().astype(np.float32)
    save_pred = sitk.GetImageFromArray(img)
    sitk.WriteImage(save_pred, save_path)


In [13]:
import torch
import torch.nn.functional as F

data_loader, train_real = get_loader()
model.to('cuda')
# model.eval()
with torch.no_grad():
    for idx, img in enumerate(data_loader):
        img = img.cuda(non_blocking=True)
        B, C, H, W, Z = img.shape
        noise = (0.1**0.5) * torch.randn(B, C, H, W, Z).cuda()
        img_noisy = img + noise
        img_lowres = F.interpolate(
            img_noisy, size=(int(192 / 4), int(192 / 4), int(96 / 4))
        )
        img_resam = F.interpolate(img_lowres, size=(192, 192, 96))
        _, x_rec, _ = model(img_resam, None, img)
        save_img(x_rec, train_real[idx].replace("DATA", "8_DisAE"))


Train [Total]  number =  30
loader is ver (train)


In [50]:
save_img(img_resam, "/workspace/Ablation/ABLATION_PD/RECONSTRUCTION/resamp.nii.gz")
save_img(img, "/workspace/Ablation/ABLATION_PD/RECONSTRUCTION/img.nii.gz")
save_img(img_noisy, "/workspace/Ablation/ABLATION_PD/RECONSTRUCTION/noisy.nii.gz")


In [51]:
save_img(x_rec, "/workspace/Ablation/ABLATION_PD/RECONSTRUCTION/x_rec.nii.gz")


In [52]:
save_img(img, "/workspace/Ablation/ABLATION_PD/RECONSTRUCTION/img.nii.gz")
