In [None]:
import os
import time
import torch
import math
import yaml

import numpy as np
from torchvision.utils import make_grid, save_image
from utils import load_ckpt, linear_schedule

from default_config import get_cfg_defaults

from model import GNM
torch.backends.cudnn.benchmark=True

In [None]:
# exp_name = 'arrow'

# exp_name = 'mnist-4'

exp_name = 'mnist-10'

In [None]:
def get_config():
    cfg = get_cfg_defaults()
    cfg.merge_from_file(f'./config/{exp_name}.yaml')
    return cfg

In [None]:
args = get_config()

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = GNM(args)
model.to(device)

global_step, args.train.start_epoch = \
            load_ckpt(model, None, f"./pretrained/{exp_name}.pth", device)

In [None]:
def hyperparam_anneal(args, global_step):
    if args.train.beta_aux_pres_anneal_end_step == 0:
        args.train.beta_aux_pres = args.train.beta_aux_pres_anneal_start_value
    else:
        args.train.beta_aux_pres = linear_schedule(
            global_step,
            args.train.beta_aux_pres_anneal_start_step,
            args.train.beta_aux_pres_anneal_end_step,
            args.train.beta_aux_pres_anneal_start_value,
            args.train.beta_aux_pres_anneal_end_value
        )

    if args.train.beta_aux_where_anneal_end_step == 0:
        args.train.beta_aux_where = args.train.beta_aux_where_anneal_start_value
    else:
        args.train.beta_aux_where = linear_schedule(
            global_step,
            args.train.beta_aux_where_anneal_start_step,
            args.train.beta_aux_where_anneal_end_step,
            args.train.beta_aux_where_anneal_start_value,
            args.train.beta_aux_where_anneal_end_value
        )

    if args.train.beta_aux_what_anneal_end_step == 0:
        args.train.beta_aux_what = args.train.beta_aux_what_anneal_start_value
    else:
        args.train.beta_aux_what = linear_schedule(
            global_step,
            args.train.beta_aux_what_anneal_start_step,
            args.train.beta_aux_what_anneal_end_step,
            args.train.beta_aux_what_anneal_start_value,
            args.train.beta_aux_what_anneal_end_value
        )

    if args.train.beta_aux_depth_anneal_end_step == 0:
        args.train.beta_aux_depth = args.train.beta_aux_depth_anneal_start_value
    else:
        args.train.beta_aux_depth = linear_schedule(
            global_step,
            args.train.beta_aux_depth_anneal_start_step,
            args.train.beta_aux_depth_anneal_end_step,
            args.train.beta_aux_depth_anneal_start_value,
            args.train.beta_aux_depth_anneal_end_value
        )

    if args.train.beta_aux_global_anneal_end_step == 0:
        args.train.beta_aux_global = args.train.beta_aux_global_anneal_start_value
    else:
        args.train.beta_aux_global = linear_schedule(
            global_step,
            args.train.beta_aux_global_anneal_start_step,
            args.train.beta_aux_global_anneal_end_step,
            args.train.beta_aux_global_anneal_start_value,
            args.train.beta_aux_global_anneal_end_value
        )

    if args.train.beta_aux_bg_anneal_end_step == 0:
        args.train.beta_aux_bg = args.train.beta_aux_bg_anneal_start_value
    else:
        args.train.beta_aux_bg = linear_schedule(
            global_step,
            args.train.beta_aux_bg_anneal_start_step,
            args.train.beta_aux_bg_anneal_end_step,
            args.train.beta_aux_bg_anneal_start_value,
            args.train.beta_aux_bg_anneal_end_value
        )

    ########################### split here ###########################
    if args.train.beta_pres_anneal_end_step == 0:
        args.train.beta_pres = args.train.beta_pres_anneal_start_value
    else:
        args.train.beta_pres = linear_schedule(
            global_step,
            args.train.beta_pres_anneal_start_step,
            args.train.beta_pres_anneal_end_step,
            args.train.beta_pres_anneal_start_value,
            args.train.beta_pres_anneal_end_value
        )

    if args.train.beta_where_anneal_end_step == 0:
        args.train.beta_where = args.train.beta_where_anneal_start_value
    else:
        args.train.beta_where = linear_schedule(
            global_step,
            args.train.beta_where_anneal_start_step,
            args.train.beta_where_anneal_end_step,
            args.train.beta_where_anneal_start_value,
            args.train.beta_where_anneal_end_value
        )

    if args.train.beta_what_anneal_end_step == 0:
        args.train.beta_what = args.train.beta_what_anneal_start_value
    else:
        args.train.beta_what = linear_schedule(
            global_step,
            args.train.beta_what_anneal_start_step,
            args.train.beta_what_anneal_end_step,
            args.train.beta_what_anneal_start_value,
            args.train.beta_what_anneal_end_value
        )

    if args.train.beta_depth_anneal_end_step == 0:
        args.train.beta_depth = args.train.beta_depth_anneal_start_value
    else:
        args.train.beta_depth = linear_schedule(
            global_step,
            args.train.beta_depth_anneal_start_step,
            args.train.beta_depth_anneal_end_step,
            args.train.beta_depth_anneal_start_value,
            args.train.beta_depth_anneal_end_value
        )

    if args.train.beta_global_anneal_end_step == 0:
        args.train.beta_global = args.train.beta_global_anneal_start_value
    else:
        args.train.beta_global = linear_schedule(
            global_step,
            args.train.beta_global_anneal_start_step,
            args.train.beta_global_anneal_end_step,
            args.train.beta_global_anneal_start_value,
            args.train.beta_global_anneal_end_value
        )

    if args.train.tau_pres_anneal_end_step == 0:
        args.train.tau_pres = args.train.tau_pres_anneal_start_value
    else:
        args.train.tau_pres = linear_schedule(
            global_step,
            args.train.tau_pres_anneal_start_step,
            args.train.tau_pres_anneal_end_step,
            args.train.tau_pres_anneal_start_value,
            args.train.tau_pres_anneal_end_value
        )

    if args.train.beta_bg_anneal_end_step == 0:
        args.train.beta_bg = args.train.beta_bg_anneal_start_value
    else:
        args.train.beta_bg = linear_schedule(
            global_step,
            args.train.beta_bg_anneal_start_step,
            args.train.beta_bg_anneal_end_step,
            args.train.beta_bg_anneal_start_value,
            args.train.beta_bg_anneal_end_value
        )

    return

In [None]:
args.train.global_step = global_step
hyperparam_anneal(args, global_step)
args.log.phase_log = False

In [None]:
result_dir = f'./generations-new/{exp_name}/'
if not os.path.isdir(result_dir):
    os.makedirs(result_dir)
    
with torch.no_grad():
    model.eval()
    for i in range(5):
        sample = model.sample(phase_use_mode=True)[0]
        save_image(sample[0].cpu().detach().clamp(0, 1), os.path.join(result_dir, f'{i}.png'), 5, normalize=False, pad_value=1)