In [1]:
import torch
from torchvision import utils
from torch.cuda.amp import autocast

import math
import os

from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Sampler, tensor_info, histc

from matplotlib import pyplot as plt

import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model = Unet(
    dim = 64,
    channels=1,
    dim_mults = (1, 2, 4, 8)
)


In [3]:
diffusion = GaussianDiffusion(
    model,
    image_size = 128,
    timesteps = 1000,
    sampling_timesteps = 250,
    loss_type = 'l1'
)

In [4]:
sampler = Sampler(diffusion)

In [5]:
path = 'pretrained_y/model_150k_steps_lr1e-5.pt'
sampler.load(path=path)

loading from: [version]:1.5.4; [step]:150000


In [6]:
save_folder = 'sample_out'
num_samples = 16
return_all_timesteps = True

In [7]:
res = sampler.sample(num_samples=num_samples, return_all_timesteps=return_all_timesteps)
tensor_info(res)

sampling loop time step: 100%|██████████| 250/250 [00:19<00:00, 12.83it/s]


shape: torch.Size([251, 16, 1, 128, 128]) 
type: torch.float32 
max: 2.910015106201172 
min: -1.734142780303955 
mean: 0.30561894178390503


In [8]:
# histc(res, os.path.join(save_folder, 'res_histc.png'))

if not return_all_timesteps:
    sampler.save_tif(res, path=os.path.join(save_folder, 'res.tif'))
else:
    sampler.save_tif_with_records(res, folder=os.path.join(save_folder, 'record/'), step=1)

In [9]:
with autocast():
    res_amp = sampler.sample(num_samples=num_samples, return_all_timesteps=return_all_timesteps)
tensor_info(res_amp)

sampling loop time step: 100%|██████████| 250/250 [00:14<00:00, 17.09it/s]

shape: torch.Size([251, 16, 1, 128, 128]) 
type: torch.float32 
max: 2.743675470352173 
min: -1.8430755138397217 
mean: 0.4030703902244568





In [10]:
# histc(res_amp, os.path.join(save_folder, 'res_amp_histc.png'))

if not return_all_timesteps:
    sampler.save_tif(res_amp, path=os.path.join(save_folder, 'res_amp.tif'))
else:
    sampler.save_tif_with_records(res_amp, folder=os.path.join(save_folder, 'record_amp/'), step=1)

In [11]:
#