In [None]:
from denoising_diffusion_pytorch import Unet, Sampler, DPS, DenoiseOperator, AnisotropicOperator, DLAnisotropicOperator, GaussialBlurOperator, tensor_info, histc

import torch
from torch.cuda.amp import autocast as autocast
from torchvision import utils

import numpy as np
import os

# Data

In [None]:
data_path = "/home/share/CARE/Isotropic_Liver/train_data/data_label.npz"
data_file = np.load(data_path)
lr_data = data_file['Y']

In [None]:
one_data = torch.from_numpy(lr_data[0:1]).cuda()
tensor_info(one_data)

# model

In [None]:
model = Unet(
    dim = 64,
    channels=1,
    dim_mults = (1, 2, 4, 8)
)
diffusion = DPS(
    model,
    image_size = 128,
    timesteps = 1000,
    sampling_timesteps = 250,
)
sampler = Sampler(diffusion)
checkpoint_path = "/home/share_ssd/ryuuyou/denoising-diffusion/pretrained_y/model_150k_steps_lr1e-5.pt"
sampler.load(path=checkpoint_path)

# Operator

In [None]:
# operator = AnisotropicOperator(img_shape=one_data.shape, scale=(1,3), noise_sigma=0.01)

operator = GaussialBlurOperator(img_shape=one_data.shape)

# operator = DenoiseOperator()

# d_model_path = 'denoising_diffusion_pytorch/degradation_model/checkpoint/best.pth'
# operator = DLAnisotropicOperator(d_model_path, noise_sigma=0.01)

In [None]:
# one_data = one_data+torch.randn_like(one_data, device=one_data.device)*0.1

one_data = operator.forward(one_data)

one_data = (one_data-one_data.min())/(one_data.max()-one_data.min())

tensor_info(one_data)

# Inference

In [None]:
return_all_timesteps = False
scale = 5
num_samples = 16
res = sampler.dps(measurement=one_data, operator=operator, num_samples=num_samples, scale=scale, return_all_timesteps=return_all_timesteps)
tensor_info(res)

In [None]:
# out_folder = '/home/share_ssd/ryuuyou/dps_out'
out_folder = './dps_out/'
sampler.save_tif(one_data, path=os.path.join(out_folder, 'input.tif'))
if not return_all_timesteps:
    sampler.save_tif(res, path=os.path.join(out_folder, 'res.tif'))
else:
    sampler.save_tif_with_records(res, folder=os.path.join(out_folder, 'record/'), step=1)

In [None]:
#