In [1]:
%load_ext autoreload
%autoreload 2

import qm9.visualizer as vis
from eval_sample import *
import argparse
import warnings

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str,
                    default="outputs/edm_qm9",
                    help='Specify model path')
parser.add_argument(
    '--n_tries', type=int, default=10,
    help='N tries to find stable molecule for gif animation')
parser.add_argument('--n_nodes', type=int, default=19,
                    help='number of atoms in molecule for gif animation')

eval_args, unparsed_args = parser.parse_known_args()

assert eval_args.model_path is not None

print(join(eval_args.model_path, 'args.pickle'))

with open(join(eval_args.model_path, 'args.pickle'), 'rb') as f:
    args = pickle.load(f)

# CAREFUL with this -->
if not hasattr(args, 'normalization_factor'):
    args.normalization_factor = 1
if not hasattr(args, 'aggregation_method'):
    args.aggregation_method = 'sum'

args.cuda = not args.no_cuda and torch.cuda.is_available()
device = torch.device("cuda" if args.cuda else "cpu")
args.device = device
dtype = torch.float32
utils.create_folders(args)
print(args)

dataset_info = get_dataset_info(args.dataset, args.remove_h)

print(dataset_info)

dataloaders, charge_scale = dataset.retrieve_dataloaders(args)

flow, nodes_dist, prop_dist = get_model(
    args, device, dataset_info, dataloaders['train'])
flow.to(device)

fn = 'generative_model_ema.npy' if args.ema_decay > 0 else 'generative_model.npy'
flow_state_dict = torch.load(join(eval_args.model_path, fn),
                                map_location=device)

flow.load_state_dict(flow_state_dict)

In [None]:
print('Sampling handful of molecules.')
# sample_different_sizes_and_save(
#     args, eval_args, device, flow, nodes_dist,
#     dataset_info=dataset_info, n_samples=29)

print('Sampling stable molecules.')
sample_only_stable_different_sizes_and_save(
    args, eval_args, device, flow, nodes_dist,
    dataset_info=dataset_info, n_samples=3, n_tries=2*10)

In [40]:
vis.visualize(
    join(eval_args.model_path, 'eval/molecules/'), dataset_info,
    max_num=100, spheres_3d=True)

In [None]:
print('Sampling visualization chain.')
save_and_sample_chain(
    args, eval_args, device, flow,
    n_tries=eval_args.n_tries, n_nodes=eval_args.n_nodes,
    dataset_info=dataset_info)