# Prepare colab environment

In [1]:
import os
import sys

IN_COLAB = 'google.colab' in sys.modules
print(f'Running in google colab: {IN_COLAB}')

Running in google colab: False


In [2]:
if IN_COLAB:
    import gdown
    # Prepare colab environment
    !git clone https://github.com/Sleepon1805/PepeGeneration.git
    print('Successfully cloned PepeGeneration repo.')

    os.chdir("/content/PepeGeneration")
    print(f'Changed working directory to {os.getcwd()}')

    !pip install --upgrade pip
    !pip install -r requirements_demo.txt
    print('Successfully installed all requirements.')

    # Download checkpoint from shared Drive file
    drive_link_id = '13byaG4vybYpgWvdYo9NScNm0bbKLxnQY'
    gdown.download_folder(id=drive_link_id, quiet=False)
    print('Successfully downloaded folder with checkpoint.')

# Load checkpoint and config

In [4]:
# Specify path to model checkpoint and config here!!!
#####################################################
ckpt_path = '/home/sleepon/repos/PepeGenerator/lightning_logs/celeba/version_11/checkpoints/epoch=07-fid_metric=1.83-val_loss=0.0254.ckpt'
#####################################################

assert os.path.exists(ckpt_path), 'Did not found path with model checkpoint. Check that checkpoint is downloaded and path is correct.'
print('Specified path with model checkpoint exists.')
print(f'Model checkpoint path: {ckpt_path}')

Specified paths with model checkpoint and config exist.
Model checkpoint path: /home/sleepon/repos/PepeGenerator/lightning_logs/celeba/version_11/checkpoints/epoch=07-fid_metric=1.83-val_loss=0.0254.ckpt
Model config path: /home/sleepon/repos/PepeGenerator/lightning_logs/celeba/version_11/checkpoints/epoch=07-fid_metric=1.83-val_loss=0.0254.ckpt


# Setup Sampler

In [None]:
from config import DDPMSamplingConfig, PCSamplingConfig, ODESamplingConfig

""" default DDPM sampler """
ddpm_sampling_config =  DDPMSamplingConfig(
    # DDPM params
    beta_min = 0.0001,
    beta_max = 0.02,
    diffusion_steps = 1000,
)

""" Predictor-Corrector SDE Sampler """
pc_sampling_config = PCSamplingConfig(
    # sde params
    sde_name = 'VPSDE',  # one of VPSDE, subVPSDE, VESDE
    num_scales = 1000,  # number of discretization timesteps
    beta_min = 0.1,  # VPSDE, subVPSDE param
    beta_max = 20., # VPSDE, subVPSDE param
    sigma_min = 0.01,  # VESDE param
    sigma_max = 50.,  # VESDE param
    # predictor params
    predictor_name = 'euler_maruyama',  # none, ancestral_sampling, reverse_diffusion, euler_maruyama
    # corrector params
    corrector_name = 'langevin',  # none, langevin, ald
    snr = 0.01,  # signal-to-noise ratio
    num_corrector_steps = 1,
    # sampler params
    probability_flow = False,
    denoise = False,
)

""" ODE Solver """
ode_sampling_config = ODESamplingConfig(
    # sde params
    sde_name = 'VPSDE',  # VPSDE, subVPSDE, VESDE
    num_scales = 1000,  # number of discretization timesteps
    beta_min = 0.1,  # VPSDE, subVPSDE param
    beta_max = 20.,  # VPSDE, subVPSDE param
    sigma_min = 0.01,  # VESDE param
    sigma_max = 50.,  # VESDE param
    # ode solver params
    method = 'RK45',
    rtol = 1e-5,
    atol = 1e-5,
    # sampler params
    denoise = False,
)

# Generate images

In [None]:
from evaluate import inference

sampling_cfg = None  # None, ddpm_sampling_config, pc_sampling_config, ode_sampling_config

inference(
    ckpt_path,
    sampling_config=sampling_cfg,
    grid_shape=(3, 3),
    save_images=False, # saves generated images in checkpoint folder (and will override previous ones)
    on_gpu=True,  # whether to use GPU (requires GPU accelerator) or run on CPU (slower) 
)