In [None]:
import os
import pickle
import shutil
from pathlib import Path
from synplan.utils.loading import download_all_data

from CGRtools import smiles
from IPython.display import SVG, display
from synplan.utils.visualisation import get_route_svg
from synplan.mcts.tree import Tree
from synplan.mcts.evaluation import ValueNetworkFunction
from synplan.mcts.expansion import PolicyNetworkFunction
from synplan.utils.config import PolicyNetworkConfig
from synplan.utils.config import TreeConfig
from synplan.utils.loading import load_reaction_rules, load_building_blocks

import pandas as pd
from collections import Counter, defaultdict
from time import time

### Load data

In [None]:
tree_config = TreeConfig(
    search_strategy="expansion_first",
    evaluation_type="rollout",
    max_iterations=300,
    max_time=600,
    max_depth=9,
    min_mol_size=6,
    init_node_value=0.5,
    ucb_type="uct",
    c_ucb=0.1,
)

# input data
policy_network = "training_hybrid/ranking_policy_network/policy_network.ckpt"
reaction_rules = load_reaction_rules("training_hybrid/reaction_rules.pickle")

building_blocks = load_building_blocks("synplan_data/building_blocks/building_blocks_em_sa_ln.smi", standardize=False)

policy_config = PolicyNetworkConfig(weights_path=policy_network)
policy_function = PolicyNetworkFunction(policy_config=policy_config)

### Run tree building

In [None]:
OUTPUT_FILE = 'tree_list_hybrid.pickle'

In [None]:
target_list = list(pd.read_csv("synplan_data/benchmarks/sascore/targets_with_sascore_1.5_2.5.smi", header=None)[0])

In [None]:
tree_list = {}
for n, smi in enumerate(target_list[:]):
    
    # 1. Read target smiles
    target = smiles(smi)
    target.canonicalize()
    target.clean2d()
    
    # 2. Tree config
    tree_config.silent = True

    # 3. Init tree
    tree = Tree(
        target=target,
        config=tree_config,
        reaction_rules=reaction_rules,
        building_blocks=building_blocks,
        expansion_function=policy_function,
        evaluation_function=None,
    )

    try:
        _ = list(tree)
        
        tree._tqdm = None
        tree.reaction_rules = None
        tree.building_blocks = None
        tree.policy_network = None

        tree_list[smi] = tree
        
        with open(OUTPUT_FILE, 'wb') as f:
            pickle.dump(tree_list, f)
        
        print(f"{n} / {smi} / {len(tree.winning_nodes)}")
        
    except Exception as e:
        print(f"{n} ERROR -> {e}")