# Prepare colab environment

In [1]:
import os
import sys

IN_COLAB = 'google.colab' in sys.modules
print(f'Running in google colab: {IN_COLAB}')

Running in google colab: False


In [2]:
if IN_COLAB:
    import gdown
    # Prepare colab environment
    !git clone https://github.com/Sleepon1805/PepeGeneration.git
    print('Successfully cloned PepeGeneration repo.')

    os.chdir("/content/PepeGeneration")
    print(f'Changed working directory to {os.getcwd()}')

    !pip install --upgrade pip
    !pip install -r requirements_demo.txt
    print('Successfully installed all requirements.')

    # Download checkpoint from shared Drive file
    drive_link_id = '13byaG4vybYpgWvdYo9NScNm0bbKLxnQY'
    gdown.download_folder(id=drive_link_id, quiet=False)
    print('Successfully downloaded checkpoint folder.')

# Load checkpoint and config

In [3]:
import torch
import pickle
from pprint import pprint
from matplotlib import pyplot as plt

from model.pepe_generator import PepeGenerator

DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using {DEVICE} device')

Using device cuda


In [4]:
# Specify path to model checkpoint and config here!!!
ckpt_path = '/home/sleepon/repos/PepeGenerator/lightning_logs/celeba/version_11/checkpoints/epoch=07-fid_metric=1.83-val_loss=0.0254.ckpt'
config_path = '/home/sleepon/repos/PepeGenerator/lightning_logs/celeba/version_11/config.pkl'

assert os.path.exists(ckpt_path), 'Did not found path with model checkpoint. Check that checkpoint is downloaded and path is correct.'
assert os.path.exists(config_path), 'Did not found path with model config. Check that config.pkl file is downloaded and path is correct.'
print('Specified paths with model checkpoint and config exist.')
print(f'Model checkpoint path: {ckpt_path}')
print(f'Model config path: {ckpt_path}')

Specified paths with model checkpoint and config exist.
Model checkpoint path: /home/sleepon/repos/PepeGenerator/lightning_logs/celeba/version_11/checkpoints/epoch=07-fid_metric=1.83-val_loss=0.0254.ckpt
Model config path: /home/sleepon/repos/PepeGenerator/lightning_logs/celeba/version_11/checkpoints/epoch=07-fid_metric=1.83-val_loss=0.0254.ckpt


In [5]:
# load config
with open(config_path, 'rb') as config_file:
    config = pickle.load(config_file)
pprint(config)

Config(git_hash='7616ce5+',
       sde_training=False,
       batch_size=64,
       image_size=64,
       lr=0.0001,
       scheduler='MultiStepLR',
       gradient_clip_algorithm='norm',
       gradient_clip_val=0.5,
       dataset_split=(0.8, 0.2),
       dataset_name='celeba',
       use_condition=False,
       condition_size=40,
       pretrained_ckpt='./lightning_logs/celeba/version_6/checkpoints/last.ckpt',
       diffusion_steps=1000,
       beta_min=0.0001,
       beta_max=0.02,
       init_channels=128,
       channel_mult=(1, 2, 4, 4),
       conv_resample=True,
       num_heads=1,
       dropout=0.3,
       use_second_attention=True)


In [6]:
# load checkpoint
model = PepeGenerator.load_from_checkpoint(ckpt_path, config=config, strict=True)
model.eval(), model.freeze(), model.to(DEVICE)
print('Loaded model')

Loaded model


# Setup Sampler

In [7]:
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn, BarColumn, TextColumn, TimeRemainingColumn, MofNCompleteColumn

from config import SamplingConfig
from SDE_sampling.sde_samplers import PC_Sampler, ODE_Sampler

def evaluate_model(sampling_config: SamplingConfig, grid_shape=(4, 4)):
    # rich progress bar
    progress = Progress(
        SpinnerColumn(),
        TextColumn("[progress.description]{task.description}"),
        BarColumn(),
        MofNCompleteColumn(),
        TimeElapsedColumn(),
        TimeRemainingColumn()
    )

    # set up sampler
    if sampling_config.sampler.lower() in ('ddpm', 'default'):
        print('Using default DDPM Sampler as evaluation sampler.')
    elif sampling_config.sampler.lower() == 'ode_solver':
        print('Using ODE Solver as evaluation sampler.')
        model.sampler = ODE_Sampler(config, sampling_config)
    elif sampling_config.sampler.lower() == 'pc_sampler':
        print('Using PC Sampler as evaluation sampler.')
        model.sampler = PC_Sampler(config, sampling_config)
    else:
        raise ValueError
    model.sampler.to(DEVICE)

    # create fake batch (actually needed only for shapes)
    num_samples = grid_shape[0] * grid_shape[1]
    fake_image_batch = torch.zeros((num_samples, 3, config.image_size, config.image_size))
    fake_cond_batch = torch.ones(num_samples, config.condition_size)
    fake_batch = (fake_image_batch, fake_cond_batch)

    # generate images
    with progress:
        # [grid_shape[0] * grid_shape[1] x 3 x cfg.image_size x cfg.image_size]
        gen_samples = model.generate_samples(fake_batch, progress=progress)
    gen_images = model.sampler.generated_samples_to_images(gen_samples, grid_shape)

    # show generated images
    plt.imshow(gen_images)
    plt.show()

In [None]:
# Choose values you want to try out
sampling_config = SamplingConfig(
    sampler = 'pc_sampler',  # ddpm = default, pc_sampler, ode_solver
    sde_name = 'VPSDE',  # VPSDE, subVPSDE, VESDE
    beta_min = 0.1,  # VPSDE, subVPSDE param
    beta_max = 20.,  # VPSDE, subVPSDE param
    sigma_min = 0.01,  # VESDE param
    sigma_max = 50.,  # VESDE param
    num_scales = 1000,
    predictor_name = 'euler_maruyama',  # none, ancestral_sampling, reverse_diffusion, euler_maruyama
    corrector_name = 'langevin',  # none, langevin, ald
    snr = 0.01,
    num_corrector_steps = 1,
    probability_flow = False,
    denoise = False,
)

evaluate_model(sampling_config=sampling_config, grid_shape=(4, 4))

Output()