In [None]:
from denoising_diffusion_pytorch import Unet, Sampler, DPS, DenoiseOperator, AnisotropicOperator, DLAnisotropicOperator, GaussialBlurOperator, tensor_info, histc

import torch
from torch.cuda.amp import autocast as autocast
from torchvision import utils

import numpy as np
import os

# Model

In [None]:
sampling_timesteps = 1000
is_ddim_sampling = False

# sampling_timesteps = 250
# is_ddim_sampling = True

In [None]:
model = Unet(
    dim = 64,
    channels=1,
    dim_mults = (1, 2, 4, 8)
)
diffusion = DPS(
    model,
    image_size = 128,
    timesteps = 1000,
    sampling_timesteps = sampling_timesteps,
    is_ddim_sampling=is_ddim_sampling
)
sampler = Sampler(diffusion)

In [None]:
base_path = 'pretrained_y/'
model_type = ['no_normalize/', 'z_score/', 'min_max/'][0]
model_path = os.path.join(base_path, model_type, 'model-15.pt')
sampler.load(path=model_path)

# Data

In [None]:
data_path = "/home/share/CARE/Isotropic_Liver/train_data/data_label.npz"
data_file = np.load(data_path)
lr_data = data_file['Y']

In [None]:
i = 0
one_data = torch.from_numpy(lr_data[i:i+1]).cuda()

In [None]:
# one_data = (one_data-one_data.mean())/one_data.std()
tensor_info(one_data)

# Operator

In [None]:
# operator = AnisotropicOperator(img_shape=one_data.shape, scale=(1,3), noise_sigma=0.01)

operator = GaussialBlurOperator(img_shape=one_data.shape)

# operator = DenoiseOperator()

# d_model_path = 'denoising_diffusion_pytorch/degradation_model/checkpoint/best.pth'
# operator = DLAnisotropicOperator(d_model_path, noise_sigma=0.01)

In [None]:
# one_data = one_data+torch.randn_like(one_data, device=one_data.device)*0.1

one_data = (one_data-one_data.min())/(one_data.max()-one_data.min())
one_data = operator.forward(one_data)


tensor_info(one_data)

# Inference

In [None]:
return_all_timesteps = False
scale = 15
num_samples = 16
out_folder = os.path.join('./dps_out/', model_type)

In [None]:
res = sampler.dps(measurement=one_data, operator=operator, num_samples=num_samples, scale=scale, return_all_timesteps=return_all_timesteps)
tensor_info(res)

In [None]:
sampler.save_tif(one_data, folder=os.path.join(out_folder, 'input.tif'))
if not return_all_timesteps:
    sampler.save_tif(res, folder=os.path.join(out_folder))
else:
    sampler.save_tif_with_records(res, folder=out_folder, step=10)

In [None]:
#