In [1]:
import sys
sys.path.insert(0, '/home/ryuuyou/Project/denoising-diffusion/')

In [2]:
from denoising_diffusion_pytorch import Unet, Sampler, tensor_info, histc
from degradation_model import DenoiseOperator, AnisotropicOperator, GaussialBlurOperator, UnetAnisotropicOperator, SelfNetAnisotropicOperator
from diffusion_posterior_sample import DPS

import torch
from torchvision import utils, transforms as T

import numpy as np
import tifffile as tif
import os

  from .autonotebook import tqdm as notebook_tqdm


# Model

In [3]:
# sampling_timesteps = 1000
# is_ddim_sampling = False

sampling_timesteps = 250
is_ddim_sampling = True

In [4]:
model = Unet(
    dim = 64,
    channels=1,
    dim_mults = (1, 2, 4, 8)
)
diffusion = DPS(
    model,
    image_size = 128,
    timesteps = 1000,
    sampling_timesteps = sampling_timesteps,
    is_ddim_sampling=is_ddim_sampling
)
sampler = Sampler(diffusion)

In [5]:
base_path = '/home/ryuuyou/Project/denoising-diffusion/checkpoints/care_liver/'
model_path = os.path.join(base_path, 'model-15.pt')
print(f'model path: {model_path}')

sampler.load(path=model_path)

model path: /home/ryuuyou/Project/denoising-diffusion/checkpoints/care_liver/model-15.pt
loading from: [version]:1.5.4; [step]:150000


# Data

In [6]:
# data_path = "/home/share/data/CARE/Isotropic_Liver/train_data/data_label.npz"
# data_file = np.load(data_path)
# file_name = 'X'
# data = data_file[file_name]

In [7]:
# data_index = 0
# measurment = torch.from_numpy(data[data_index:data_index+1]).cuda()
# tensor_info(measurment)

In [8]:
# gt = data_file['Y']
# gt_measurment = torch.from_numpy(gt[data_index:data_index+1]).cuda()
# tensor_info(gt)

# Data 2

In [9]:
data_path = "/home/share/data/CARE/Isotropic_Liver/test_data/input_subsample_8.tif"
data = tif.imread(data_path)
data = np.expand_dims(np.moveaxis(data, 0, -1), axis=1).astype(np.float32)
data.shape

(752, 1, 752, 301)

In [10]:
data_index = 0
measurment = torch.from_numpy(data[data_index:data_index+1, :, 300:300+128, 100:100+128]).cuda()
tensor_info(measurment)

shape: torch.Size([1, 1, 128, 128]) 
type: torch.float32 
max: 2611.0           
min: 97.0 
mean: 376.29656982421875 
std: 236.73397827148438


In [11]:
gt_path = "/home/share/data/CARE/Isotropic_Liver/test_data/input_subsample_1_groundtruth.tif"
gt = tif.imread(gt_path)
gt = np.expand_dims(np.moveaxis(gt, 0, -1), axis=1).astype(np.float32)
data.shape

(752, 1, 752, 301)

In [12]:
data_index = 0
gt_measurment = torch.from_numpy(gt[data_index:data_index+1, :, 300:300+128, 100:100+128]).cuda()
tensor_info(gt_measurment)

shape: torch.Size([1, 1, 128, 128]) 
type: torch.float32 
max: 2957.0           
min: 70.0 
mean: 378.065673828125 
std: 259.2381591796875


# Operator

In [13]:
measurment_type = ['identity', 'denoise', 'deblur', 'iso_sr', 'iso_sr_dl'][4]

In [14]:
measurment_type

'iso_sr_dl'

In [15]:
print(f'{measurment_type}')

if measurment_type == 'identity':
    operator = DenoiseOperator()

elif measurment_type == 'denoise':
    measurment = measurment+torch.randn_like(measurment, device=measurment.device)*0.1
    operator = DenoiseOperator()

elif measurment_type == 'deblur':
    blur = T.GaussianBlur(kernel_size=19, sigma=3.0)
    measurment = blur(measurment)
    operator = GaussialBlurOperator(img_shape=measurment.shape)

elif measurment_type == 'iso_sr':
    operator = AnisotropicOperator(img_shape=measurment.shape, scale=(1,3), noise_sigma=0.01)

elif measurment_type == 'iso_sr_dl':
    # d_model_path = 'denoising_diffusion_pytorch/degradation_model/checkpoint/best.pth'
    # operator = UnetAnisotropicOperator(d_model_path, noise_sigma=0.01)
    operator = SelfNetAnisotropicOperator(path='/home/ryuuyou/Project/self_net/data/care_liver/checkpoint/saved_models/netG_A/60_3200.pkl')

iso_sr_dl
initialize network with normal


In [16]:
norm_type = 'min_max'
# norm_type = None
if norm_type == 'z_score':
    normalize = lambda t:(t-t.mean())/(t.std())
elif norm_type == 'min_max':
    normalize = lambda t:(t - t.min())/(t.max()-t.min())
else:
    normalize = lambda t:t

measurment = normalize(measurment)
tensor_info(measurment)

shape: torch.Size([1, 1, 128, 128]) 
type: torch.float32 
max: 1.0           
min: 0.0 
mean: 0.11109648644924164 
std: 0.09416625648736954


# Inference

In [26]:
scale = 8
num_samples = 16
return_all_timesteps = False

In [27]:
res = sampler.dps(measurement=measurment, operator=operator, num_samples=num_samples, scale=scale, return_all_timesteps=return_all_timesteps)
tensor_info(res)

sampling loop time step:   0%|          | 0/250 [00:00<?, ?it/s]

sampling loop time step: 100%|██████████| 250/250 [00:53<00:00,  4.70it/s]


shape: torch.Size([16, 1, 128, 128]) 
type: torch.float32 
max: 1.0           
min: 0.47038328647613525 
mean: 0.6204744577407837 
std: 0.07712245732545853


In [28]:
out_folder = os.path.join('out/', measurment_type, 's'+str(scale).replace('.', 'd'))
print(out_folder)
sampler.save_tif(measurment, folder=out_folder, file_name='input.tif', make_grid=False)
if not return_all_timesteps:
    sampler.save_tif(res, folder=out_folder)
else:
    sampler.save_tif_with_records(res, folder=out_folder, step=10)

out/iso_sr_dl/s8


In [29]:
sampler.save_tif(gt_measurment, folder=out_folder, file_name='gt.tif', make_grid=False)

In [21]:
#