In [None]:
import yaml
import numpy as np
import hydra
import torch
from tqdm import tqdm
from omegaconf import OmegaConf
from pytorch_lightning import (
    LightningDataModule,
    seed_everything,
)

from multimodal_contrastive.analysis.utils import *
from multimodal_contrastive.utils import utils
from gflownet.config import Config, init_from_dict
from gflownet.data.data_source import DataSource
from gflownet.tasks.morph_frag import MorphSimilarityTrainer

# register custom resolvers if not already registered
OmegaConf.register_new_resolver("sum", lambda input_list: np.sum(input_list), replace=True)

device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
## TODO: Replace with paths to run directory and gmc checkpoint
run_dir = "path.to/gfn_run_dir"
ckpt_path = "path.to/gmc_proxy.ckpt"

In [3]:
def load_trainer(run_dir):
    # Setup path to saved run configuration
    with open(run_dir + "/config.yaml", "r") as f:
        run_config = yaml.load(f, Loader=yaml.FullLoader)

    config = init_from_dict(Config(), run_config)
    config.log_dir = "~/your_log_dir"
    config.overwrite_existing_exp = True

    # Load gflownet model trainer setup for sampling
    trainer = MorphSimilarityTrainer(config)

    model_state = torch.load(f'{run_dir}/model_state.pt')
    trainer.model.load_state_dict(model_state['models_state_dict'][0])
    trainer.model = trainer.model.cuda()
    
    return trainer

def sample_from_trained_model(run_dir, n_samples, n_samples_per_iter = 64, trainer=None):
    if trainer is None:
        trainer = load_trainer(run_dir)

    # Sample trajectories using the trained model
    n_iterations = n_samples // n_samples_per_iter
    src = DataSource(trainer.cfg, trainer.ctx, trainer.algo, trainer.task, replay_buffer=None)
    
    samples = []
    with torch.no_grad():
        for t in range(n_iterations):
            p = 0.01
            cond_info = trainer.task.sample_conditional_information(n_samples_per_iter, t)
            trajs = trainer.algo.create_training_data_from_own_samples(trainer.model, n_samples_per_iter, cond_info["encoding"], p)
            src.set_traj_cond_info(trajs, cond_info)
            src.compute_properties(trajs, mark_as_online=True)
            src.compute_log_rewards(trajs)
            
            for traj in trajs:
                if traj['mol'] is None:
                    continue
                mol = traj['mol']
                reward = traj['flat_rewards']
                samples.append((mol, reward))

    return samples
        

## Load the GMC Proxy Model

In [None]:
# Load config for GMC proxy model
config_name = "puma_sm_gmc"
configs_path = "../../configs"

with hydra.initialize(version_base=None, config_path=configs_path):
    cfg = hydra.compose(config_name=config_name)

print(cfg.datamodule.split_type)

# Set seed for random number generators in pytorch, numpy and python.random
# and especially for generating the same data splits for the test set
if cfg.get("seed"):
    seed_everything(cfg.seed, workers=True)

# Load model from checkpoint
model = utils.instantiate_model(cfg)
model = model.load_from_checkpoint(ckpt_path, map_location=device)
model = model.eval()

## Sample from the GFlowNet

In [None]:
n_samples = 1000
samples = sample_from_trained_model(run_dir, n_samples, 128)
rewards = [reward.item() for _, reward in samples]