To analyze the impact of shrinkage levels on reconstruction performance by shrinking the binary mask, here we provide 2 methods:

- 1. Following the reviewer's suggestion, shrinking the mask by one pixel, you can find the code in the 4th block
- 2. Shrinking the mask by self-defined downsampling rate. You can find the code in the val_transforms part

In [None]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from monai import transforms
from monai.apps import DecathlonDataset
from monai.config import print_config
from monai.data import DataLoader
from monai.utils import first, set_determinism
from torch.cuda.amp import GradScaler, autocast
from torch.nn import L1Loss
from tqdm import tqdm

from generative.inferers import LatentDiffusionInferer,DiffusionInferer
from generative.losses import PatchAdversarialLoss, PerceptualLoss
from generative.networks.nets import AutoencoderKL, DiffusionModelUNet, PatchDiscriminator
from generative.networks.schedulers import DDPMScheduler, DDIMScheduler

from mytransforms import *
from utiles import *
from generative.networks.nets import ControlNet
from generative.inferers import ControlNetDiffusionInferer,ControlNetLatentDiffusionInferer
print_config()

set_determinism(42)

In [None]:
img_dim = 64
batch_size = 4
img_space = 4
train_name = 'your_model_name'
start_time = time.time()

input_data_dir = '/your_data_path'
save_dir = '/results/Repaint3D/'+train_name
save_vis_dir = os.path.join(save_dir,"vis")
os.makedirs(save_vis_dir, exist_ok=True)

img_list_pattern = list_nii_files(input_data_dir)
img_list = natsorted(img_list_pattern)


val_img_list = img_list

val_files = [
    {
        "image":val_img_list[i],
    }
    for i in range(len(val_img_list))
]


channel = 0  # 0 = Flair
assert channel in [0, 1, 2, 3], "Choose a valid channel"

val_transforms = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image"]),
        transforms.EnsureChannelFirstd(keys=["image"]),
        transforms.Lambdad(keys="image", func=lambda x: x[channel, :, :, :]),
        transforms.EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
        transforms.EnsureTyped(keys=["image"]),
        transforms.Orientationd(keys=["image"], axcodes="LPI"), # TODO Here change RAS to LPI
        Copyd(keys=["image"], new_key=["ori_img"]),
        transforms.Spacingd(keys=["image"], pixdim=(img_space, img_space, img_space), mode=("bilinear")),
        # TODO here may add an affine later
        transforms.ForegroundMaskD(keys = ["image"], new_key_prefix = "mask", threshold = 0.999, invert = True), # so we get a key called
        
        transforms.CenterSpatialCropd(keys=["maskimage"], roi_size=(img_dim, img_dim, img_dim)),
        transforms.CenterSpatialCropd(keys=["image"], roi_size=(img_dim, img_dim, img_dim)),

        
        # you can shrink the binary mask by a certain downsampling rate by uncommenting the following code
        # # ***************************************************************************
        # transforms.Spacingd(keys=["maskimage"], pixdim=(1.2, 1.2, 1.2), mode=("nearest")),
        # transforms.ForegroundMaskD(keys = ["maskimage"], threshold = 0.999, invert = True),
        # transforms.SpatialPadd(keys=["maskimage"],spatial_size = (img_dim, img_dim, img_dim)),
        # # ***************************************************************************
        
    ]
)

val_ds = CacheDataset(data = val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=8, persistent_workers=True)

In [None]:
device = torch.device("cuda")

model = DiffusionModelUNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    num_channels=[256, 256, 512],
    attention_levels=[False, False, True],
    num_head_channels=[0, 0, 512],
    num_res_blocks=2,
)
model.to(device)

model_path = 'pretrained_ddpm_model.pth'
model.load_state_dict(torch.load(model_path))
num_step = 1000
scheduler = DDPMScheduler(num_train_timesteps=num_step , schedule="scaled_linear_beta", beta_start=0.0005, beta_end=0.0195)
# scheduler = DDIMScheduler(num_train_timesteps=num_step, schedule="scaled_linear_beta", beta_start=0.0005, beta_end=0.0195)
inferer = DiffusionInferer(scheduler)
optimizer = torch.optim.Adam(params=model.parameters(), lr=5e-5)

scaler = GradScaler()
total_start = time.time()
model.eval()


In [None]:
with torch.no_grad():
    with autocast(enabled=True):
        first_val_batch = first(val_loader)
        images = first_val_batch["image"].to(device)
        masks = first_val_batch["maskimage"].to(device)


        # you can shrink the binary mask by 1 pixel by uncommenting the following code
        # # ***************************************************************************
        # structuring_element = torch.ones((1, 1, 3, 3, 3), dtype=torch.float32, device="cuda")
        # padded_masks = F.pad(masks, (1, 1, 1, 1, 1, 1), mode='constant', value=1)
        # eroded_mask = F.conv3d(padded_masks, structuring_element, stride=1, padding=1, groups=1)
        # eroded_mask = (eroded_mask == structuring_element.sum()).float()
        # eroded_mask = eroded_mask[:, :, 1:-1, 1:-1, 1:-1]
        # eroded_mask = eroded_mask.to(torch.uint8)
        # masks = eroded_mask
        # # ***************************************************************************

        masked_images = images*masks # TODO the condition for kidney outpainting
        
        
        num_resample_steps = 2
        timesteps = torch.Tensor((999,)).to(device).long()
        progress_bar = tqdm(scheduler.timesteps)
        val_image_inpainted = torch.randn_like(images).to(device)
        
        for t in progress_bar:
            for u in range(num_resample_steps):
                # get the known portion at t-1
                if t > 0:
                    noise = torch.randn_like(images).to(device)
                    timesteps_prev = torch.Tensor((t - 1,)).to(noise.device).long()
                    val_image_inpainted_prev_known = scheduler.add_noise(
                        original_samples=masked_images, noise=noise, timesteps=timesteps_prev
                    )
                else:
                    val_image_inpainted_prev_known = masked_images

                # perform a denoising step to get the unknown portion at t-1
                if t > 0:
                    timesteps = torch.Tensor((t,)).to(noise.device).long()
                    model_output = model(val_image_inpainted, timesteps=timesteps)
                    val_image_inpainted_prev_unknown, _ = scheduler.step(model_output, t, val_image_inpainted)

                # combine known and unknown using the mask
                val_image_inpainted = torch.where(
                    masks == 1, val_image_inpainted_prev_known, val_image_inpainted_prev_unknown
                )

                # perform resampling
                if t > 0 and u < (num_resample_steps - 1):
                    # sample x_t from x_t-1
                    noise = torch.randn_like(images).to(device)
                    val_image_inpainted = (
                        torch.sqrt(1 - scheduler.betas[t - 1]) * val_image_inpainted
                        + torch.sqrt(scheduler.betas[t - 1]) * noise
                    )


In [None]:
def vis_results(*tensors, idx):  # TODO: to visualize the midian result
    fig = plt.figure(figsize=(20, 5))
    t_n = ['I', 'mask','mI', 'op_image']
    len_tensor = len(tensors)
    for k in range(3):
        for i, tensor in enumerate(tensors, start=1):
            ax = fig.add_subplot(3, len_tensor, k*len_tensor + i) # show all 3 panels
            # Assuming your tensor has the shape (1,1,128,128,128) and it's on CUDA
            tensor = tensor.cpu()  # Move the tensor back to CPU for visualization
            # print(f"tensor's shape is {tensor.shape}")
            if k == 0:
                tensor_slice = tensor[idx, 0, :, tensor.shape[-1]//2, :].detach().numpy()  # Select the desired slice and convert to numpy
            if k == 1:
                tensor_slice = tensor[idx, 0, tensor.shape[-1]//2, :, :].detach().numpy()  # Select the desired slice and convert to numpy
            if k == 2:
                tensor_slice = tensor[idx, 0, :, :, tensor.shape[-1]//2].detach().numpy()  # Select the desired slice and convert to numpy

            ax.imshow(tensor_slice, cmap='gray', vmin=0, vmax=1)
            ax.set_title(t_n[i - 1])  # Give each tensor image a title
            ax.axis('off')  # Turn off axis
    plt.show()
for i in range(images.shape[0]):
    vis_results(images, masks, masked_images, val_image_inpainted, idx = i)