In [None]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '1'
import sys
sys.path.append('../')
import tqdm
import json
import pickle
import numpy as np
import torch
import dnnlib
from pathlib import Path
from torch_utils import distributed as dist
from torch_utils.misc import modify_network_pkl
from training.sampler import StackedRandomGenerator, samplers_to_kwargs
from training.structure import Structure, StructuredDataBatch
from training.networks.egnn import EGNNMultiHeadJump
from training.loss import JumpLossFinalDim
import matplotlib.pyplot as plt
import datetime
import yaml
from training.dataset.qm9 import plot_data3d
from training.dataset import datasets_to_kwargs
import time

model_path = Path('../models/unconditional')
device = 'cuda'
sampler_class = 'JumpSampler'

sampler_kwargs = {
    'dt': 0.001,
    'corrector_steps': 3,
    'corrector_snr': 0.1,
    'corrector_start_time': 1,
    'corrector_finish_time': 0.003,
    'do_conditioning': True,
    'condition_type': 'sweep',
    'condition_sweep_idx': 9, # to decide which conditioning task to do 0-9
    'condition_sweep_path': '../data/mol_conditions/mol_conds.npy',
    'guidance_weight': 1.0,
    'do_jump_corrector': True,
    'sample_near_atom': True,
    'dt_schedule': 'uniform',
    'dt_schedule_h': 0.05,
    'dt_schedule_l': 0.001,
    'dt_schedule_tc': 0.5,
    'no_noise_final_step': True,
}

with open(model_path.joinpath('training_options.json'), "r") as stream:
    c = dnnlib.util.EasyDict(json.load(stream))

def convert_inner_dicts_to_easydicts(input_dict):
    for key in input_dict.keys():
        if type(input_dict[key]) == dict:
            input_dict[key] = convert_inner_dicts_to_easydicts(input_dict[key])
    input_dict = dnnlib.util.EasyDict(input_dict)
    return input_dict

c = convert_inner_dicts_to_easydicts(c)

dataset_obj = dnnlib.util.construct_class_by_name(**c.dataset_kwargs, train_or_valid='valid')  # subclass of training.dataset.Dataset

structure = Structure(**c.structure_kwargs, dataset=dataset_obj)
net = dnnlib.util.construct_class_by_name(**c.network_kwargs, structure=structure) # subclass of torch.nn.Module
net.load_state_dict(torch.load(model_path.joinpath('state_dict_unconditional.pt')))
net = net.eval().requires_grad_(False).to(device)
# modify_network_pkl(net)

# Setup sampler
sampler_class_name = 'training.sampler.' + sampler_class
usable_sampler_kwargs = dnnlib.EasyDict(class_name=sampler_class_name)
for kwarg_name, _, _ in samplers_to_kwargs[sampler_class]:
    # new_kwarg_name = "_".join(kwarg_name.split("_")[1:])
    usable_sampler_kwargs[kwarg_name] = sampler_kwargs[kwarg_name]
sampler = dnnlib.util.construct_class_by_name(**usable_sampler_kwargs, structure=structure)

# infer the task from the dataset
dataset_class_name = c.dataset_kwargs['class_name'].split('training.dataset.')[1]
if dataset_class_name not in ['QM9Dataset']:
    raise ValueError('Unknown dataset: ', dataset_class_name)

del(c.loss_kwargs['class_name'])
loss = JumpLossFinalDim(**c.loss_kwargs, structure=structure)


In [None]:
batch_size = 32
seeds = torch.arange(batch_size)
rnd = StackedRandomGenerator(device, seeds)
indices = rnd.randint(len(dataset_obj), size=[batch_size, 1], device=device)
unstacked_data = [dataset_obj.__getitem__(i.item(), will_augment=False) for i in indices]
unstacked_data_no_dims = [d[1:] for d in unstacked_data]
dims = torch.tensor([d[0] for d in unstacked_data])
data = tuple(torch.stack([datum[t] for datum in unstacked_data_no_dims]).to(device) for t in range(len(unstacked_data_no_dims[0])))
st_batch = StructuredDataBatch(data, dims, structure.observed,
    structure.exist, dataset_obj.is_onehot, structure.graphical_structure
)
known_dims = None
x0_st_batch = sampler.sample(net, st_batch, loss, rnd, known_dims=known_dims,
                                dataset_obj=dataset_obj)

In [None]:
num_to_plot = min(batch_size, 8)
for idx in range(num_to_plot):
    num_atoms = x0_st_batch.get_dims()[idx].item()
    positions = x0_st_batch.tuple_batch[0][idx, 0:num_atoms, :].cpu().detach()
    atom_types = torch.argmax(x0_st_batch.tuple_batch[1][idx, 0:num_atoms, :], dim=1).cpu().detach()
    plot_data3d(positions, atom_types, dataset_obj.dataset_info, spheres_3d=False)
    plt.show()