In [1]:
import os
gpu_ids = 0
os.environ["CUDA_VISIBLE_DEVICES"] = f"{gpu_ids}"
# import libraries
import numpy as np
from termcolor import colored, cprint
# for display
from IPython.display import Image as ipy_image
from IPython.display import display
from torch.utils.data import random_split
import numpy as np
import SimpleITK as sitk
import os
from torch.utils.data import DataLoader, Dataset
import torch
import torch.backends.cudnn as cudnn
cudnn.benchmark = True
import torchvision.utils as vutils
from torch.utils.tensorboard import SummaryWriter
import os
import datetime
from tensorflow_mri import ssim3d
import SimpleITK as sitk
from tqdm import tqdm
from models.base_model import create_model
from utils.demo_util import VQVAEOpt

def resize_image_itk(itkimage, newSize, resamplemethod=sitk.sitkNearestNeighbor):
    resampler = sitk.ResampleImageFilter()
    originSize = itkimage.GetSize()
    originSpacing = itkimage.GetSpacing()
    newSize = np.array(newSize, float)
    factor = originSize / newSize
    newSpacing = originSpacing * factor
    newSize = newSize.astype(np.int16)
    resampler.SetReferenceImage(itkimage)
    resampler.SetSize(newSize.tolist())
    resampler.SetOutputSpacing(newSpacing.tolist())
    resampler.SetTransform(sitk.Transform(3, sitk.sitkIdentity))
    resampler.SetInterpolator(resamplemethod)
    itkimgResampled = resampler.Execute(itkimage)
    return itkimgResampled

def load_data(nii_path):
    nii_img = sitk.ReadImage(nii_path)
    nii_data = sitk.GetArrayFromImage(nii_img)
    mean_value = nii_data.mean()
    std_dev = nii_data.std()
    normalized_data = (nii_data - mean_value) / std_dev
    nrrd_img = sitk.GetImageFromArray(normalized_data)
    nrrd_img = resize_image_itk(nrrd_img, (128, 128, 128),
                                        resamplemethod=sitk.sitkLinear)
    nii_data = sitk.GetArrayFromImage(nrrd_img)
    return torch.tensor(nii_data).float(),std_dev,mean_value

seed = 2023
opt = VQVAEOpt(gpu_ids=gpu_ids, seed=seed)
device = opt.device
data_root_path = "/root/autodl-tmp/vqvae/data"
# initialize SDFusion model
ckpt_path_list = ["./ckpt/flair.pth","./ckpt/t1.pth","./ckpt/t1ce.pth","./ckpt/t2.pth",]
# ckpt_path_list = ["./ckpt/flair.pth","./ckpt/t1.pth","./ckpt/t1ce.pth","/root/autodl-tmp/vqvae/results/t2.pth",]
#data_path = "/root/autodl-tmp/dit/BraTS2021_00017.npy"
data_path = "/root/autodl-tmp/dit_repaint/result/teset_figt1ce_reserve/BraTS2021_01150.npy_mse0.001663263188675046_ssim[0.9936348].npy"
data = torch.tensor(np.load(data_path))
iter_num = data_path.split("/")[-1].split(".")[0]
folder = data_path.split("/")[-2]
# iter_num = "BraTS2021_00021"
recons_dir = f"./sest_fig/{folder}/{iter_num}"
print(f"Save to {recons_dir}")
os.makedirs(recons_dir, exist_ok=True)
print(f"data has {data.shape[0]} samples")
print(f"data shape : {data.shape}")
sample_list = torch.chunk(data, chunks=data.shape[0], dim=0)
i = 0
for data in sample_list: 
    i += 1
    result_list = torch.chunk(data, chunks=4, dim=1)
    with torch.no_grad():
        for j in range(len(result_list)):
            if j == 3:
                torch.cuda.empty_cache()
                modality = ckpt_path_list[j].split("/")[-1].replace(".pth","")
                print(f"========================={j}_{modality}=========================")
                data_ori,std_dev,mean_value = load_data(f'{data_root_path}/{modality}/{iter_num}_{modality}.nii.gz')
                print(data_ori.shape)
                dset="snet"
                opt.init_model_args(ckpt_path_list[j],isTrain=False)
                opt.init_dset_args(dataset_mode=dset)
                #print(opt.model)
                vqvae = create_model(opt)
                # vqvae.initialize(opt)
                cprint(f'[*] "{vqvae.name()}" loaded.', 'cyan')
                # print("z shape :",result_list[j].shape)
                reconstructions = vqvae.vqvae_module.decode_no_quant(result_list[j].to("cuda"))-0.4
                #reconstructions[reconstructions<0]=0
                data_ori = data_ori.unsqueeze(0).unsqueeze(0)
                #reconstructions = vqvae.vqvae_module.decode_no_quant(data.to("cuda"))
                print("reconstructions shape : ",reconstructions.shape)
                mse = torch.nn.functional.mse_loss(reconstructions.clone().detach().cpu(), data_ori.clone().detach().cpu())
                ssim = ssim3d(data_ori.contiguous().clone().detach().cpu(),
                              reconstructions.contiguous().clone().detach().cpu(), filter_size=1).numpy()

                print(f'filename {iter_num}_{modality}.nii.gz, ssim: {ssim} , mse: {mse.item()}')
                out_img = sitk.GetImageFromArray(reconstructions.cpu().detach())
                output_filename = f"{recons_dir}/sample_{iter_num}_{modality}.nii.gz"
                sitk.WriteImage(out_img, output_filename)
                torch.cuda.empty_cache()
print('Done!')

2024-03-13 19:29:56.228820: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-03-13 19:29:56.390097: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-03-13 19:29:56.969689: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2024-03-13 19:29:56.969762: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinf

[*] VQVAETestOpt initialized.
Save to ./sest_fig/teset_figt1ce_reserve/BraTS2021_01150
data has 1 samples
data shape : torch.Size([1, 4, 32, 32, 32])
torch.Size([128, 128, 128])
Working with z of shape (1, 1, 16, 16, 16) = 4096 dimensions.
[34m[*] weight successfully load from: ./ckpt/t2.pth[0m
initialize
[34m[*] Model has been created: VQVAE-Model[0m
[36m[*] "VQVAE-Model" loaded.[0m
reconstructions shape :  torch.Size([1, 1, 128, 128, 128])


2024-03-13 19:30:13.262646: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2024-03-13 19:30:13.267006: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1616] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 45126 MB memory:  -> device: 0, name: NVIDIA A40, pci bus id: 0000:01:00.0, compute capability: 8.6
2024-03-13 19:30:13.887970: I tensorflow/stream_executor/cuda/cuda_dnn.cc:384] Loaded cuDNN version 8700


filename BraTS2021_01150_t2.nii.gz, ssim: [0.8416451] , mse: 0.4071943461894989
Done!
