In [1]:
import json
import os
import pdb

import numpy as np
import torch.distributed as dist
from torch.utils.data import DataLoader, DistributedSampler
from torch.utils.data._utils.collate import default_collate

from monai.data import (
    Dataset,
    Dataset,
)
import math

from monai.transforms import (
    EnsureChannelFirstd,
    RandSpatialCropd,
    Compose,
    LoadImaged,
    Orientationd,
    ScaleIntensityRanged,
    ToTensord,
    RandAffined,
    RandZoomd,
    RandRotated
)


def datafold_read(datalist, basedir, fold=0, key="training"):
    with open(datalist) as f:
        json_data = json.load(f)

    json_data = json_data[key]

    for d in json_data:
        for k, v in d.items():
            if isinstance(d[k], list):
                d[k] = [os.path.join(basedir, iv) for iv in d[k]]
            elif isinstance(d[k], str):
                d[k] = os.path.join(basedir, d[k]) if len(d[k]) > 0 else d[k]

    tr = []
    val = []
    for d in json_data:
        if "fold" in d and d["fold"] == fold:
            val.append(d)
        else:
            tr.append(d)
    # pdb.set_trace()
    return tr, val


class MaskGenerator:
    def __init__(self, input_size=192, mask_patch_size=32, model_patch_size=(2, 2, 2), mask_ratio=0.6):
        self.input_size = input_size #192
        self.mask_patch_size = mask_patch_size #32
        self.model_patch_size = model_patch_size[0] #2
        self.mask_ratio = mask_ratio #0.6
        assert self.input_size % self.mask_patch_size == 0
        assert self.mask_patch_size % self.model_patch_size == 0
        self.rand_size = self.input_size // self.mask_patch_size #6
        self.scale = self.mask_patch_size // self.model_patch_size #8
        self.token_count = int(self.rand_size*self.rand_size*self.rand_size/2) #6*6*3
        self.mask_count = int(np.ceil(self.token_count * self.mask_ratio))

    def __call__(self):
        mask_idx = np.random.permutation(self.token_count)[: self.mask_count]
        mask = np.zeros(self.token_count, dtype=int)
        mask[mask_idx] = 1
        mask = mask.reshape((self.rand_size, self.rand_size, int(self.rand_size/2)))
        mask = mask.repeat(self.scale, axis=0).repeat(self.scale, axis=1).repeat(self.scale, axis=2)
        return mask


class Transform:
    def __init__(self):
        self.transform_pet = Compose(
            [
                LoadImaged(keys=["image"], image_only=True),
                EnsureChannelFirstd(keys=["image"]),
                Orientationd(keys=["image"], axcodes="LPS"),
                ScaleIntensityRanged(keys=["image"], a_min=0.0, a_max=22.0, b_min=0.0, b_max=1.0, clip=True),
                ToTensord(keys=["image"], track_meta=False),
            ]
        )
            
        self.mask_generator = MaskGenerator()

    def __call__(self, img):
        img = self.transform_pet(img)
        mask = self.mask_generator()
        return img, mask

def collate_fn(batch):
    if not isinstance(batch[0][0], tuple):
        return default_collate(batch)
    else:
        batch_num = len(batch)
        ret = []
        for item_idx in range(len(batch[0][0])):
            if batch[0][0][item_idx] is None:
                ret.append(None)
            else:
                ret.append(default_collate([batch[i][0][item_idx] for i in range(batch_num)]))
        ret.append(default_collate([batch[i][1] for i in range(batch_num)]))
        return ret

import natsort
import glob

def build_loader_simmim():
    train_real = natsort.natsorted(glob.glob(f'/workspace/Ablation/ABLATION_PD/RECONSTRUCTION/DATA/*/*.nii.gz'))[:] #ALL -> 2125 or 2130

    print("Train [Total]  number = ", len(train_real))

    files_tr = [{"image": tr_img} for tr_img in zip(train_real)]

    transform = Transform()
    dataset_train = Dataset(data = files_tr, transform = transform)

    dataloader_train = DataLoader(
        dataset_train,
        batch_size=1,
        num_workers=1,
        shuffle=False,
        pin_memory=True,
        drop_last=True,
        collate_fn=collate_fn,
    )

    return dataloader_train


In [2]:
import natsort
import glob
path_list = []
ori_list = natsort.natsorted(glob.glob(f'/workspace/Ablation/ABLATION_PD/RECONSTRUCTION/DATA/*/*.nii.gz'))
for i in range(30):
    a_item = natsort.natsorted(glob.glob(f'/workspace/Ablation/ABLATION_PD/RECONSTRUCTION/DATA/*/*.nii.gz'))[i].replace("DATA", "7_SimMIM")
    path_list.append(a_item)

In [3]:
import pdb

import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.layers import trunc_normal_

from monai.networks.blocks.dynunet_block import UnetBasicBlock, UnetResBlock, get_conv_layer

from swin_transformer_3d import SwinTransformer3D


class PixelShuffle3D(nn.Module):
    """
    https://github.com/assassint2017/PixelShuffle3D/blob/master/PixelShuffle3D.py
    """

    def __init__(self, upscale_factor):
        super(PixelShuffle3D, self).__init__()
        self.upscale_factor = upscale_factor

    def forward(self, inputs):
        batch_size, channels, in_depth, in_height, in_width = inputs.size()
        channels //= self.upscale_factor**3
        out_depth = in_depth * self.upscale_factor
        out_height = in_height * self.upscale_factor
        out_width = in_width * self.upscale_factor
        input_view = inputs.contiguous().view(
            batch_size,
            channels,
            self.upscale_factor,
            self.upscale_factor,
            self.upscale_factor,
            in_depth,
            in_height,
            in_width,
        )
        shuffle_out = input_view.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
        return shuffle_out.view(batch_size, channels, out_depth, out_height, out_width)


class SwinTransformerForSimMIM(SwinTransformer3D):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

        self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        trunc_normal_(self.mask_token, mean=0.0, std=0.02)
        self.layers = nn.ModuleList([self.layers1, self.layers2, self.layers3, self.layers4])

    def forward(self, x, mask):
        _, _, D, H, W = x.size()
        # x = self.patch_embed(x)
        patched_x = self.patch_embed(x)
        # print("patch embed size : ", patched_x.shape)
        # x = x.flatten(2).transpose(1, 2)
        flattened_patched_x = patched_x.flatten(2).transpose(1, 2)
        assert mask is not None
        # B, L, _ = x.shape
        B, L, _ = flattened_patched_x.shape
        mask_tokens = self.mask_token.expand(B, L, -1)
        # print("mask_tokens.shape :", mask_tokens.shape)
        w = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
        x = flattened_patched_x * (1.0 - w) + mask_tokens * w
        x = self.pos_drop(x)
        x = x.view(-1, self.embed_dim, D // self.patch_size[0], H // self.patch_size[1], W // self.patch_size[2])
        # print("first x_shape :", x.shape)
        x_input = x
        for layer in self.layers:
            x = layer[0](x)
            # print("x_shape :", x.shape)
        reduction = self.patch_size[0] * 16
        x = x.reshape(-1, (D // reduction) * (H // reduction) * (W // reduction), 2 * self.num_features)
        x = self.norm(x)
        x = x.transpose(1, 2)
        x = x.view(-1, 2 * self.num_features, D // 32, H // 32, W // 32)
        # print("last x_shape :", x.shape)
        return x, flattened_patched_x, mask_tokens, x_input, w

    @torch.jit.ignore
    def no_weight_decay(self):
        return super().no_weight_decay() | {"mask_token"}

class SimMIM(nn.Module):
    def __init__(self, encoder, encoder_stride, decoder="pixel_shuffle", loss="mask_only"):
        super().__init__()
        self.encoder = encoder
        self.encoder_stride = encoder_stride
        self.decoder = decoder
        self.loss = loss

        self.conv1 = nn.Conv3d(
            in_channels=2 * self.encoder.num_features, out_channels=self.encoder_stride**3 * 1, kernel_size=1
        )
        self.pixel_shuffle = PixelShuffle3D(self.encoder_stride)

        self.in_chans = self.encoder.in_chans
        self.patch_size = self.encoder.patch_size

    def forward(self, x, mask):
        z, flattened_patched_x, mask_tokens, x_input, mask_flatten = self.encoder(x, mask)
        x_rec = self.pixel_shuffle(self.conv1(z))

        mask = (
            mask.repeat_interleave(self.patch_size[0], 1)
            .repeat_interleave(self.patch_size[1], 2)
            .repeat_interleave(self.patch_size[2], 3)
            .unsqueeze(1)
            .contiguous()
        )
        loss_recon = F.l1_loss(x, x_rec, reduction="none")
        loss = (loss_recon * mask).sum() / (mask.sum() + 1e-5) / self.in_chans

        return loss, flattened_patched_x, mask_tokens, x_input, mask_flatten, x_rec

    @torch.jit.ignore
    def no_weight_decay(self):
        if hasattr(self.encoder, "no_weight_decay"):
            return {"encoder." + i for i in self.encoder.no_weight_decay()}
        return {}

    @torch.jit.ignore
    def no_weight_decay_keywords(self):
        if hasattr(self.encoder, "no_weight_decay_keywords"):
            return {"encoder." + i for i in self.encoder.no_weight_decay_keywords()}
        return {}




In [4]:
ckpt_path = f'/workspace/Ablation/ABLATION_PD/FINE_TUNING/WEIGHTS/SIMMIM_1900.pt'
print(f"ckpt_path : {ckpt_path}")
encoder = SwinTransformerForSimMIM(
    num_classes=1,#fine tuning시 바꿀 것
    img_size=192,
    patch_size=(2, 2, 2),
    in_chans=1,
    embed_dim=48,
    depths=[2, 2, 2, 2],
    num_heads=[3, 6, 12, 24],
    window_size=(7, 7, 7),
    mlp_ratio=4.0,
    qkv_bias=True,
    qk_scale=None,
    drop_rate=0.0,
    drop_path_rate=0.1,
    # use_checkpoint=args.use_grad_checkpoint,
    patch_norm=True,
)
encoder_stride = 32
model = SimMIM(encoder=encoder, encoder_stride=encoder_stride)

ckpt = torch.load(ckpt_path, map_location='cpu')
model.load_state_dict(ckpt['model'])


ckpt_path : /workspace/Ablation/ABLATION_PD/FINE_TUNING/WEIGHTS/SIMMIM_1900.pt


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
  ckpt = torch.load(ckpt_path, map_location='cpu')


<All keys matched successfully>

In [5]:
import SimpleITK as sitk
def save_img(img, save_path):
    img = img.cpu().detach().numpy().transpose(0,4,3,2,1).squeeze().astype(np.float32)
    save_pred = sitk.GetImageFromArray(img)
    sitk.WriteImage(save_pred, save_path)


In [6]:
loader = build_loader_simmim()
model.to("cuda")
model.eval()
with torch.no_grad():
    for idx, (img, mask) in enumerate(loader):
        img = img["image"].cuda(non_blocking=True)
        mask = mask.cuda(non_blocking=True)
        loss, flattened_patched_x, mask_tokens, x_input, mask_flatten, x_rec = model(img, mask)
        save_img(x_rec, path_list[idx])

Train [Total]  number =  30


In [8]:
x_rec.shape

torch.Size([1, 1, 192, 192, 96])

In [9]:
path_list[idx]

'/workspace/Ablation/ABLATION_PD/7_SimMIM/7_SimMIM_DATA/NC/NM_0073_centered_normalized_occipital.nii.gz'

In [10]:
x_input[0][0].shape

torch.Size([96, 96, 48])