In [None]:
import torch
from torchvision import utils

import math
import os

from denoising_diffusion_pytorch import Unet, GaussianDiffusion, Sampler, tensor_info, histc

from matplotlib import pyplot as plt

import numpy as np

# Model

In [None]:
# sampling_timesteps = 1000
# is_ddim_sampling = False

sampling_timesteps = 250
is_ddim_sampling = True

In [None]:
model = Unet(
    dim = 64,
    channels=1,
    dim_mults = (1, 2, 4, 8)
)
diffusion = GaussianDiffusion(
    model,
    image_size = 128,
    timesteps = 1000,
    sampling_timesteps = sampling_timesteps,
    loss_type = 'l1',
    is_ddim_sampling=is_ddim_sampling
)
sampler = Sampler(diffusion)

In [None]:
base_path = 'checkpoints/'
model_path = os.path.join(base_path, 'model-15.pt')
print(f'model path: {model_path}')

sampler.load(path=model_path)

# Sample

In [None]:
num_samples = 16
return_all_timesteps = False

In [None]:
res = sampler.sample(num_samples=num_samples, return_all_timesteps=return_all_timesteps)
tensor_info(res)

# Save

In [None]:
save_folder = os.path.join('sample_out/')
print(save_folder)

In [None]:
if not return_all_timesteps:
    sampler.save_tif(res, folder=save_folder)
    sampler.save_histc(res, save_path=save_folder, if_show=True)
else:
    sampler.save_tif_with_records(res, folder=os.path.join(save_folder, 'record/'), step=10)

In [None]:
#