In [1]:
from copy import deepcopy
from typing import Callable, Union
from functools import partial

import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.probabilistic import set_interaction_type as set_exploration_type
from tensordict.tensordict import TensorDictBase
from torchrl.collectors import RandomPolicy
from torchrl.envs import EnvBase, TensorDictPrimer, TransformedEnv, CatFrames, InitTracker, StepCounter, UnsqueezeTransform
from torchrl.envs.utils import ExplorationType, step_mdp


from acegen.rl_env.smiles_env import SMILESEnv
from acegen.vocabulary import SMILESVocabulary, SMILESTokenizer, SMILESTokenizer2
from acegen.data.utils import smiles_to_tensordict
from acegen.models import create_gru_actor, adapt_state_dict, create_gpt2_actor

from promptsmiles import ScaffoldDecorator, FragmentLinker
_has_promptsmiles = True

from acegen.rl_env.utils import generate_complete_smiles

  from .autonotebook import tqdm as notebook_tqdm


In [1]:
from acegen.vocabulary import tokenizers

for tokenizer in [tokenizers.SMILESTokenizer, tokenizers.SMILESTokenizer2, tokenizers.DeepSMILESTokenizer, tokenizers.AISTokenizer, tokenizers.SAFETokenizer, tokenizers.SELFIESTokenizer]:
    print(tokenizer)
    smiles = "CC1=C(C(=O)N2CCCCC2=N1)CCN3CCC(CC3)C4=NOC5=C4C=CC(=C5)F"
    TOK = tokenizer()
    tokens = TOK.tokenize(smiles)
    print(tokens)
    untokens = TOK.untokenize(tokens)
    print(untokens)

  from .autonotebook import tqdm as notebook_tqdm


<class 'acegen.vocabulary.tokenizers.SMILESTokenizer'>
['C', 'C', '1', '=', 'C', '(', 'C', '(', '=', 'O', ')', 'N', '2', 'C', 'C', 'C', 'C', 'C', '2', '=', 'N', '1', ')', 'C', 'C', 'N', '3', 'C', 'C', 'C', '(', 'C', 'C', '3', ')', 'C', '4', '=', 'N', 'O', 'C', '5', '=', 'C', '4', 'C', '=', 'C', 'C', '(', '=', 'C', '5', ')', 'F']
CC1=C(C(=O)N2CCCCC2=N1)CCN3CCC(CC3)C4=NOC5=C4C=CC(=C5)F
<class 'acegen.vocabulary.tokenizers.SMILESTokenizer2'>
['C', 'C', '1', '=', 'C', '(', 'C', '(', '=', 'O', ')', 'N', '2', 'C', 'C', 'C', 'C', 'C', '2', '=', 'N', '1', ')', 'C', 'C', 'N', '3', 'C', 'C', 'C', '(', 'C', 'C', '3', ')', 'C', '4', '=', 'N', 'O', 'C', '5', '=', 'C', '4', 'C', '=', 'C', 'C', '(', '=', 'C', '5', ')', 'F']
CC1=C(C(=O)N2CCCCC2=N1)CCN3CCC(CC3)C4=NOC5=C4C=CC(=C5)F
<class 'acegen.vocabulary.tokenizers.DeepSMILESTokenizer'>
['C', 'C', '=', 'C', 'C', '=', 'O', ')', 'N', 'C', 'C', 'C', 'C', 'C', '6', '=', 'N', '%', '1', '0', ')', ')', ')', ')', ')', ')', ')', ')', 'C', 'C', 'N', 'C', 'C', 

In [3]:
output_data = generate_complete_smiles(
    environment=env,
    vocabulary=vocabulary,
    policy_sample=actor_inference,
    policy_evaluate=actor_training,
    prompt=None, # "c1ccccc1C(=O)",
    promptsmiles= "N1(*)CCNCC1.C1(*)CC1.c1cncc(*)c1", # fragment "N1(*)CCNCC1.C1(*)CC1.c1cncc(*)c1", # scaffold "N1(*)CCN(CC1)CCCCN(*)"
    promptsmiles_optimize=False,
    promptsmiles_shuffle=True,
    promptsmiles_multi=True,
    return_smiles_only=False
)
if isinstance(output_data, list):
    _ = [print(s) for s in output_data]
else:
    print(output_data)
    smiles = output_data.get("action").cpu()
    smiles_str = [vocabulary.decode(smi.numpy()) for smi in smiles]
    _ = [print(s) for s in smiles_str]

Scan must be used for more than two fragments, Scan will be enabled.


TensorDict(
    fields={
        action: Tensor(shape=torch.Size([192, 100]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([192, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        is_init: Tensor(shape=torch.Size([192, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        mask: Tensor(shape=torch.Size([192, 100]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([192, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                is_init: Tensor(shape=torch.Size([192, 100, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([192, 100]), device=cpu, dtype=torch.int32, is_shared=False),
                reward: Tensor(shape=torch.Size([192, 100, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                sequence: Tensor(shape=torch.Size([192, 100]), de

In [2]:
# Load prereqs
smiles = [
    'c1ccccc1',
    'c1ccccc1C(=O)C'
]

# Get available device
device = torch.device("cpu")
create_actor, voc_path, ckpt_path, tokenizer = (create_gru_actor, "priors/chembl_filtered_vocabulary.txt", "priors/gru_chembl_filtered.ckpt", SMILESTokenizer())
#create_actor, voc_path, ckpt_path, tokenizer = (create_gpt2_actor, "priors/enamine_real_vocabulary.txt","/shared/albert/acegen-open/priors/gpt2_enamine_real.ckpt", SMILESTokenizer2())
# Load vocabulary
with open(voc_path, "r") as f:
        tokens = f.read().splitlines()
tokens_dict = dict(zip(tokens, range(len(tokens))))
vocabulary = SMILESVocabulary.create_from_dict(tokens_dict, start_token="GO", end_token="EOS", tokenizer=tokenizer)
# Create models
ckpt = torch.load(ckpt_path)
actor_training, actor_inference = create_actor(vocabulary_size=len(vocabulary))
actor_inference.load_state_dict(
    adapt_state_dict(ckpt, actor_inference.state_dict())
)
actor_training.load_state_dict(adapt_state_dict(ckpt, actor_training.state_dict()))
actor_inference = actor_inference.to(device)
actor_training = actor_training.to(device)
prior = deepcopy(actor_training)
 # Create RL environment
# For RNNs, create a transform to populate initial tensordict with recurrent states equal to 0.0
rhs_primers = []
if hasattr(actor_training, "rnn_spec"):
    primers = actor_training.rnn_spec.expand(64)
    rhs_primers.append(TensorDictPrimer(primers))
env_kwargs = {
    "start_token": vocabulary.start_token_index,
    "end_token": vocabulary.end_token_index,
    "length_vocabulary": len(vocabulary),
    "batch_size": 64,
    "device": device,
}
def create_env_fn():
    """Create a single RL rl_env."""
    env = SMILESEnv(**env_kwargs)
    env = TransformedEnv(env)
    env.append_transform(StepCounter())
    env.append_transform(InitTracker())
    for rhs_primer in rhs_primers:
        env.append_transform(rhs_primer)
    return env

env = create_env_fn()
#data = generate_complete_smiles(policy=actor_inference, environment=env)