In [None]:
import yaml
import os
import torch

from models.unet import Unet
from models.vae import VAE
from utils.scheduler import LinearNoiseSchedule
from utils.sample_ddpm import sample

DDPM_CONFIG = "configs/ddpm.yaml"
VAE_CONFIG = "configs/vae.yaml"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

DDPM_PATH = ""
VAE_PATH = ""

In [None]:
with open(DDPM_CONFIG, 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)
ddpm_model_config = config['model_config']
ddpm_dataset_config = config['dataset_config']
ddpm_training_config = config['training_config']
ddpm_inference_config = config['inference_config']

with open(VAE_CONFIG, 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)
vae_model_config = config['model_config']

In [None]:
# Create the noise scheduler
scheduler = LinearNoiseSchedule(num_timesteps=ddpm_training_config['NUM_TIMESTEPS'])

model = Unet(im_channels = vae_model_config['Z_CHANNELS'], model_config = ddpm_model_config).to(DEVICE)
model.eval()

if os.path.exists(): #path of the pretrained DDPM goes here
    print('Loaded unet checkpoint')
    model.load_state_dict(torch.load(DDPM_PATH,map_location=DEVICE))

# Create output directories
if not os.path.exists():
    os.mkdir()

vae = VAE(model_config=vae_model_config).to(DEVICE)
vae.eval()

# Load vae if found
if os.path.exists(): #path of pretrained VAE
    print('Loaded vae checkpoint')
    vae.load_state_dict(torch.load(VAE_PATH,map_location=DEVICE), strict=True)

In [None]:
with torch.no_grad():
    sample(model, vae, scheduler, ddpm_inference_config['NUM_SAMPLES'], ddpm_inference_config['NROWS'])