In [1]:
from DPS import DDIM, DDPM, GaussianNoise, InverseProblemOperator, AnisotropicOperator, DenoiseOperator

from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Sampler

import torch
import numpy as np
from torch.cuda.amp import autocast as autocast

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
os.environ['TORCH_USE_CUDA_DSA'] = '1'

# model HR

In [3]:
model_hr = Unet(
    dim = 64,
    channels=1,
    dim_mults = (1, 2, 4, 8)
)

In [4]:
hr_checkpoint_path = "/home/share_ssd/ryuuyou/denoising-diffusion/unet_checkpoints/y_150k.pt"
model_hr.load_state_dict(torch.load(hr_checkpoint_path))

<All keys matched successfully>

# model LR

In [5]:
model_lr = Unet(
    dim = 64,
    channels=1,
    dim_mults = (1, 2, 4, 8)
)

In [6]:
diffusion_lr = GaussianDiffusion(
    model_hr,
    image_size = 128,
    timesteps = 1000,
    sampling_timesteps = 250,
    loss_type = 'l1'
)

# Data

In [7]:
data_path = "/home/share/CARE/Isotropic_Liver/train_data/data_label.npz"
data_file = np.load(data_path)
lr_data = data_file['X']

In [8]:
one_batch_data = torch.from_numpy(lr_data[0:16]).cuda()

In [9]:
one_batch_data.shape

torch.Size([16, 1, 128, 128])

# Inference

In [10]:
noiser = GaussianNoise(sigma=0.05)

In [11]:
lr_checkpoint_path = "pretrained_x/model_150k_steps_lr1e-5.pt"
operator = InverseProblemOperator(
    diffusion_model=diffusion_lr,
    checkpoint_path=lr_checkpoint_path
)
# operator = AnisotropicOperator()
operator = DenoiseOperator(device=one_batch_data.device)

loading from version 1.5.4


In [12]:
sampler = DDIM(model_hr=model_hr, operator=operator, noiser=noiser, 
               timesteps=1000, ddim_steps=250,
               beta_schedule='sigmoid')
# sampler = DDPM(model_hr=model_hr, operator=operator, noiser=noiser, 
#                timesteps=900,
#                beta_schedule='sigmoid')

In [13]:
# y = operator.forward(one_batch_data, t=999)

In [14]:
y_n = noiser.forward(y)

In [15]:
y.dtype

torch.float32

In [16]:
x_start = torch.randn(one_batch_data.shape, device=one_batch_data.device)

In [17]:
with autocast():
    res = sampler.p_sample_loop(x_start=x_start, measurement=y_n, record=False, save_root='./out/')

100%|██████████| 250/250 [00:38<00:00,  6.50it/s]


In [18]:
from torchvision import utils

In [19]:
utils.save_image(res, './res1.png', nrow=4)

In [20]:
utils.save_image(one_batch_data, './res2.png', nrow=4)