In [None]:
import os
import shutil
import tempfile

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from monai import transforms
from monai.apps import DecathlonDataset
from monai.config import print_config
from monai.data import DataLoader
from monai.utils import first, set_determinism
from torch.cuda.amp import GradScaler, autocast
from torch.nn import L1Loss
from tqdm import tqdm

from generative.inferers import LatentDiffusionInferer
from generative.losses import PatchAdversarialLoss, PerceptualLoss
from generative.networks.nets import AutoencoderKL, DiffusionModelUNet, PatchDiscriminator
from generative.networks.schedulers import DDPMScheduler

from mytransforms import *
from utiles import *
from generative.networks.nets import ControlNet
from generative.inferers import ControlNetDiffusionInferer,ControlNetLatentDiffusionInferer
print_config()

set_determinism(42)

In [2]:
# remove the unnecessary codes
batch_size = 4
img_dim = 64
img_space = 4
mask_space = 4
n_epochs = 2000
adv_weight = 0.01
perceptual_weight = 0.01
kl_weight = 1e-6
latent_c = 1
dataname = 'your_data_name'

lr = 2.5e-4
ldm_epoch = 1500
train_name = "lr"+str(lr) + 'bs'+str(batch_size) + 'ldmepoch'+ str(ldm_epoch)
start_time = time.time()


autoencoder_path = '/the_pretrained_autoencoderkl.pth'
unet_path ='/the_pretrained_ldm.pth'

input_data_dir = '/data/'+dataname


In [None]:

img_list_pattern = list_nii_files(input_data_dir)
img_list = natsorted(img_list_pattern)

# train_img_list = img_list[5*2:]
val_img_list = img_list

val_files = [
    {
        "image":val_img_list[i],
    }
    for i in range(10)
]


channel = 0  # 0 = Flair
assert channel in [0, 1, 2, 3], "Choose a valid channel"

val_transforms = transforms.Compose(
    [
        transforms.LoadImaged(keys=["image"]),
        transforms.EnsureChannelFirstd(keys=["image"]),
        transforms.Lambdad(keys="image", func=lambda x: x[channel, :, :, :]),
        transforms.EnsureChannelFirstd(keys=["image"], channel_dim="no_channel"),
        transforms.EnsureTyped(keys=["image"]),
        transforms.Orientationd(keys=["image"], axcodes="LPI"), # TODO Here change RAS to LPI
        Copyd(keys=["image"], new_key=["ori_img"]),
        transforms.Spacingd(keys=["image"], pixdim=(img_space, img_space, img_space), mode=("bilinear")),
        # TODO here may add an affine later
        transforms.ForegroundMaskD(keys = ["image"], new_key_prefix = "mask", threshold = 0.999, invert = True), # so we get a key called
        Copyd(keys=["maskimage","maskimage"], new_key=["latent_mask","latent_mask2"]),
        transforms.CenterSpatialCropd(keys=["maskimage"], roi_size=(img_dim, img_dim, img_dim)),
        transforms.CenterSpatialCropd(keys=["image"], roi_size=(img_dim, img_dim, img_dim)),
        transforms.Resized(keys=["latent_mask"],spatial_size=(14,14,14)), # we just shrink for 2 pixels
        transforms.ForegroundMaskD(keys = ["latent_mask"], threshold = 0.999, invert = True), # so we get a key called
        # transforms.CenterSpatialCropd(keys=["latent_mask"], roi_size=(16, 16, 16)),
        transforms.SpatialPadD(keys=["latent_mask"], spatial_size=(16, 16, 16)),
        
        transforms.Resized(keys=["latent_mask2"],spatial_size=(18,18,18)),
        transforms.ForegroundMaskD(keys = ["latent_mask2"], threshold = 0.999, invert = True), # so we get a key called
        transforms.CenterSpatialCropd(keys=["latent_mask2"], roi_size=(16, 16, 16)),
    ]
)

val_ds = CacheDataset(data = val_files, transform=val_transforms)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=8, persistent_workers=True)

In [None]:
check_data = first(val_loader)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")

autoencoder = AutoencoderKL(
    spatial_dims=3,
    in_channels=1,
    out_channels=1,
    num_channels=(32, 64, 64),
    latent_channels=latent_c,
    num_res_blocks=1,
    norm_num_groups=16,
    attention_levels=(False, False, True),
)
autoencoder.to(device)
autoencoder.load_state_dict(torch.load(autoencoder_path))
autoencoder.eval()

unet = DiffusionModelUNet(
    spatial_dims=3,
    in_channels=latent_c,
    out_channels=latent_c,
    num_res_blocks=1,
    num_channels=(32, 64, 64),
    attention_levels=(False, True, True),
    num_head_channels=(0, 64, 64),
)
unet.to(device)

unet.load_state_dict(torch.load(unet_path))
unet.eval()

scheduler = DDPMScheduler(num_train_timesteps=1000, schedule="scaled_linear_beta", beta_start=0.0015, beta_end=0.0195)
with torch.no_grad():
    with autocast(enabled=True):
        z = autoencoder.encode_stage_2_inputs(check_data["image"].to(device))

print(f"Scaling factor set to {1/torch.std(z)}")
scale_factor = 1 / torch.std(z)

inferer = LatentDiffusionInferer(scheduler, scale_factor=scale_factor)
# optimizer_diff = torch.optim.Adam(params=unet.parameters(), lr=1e-4)


In [5]:
from pytorch_ssim import *

class NCC:
    """
    Local (over window) normalized cross correlation loss.
    """

    def __init__(self, win=None):
        self.win = win

    def loss(self, y_true, y_pred):
        # compute local sums via convolution
        I= y_true
        J = y_pred
        eps = 1e-10
        cross = (I - torch.mean(I)) * (J - torch.mean(J))
        I_var = (I - torch.mean(I)) * (I - torch.mean(I))
        J_var = (J - torch.mean(J)) * (J - torch.mean(J))

        cc = torch.sum(cross) / torch.sum(torch.sqrt(I_var * J_var + eps))

        # test = torch.mean(cc)
        return cc


In [None]:
def create_centered_tensor(size=16, n=8):
    if n >= size:
        raise ValueError("n should be less than the size of the tensor")
    
    # 创建一个大小为 16x16x16 的全零张量
    tensor = torch.zeros(size, size, size)
    
    # 计算中心区域的起始和结束索引
    start_index = (size - n) // 2
    end_index = start_index + n
    
    # 将中心区域设置为 1
    tensor[start_index:end_index, start_index:end_index, start_index:end_index] = 1
    
    return tensor

ssim_list = []
ncc_list = []
mse_list = []

val_progress_bar = tqdm(enumerate(val_loader), total=len(val_loader), ncols=110)
with torch.no_grad():
    with autocast(enabled=True):
        for step, first_val_batch in val_progress_bar:
        # first_val_batch = first(val_loader)
            images = first_val_batch["image"].to(device)
            ori_masks = first_val_batch["maskimage"].to(device)
    
            k = 8
            latent_mask = create_centered_tensor(16, k-2).to(device) # LOTUS
            # latent_mask = create_centered_tensor(16, k).to(device) # LOTUS*
            masks = create_centered_tensor(64, k*4).to(device).unsqueeze(0)
            masks = masks.repeat(images.shape[0], 1, 1, 1,1)
            op_roi = masks
            op_target = images*(1 - op_roi)
            # masks = 1 - masks
            masked_images = images*masks # TODO the condition for kidney outpainting
            
            # # Note here the image should also be a destroied version:
            # images = masked_images
            print(f'masked_images shape is {masked_images.shape}') # masked_images shape is torch.Size([4, 1, 64, 64, 64])
            masked_images_latent = autoencoder.encode_stage_2_inputs(masked_images)
            images_latent = autoencoder.encode_stage_2_inputs(images)
            RmI_L = autoencoder.decode_stage_2_outputs(images_latent / scale_factor)*ori_masks

            masked_latent = masked_images_latent * latent_mask.to(device) 
            # masked_latent = m1L
            
            
            num_resample_steps = 2
            timesteps = torch.Tensor((999,)).to(device).long()
            progress_bar = tqdm(scheduler.timesteps)
            val_image_inpainted = torch.randn_like(images_latent).to(device)
            
            for t in progress_bar:
                for u in range(num_resample_steps):
                    # get the known portion at t-1
                    if t > 0:
                        noise = torch.randn_like(images_latent).to(device)
                        timesteps_prev = torch.Tensor((t - 1,)).to(noise.device).long()
                        val_image_inpainted_prev_known = scheduler.add_noise(
                            original_samples=masked_latent, noise=noise, timesteps=timesteps_prev
                        )
                    else:
                        val_image_inpainted_prev_known = masked_latent
    
                    # perform a denoising step to get the unknown portion at t-1
                    if t > 0:
                        timesteps = torch.Tensor((t,)).to(noise.device).long()
                        model_output = unet(val_image_inpainted, timesteps=timesteps)
                        val_image_inpainted_prev_unknown, _ = scheduler.step(model_output, t, val_image_inpainted)
    
                    # combine known and unknown using the mask
                    val_image_inpainted = torch.where(
                        latent_mask == 1, val_image_inpainted_prev_known, val_image_inpainted_prev_unknown
                    )
    
                    # perform resampling
                    if t > 0 and u < (num_resample_steps - 1):
                        # sample x_t from x_t-1
                        noise = torch.randn_like(images_latent).to(device)
                        val_image_inpainted = (
                            torch.sqrt(1 - scheduler.betas[t - 1]) * val_image_inpainted
                            + torch.sqrt(scheduler.betas[t - 1]) * noise
                        )
            rec_image = autoencoder.decode_stage_2_outputs(val_image_inpainted / scale_factor)
            # rec_image = rec_image*ori_masks*(1 - op_roi)
            rec_image = (rec_image - torch.min(rec_image))/(torch.max(rec_image) - torch.min(rec_image))*ori_masks*(1 - op_roi)


            
            for n in range(rec_image.shape[0]):
                # compare in the whole image level
                rec_patch = rec_image[n:n+1, :, :, :, :]
                target_patch = op_target[n:n+1, :, :, :, :]
                
                ssim_volume = ssim3D(rec_patch, target_patch, window_size=9, size_average=False)
                ssim_volume = ssim_volume.squeeze(0).squeeze(0).cpu().numpy()
                ssim_score = np.average(ssim_volume)
                ssim_list.append(ssim_score)
                ncc_score = NCC().loss(target_patch, rec_patch)
                ncc_list.append(ncc_score.cpu().numpy())
                
                mse_volume = (target_patch - rec_patch) ** 2
                mse_list.append(np.average(mse_volume.cpu().numpy()))
                

In [None]:
# k = 8, 8x4 32 whole image, lotus
print(ssim_list)
print(ncc_list)
print(mse_list)

print(f'SSIM ave: {np.mean(ssim_list)}, std: {np.std(ssim_list)}')
print(f'NCC ave: {np.mean(ncc_list)}, std: {np.std(ncc_list)}')
print(f'MSE ave: {np.mean(mse_list)}, std: {np.std(mse_list)}')

In [None]:
def vis_results(*tensors, idx):  # TODO: to visualize the midian result
    # folder = output_path + "/LogVis"
    # os.makedirs(folder, exist_ok=True)  # Create the directory if it doesn't exist
    fig = plt.figure(figsize=(20, 5))
    # t_n = ['I', 'I_L', 'mI', 'mI_L', 'Lm1', 'Lm2', 'Lm', 'm1L', 'm2L','mL', 'OPm1L','OPm2L','OPmL', 'ROPm1L','ROPm2L','ROPmL']
    t_n = ['I', 'mask','mI', 'op_image']
    len_tensor = len(tensors)
    for k in range(3):
        for i, tensor in enumerate(tensors, start=1):
            ax = fig.add_subplot(3, len_tensor, k*len_tensor + i) # show all 3 panels
            # Assuming your tensor has the shape (1,1,128,128,128) and it's on CUDA
            tensor = tensor.cpu()  # Move the tensor back to CPU for visualization
            # print(f"tensor's shape is {tensor.shape}")
            if k == 0:
                tensor_slice = tensor[idx, 0, :, tensor.shape[-1]//2, :].detach().numpy()  # Select the desired slice and convert to numpy
            if k == 1:
                tensor_slice = tensor[idx, 0, tensor.shape[-1]//2, :, :].detach().numpy()  # Select the desired slice and convert to numpy
            if k == 2:
                tensor_slice = tensor[idx, 0, :, :, tensor.shape[-1]//2].detach().numpy()  # Select the desired slice and convert to numpy
            # rotate_slice = np.rot90(tensor_slice, k=-1)
            # ax.imshow(tensor_slice, cmap='gray', vmin=0, vmax=1)
            # if (t_n[i - 1] == 'I') or (t_n[i - 1] == 'mI') or (t_n[i - 1] == 'RmI_L') or (t_n[i - 1] == 'Rm1L') or (t_n[i - 1] == 'Rm2L') or (t_n[i - 1] =='RmL') or (t_n[i - 1] == 'ROPmL') or (t_n[i - 1] == 'Lm1') or (t_n[i - 1] == 'op_image'):
            #     ax.imshow(tensor_slice, cmap='gray', vmin=0, vmax=1)
            # else:
            # ax.imshow((tensor_slice - np.min(tensor_slice))/(np.max(tensor_slice) - np.min(tensor_slice)), cmap='gray', vmin=0, vmax=1)
            ax.imshow(tensor_slice, cmap='gray', vmin=0, vmax=1)
            ax.set_title(t_n[i - 1])  # Give each tensor image a title
            ax.axis('off')  # Turn off axis
    plt.show()
for i in range(images.shape[0]):
    # vis_results(images, masks, op_target, rec_image, idx = i)
    vis_results(images, masks, RmI_L, rec_image, idx = i)