In [1]:
from copy import deepcopy
from typing import Callable, Union
from functools import partial

import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.probabilistic import set_interaction_type as set_exploration_type
from tensordict.tensordict import TensorDictBase
from torchrl.collectors import RandomPolicy
from torchrl.envs import EnvBase, TensorDictPrimer, TransformedEnv, CatFrames, InitTracker, StepCounter, UnsqueezeTransform
from torchrl.envs.utils import ExplorationType, step_mdp


from acegen.rl_env.smiles_env import SMILESEnv
from acegen.vocabulary import SMILESVocabulary, SMILESTokenizer, SMILESTokenizer2
from acegen.data.utils import smiles_to_tensordict
from acegen.models import create_gru_actor, adapt_state_dict, create_gpt2_actor

from promptsmiles import ScaffoldDecorator, FragmentLinker
_has_promptsmiles = True

from acegen.rl_env.utils import generate_complete_smiles

  from .autonotebook import tqdm as notebook_tqdm


In [1]:
from acegen.vocabulary import tokenizers

for tokenizer in [tokenizers.SMILESTokenizer, tokenizers.SMILESTokenizer2, tokenizers.DeepSMILESTokenizer, tokenizers.AISTokenizer, tokenizers.SAFETokenizer, tokenizers.SELFIESTokenizer]:
    print(tokenizer)
    smiles = "CC1=C(C(=O)N2CCCCC2=N1)CCN3CCC(CC3)C4=NOC5=C4C=CC(=C5)F"
    TOK = tokenizer()
    tokens = TOK.tokenize(smiles)
    print(tokens)
    untokens = TOK.untokenize(tokens)
    print(untokens)

  from .autonotebook import tqdm as notebook_tqdm


<class 'acegen.vocabulary.tokenizers.SMILESTokenizer'>
['C', 'C', '1', '=', 'C', '(', 'C', '(', '=', 'O', ')', 'N', '2', 'C', 'C', 'C', 'C', 'C', '2', '=', 'N', '1', ')', 'C', 'C', 'N', '3', 'C', 'C', 'C', '(', 'C', 'C', '3', ')', 'C', '4', '=', 'N', 'O', 'C', '5', '=', 'C', '4', 'C', '=', 'C', 'C', '(', '=', 'C', '5', ')', 'F']
CC1=C(C(=O)N2CCCCC2=N1)CCN3CCC(CC3)C4=NOC5=C4C=CC(=C5)F
<class 'acegen.vocabulary.tokenizers.SMILESTokenizer2'>
['C', 'C', '1', '=', 'C', '(', 'C', '(', '=', 'O', ')', 'N', '2', 'C', 'C', 'C', 'C', 'C', '2', '=', 'N', '1', ')', 'C', 'C', 'N', '3', 'C', 'C', 'C', '(', 'C', 'C', '3', ')', 'C', '4', '=', 'N', 'O', 'C', '5', '=', 'C', '4', 'C', '=', 'C', 'C', '(', '=', 'C', '5', ')', 'F']
CC1=C(C(=O)N2CCCCC2=N1)CCN3CCC(CC3)C4=NOC5=C4C=CC(=C5)F
<class 'acegen.vocabulary.tokenizers.DeepSMILESTokenizer'>
['C', 'C', '=', 'C', 'C', '=', 'O', ')', 'N', 'C', 'C', 'C', 'C', 'C', '6', '=', 'N', '%', '1', '0', ')', ')', ')', ')', ')', ')', ')', ')', 'C', 'C', 'N', 'C', 'C', 

In [5]:
output_data = generate_complete_smiles(
    environment=env,
    vocabulary=vocabulary,
    policy_sample=actor_inference,
    policy_evaluate=actor_training,
    prompt=None, #"c%11ccccc%11", # "c1ccccc1C(=O)",
    promptsmiles="N=C(N)c1c(*)cc2[nH]c(-c3ccc(Cn4cc(COc5ccc(-c6c(*)cc(OCc7cn(Cc8ccc(-c9nc%10cc(C(=N)N)ccc%10[nH]9)o8)nn7)cc6)cc5)nn4)o3)nc2c1", #"c%12ccc(*)cc%12", # "N1(*)CCN(CC1)CCCCN(*)", # fragment "N1(*)CCNCC1.C1(*)CC1.c1cncc(*)c1", # scaffold "N1(*)CCN(CC1)CCCCN(*)"
    promptsmiles_optimize=False,
    promptsmiles_shuffle=True,
    promptsmiles_multi=False,
    return_smiles_only=False
)
if isinstance(output_data, list):
    _ = [print(s) for s in output_data]
else:
    print(output_data)
    smiles = output_data.get("action").cpu()
    smiles_str = [vocabulary.decode(smi.numpy()) for smi in smiles]
    _ = [print(s) for s in smiles_str]

> [0;32m/shared/morgan/acegen-open/acegen/rl_env/utils.py[0m(268)[0;36mgenerate_complete_smiles[0;34m()[0m
[0;32m    266 [0;31m        [0;32mif[0m [0mprompt[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    267 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 268 [0;31m            [0;32mif[0m [0misinstance[0m[0;34m([0m[0mprompt[0m[0;34m,[0m [0mstr[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    269 [0;31m                [0mprompt[0m [0;34m=[0m [0;34m[[0m[0mprompt[0m[0;34m][0m [0;34m*[0m [0mbatch_size[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    270 [0;31m            [0;31m# Encode the prompt(s)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/shared/morgan/acegen-open/acegen/rl_env/utils.py[0m(355)[0;36mgenerate_complete_smiles[0;34m()[0m
[0;32m    353 [0;31m            [0mtensordicts[0m[0;34m[[0m[0;34m-[0m[0;36m1[0m[0;34m][0m[0;34m[[0m[0;34m([0m[0;34m"next"[0m[0;34m,[0m [0;34m"done"[0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0;34m~[0m[0mfinished[0m[0;34m.[0m[0mclone[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    354 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 355 [0;31m        [0moutput_data[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mstack[0m[0;34m([0m[0mtensordicts[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mcontiguous[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    356 [0;31m        [0moutput_data[0m[0;34m.[0m[0mrefine_names[0m[0;34m([0m[0;34m...[0m[0;34m

ipdb>  c


> [0;32m/shared/morgan/acegen-open/acegen/rl_env/utils.py[0m(360)[0;36mgenerate_complete_smiles[0;34m()[0m
[0;32m    358 [0;31m    [0;32mif[0m [0mreturn_smiles_only[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    359 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 360 [0;31m        [0msmiles[0m [0;34m=[0m [0moutput_data[0m[0;34m.[0m[0mget[0m[0;34m([0m[0;34m"action"[0m[0;34m)[0m[0;34m.[0m[0mcpu[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    361 [0;31m        [0msmiles_str[0m [0;34m=[0m [0;34m[[0m[0mvocabulary[0m[0;34m.[0m[0mdecode[0m[0;34m([0m[0msmi[0m[0;34m.[0m[0mnumpy[0m[0;34m([0m[0;34m)[0m[0;34m)[0m [0;32mfor[0m [0msmi[0m [0;32min[0m [0msmiles[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    362 [0;31m        [0;31m# Replace failed encodings with original prompt for P

ipdb>  c


> [0;32m/shared/morgan/acegen-open/acegen/rl_env/utils.py[0m(268)[0;36mgenerate_complete_smiles[0;34m()[0m
[0;32m    266 [0;31m        [0;32mif[0m [0mprompt[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    267 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 268 [0;31m            [0;32mif[0m [0misinstance[0m[0;34m([0m[0mprompt[0m[0;34m,[0m [0mstr[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    269 [0;31m                [0mprompt[0m [0;34m=[0m [0;34m[[0m[0mprompt[0m[0;34m][0m [0;34m*[0m [0mbatch_size[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    270 [0;31m            [0;31m# Encode the prompt(s)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/shared/morgan/acegen-open/acegen/rl_env/utils.py[0m(355)[0;36mgenerate_complete_smiles[0;34m()[0m
[0;32m    353 [0;31m            [0mtensordicts[0m[0;34m[[0m[0;34m-[0m[0;36m1[0m[0;34m][0m[0;34m[[0m[0;34m([0m[0;34m"next"[0m[0;34m,[0m [0;34m"done"[0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0;34m~[0m[0mfinished[0m[0;34m.[0m[0mclone[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    354 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 355 [0;31m        [0moutput_data[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mstack[0m[0;34m([0m[0mtensordicts[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mcontiguous[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    356 [0;31m        [0moutput_data[0m[0;34m.[0m[0mrefine_names[0m[0;34m([0m[0;34m...[0m[0;34m

ipdb>  c


> [0;32m/shared/morgan/acegen-open/acegen/rl_env/utils.py[0m(360)[0;36mgenerate_complete_smiles[0;34m()[0m
[0;32m    358 [0;31m    [0;32mif[0m [0mreturn_smiles_only[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    359 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 360 [0;31m        [0msmiles[0m [0;34m=[0m [0moutput_data[0m[0;34m.[0m[0mget[0m[0;34m([0m[0;34m"action"[0m[0;34m)[0m[0;34m.[0m[0mcpu[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    361 [0;31m        [0msmiles_str[0m [0;34m=[0m [0;34m[[0m[0mvocabulary[0m[0;34m.[0m[0mdecode[0m[0;34m([0m[0msmi[0m[0;34m.[0m[0mnumpy[0m[0;34m([0m[0;34m)[0m[0;34m)[0m [0;32mfor[0m [0msmi[0m [0;32min[0m [0msmiles[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    362 [0;31m        [0;31m# Replace failed encodings with original prompt for P

ipdb>  c


TensorDict(
    fields={
        action: Tensor(shape=torch.Size([2, 200]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([2, 200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        is_init: Tensor(shape=torch.Size([2, 200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        mask: Tensor(shape=torch.Size([2, 200]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                is_init: Tensor(shape=torch.Size([2, 200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([2, 200]), device=cpu, dtype=torch.int32, is_shared=False),
                reward: Tensor(shape=torch.Size([2, 200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                sequence: Tensor(shape=torch.Size([2, 200]), device=cpu, dtype=to

In [4]:
# Load prereqs
smiles = [
    'c1ccccc1',
    'c1ccccc1C(=O)C'
]

# Get available device
device = torch.device("cpu")
create_actor, voc_path, ckpt_path, tokenizer = (create_gru_actor, "priors/chembl_filtered_vocabulary.txt", "priors/gru_chembl_filtered.ckpt", SMILESTokenizer())
#create_actor, voc_path, ckpt_path, tokenizer = (create_gpt2_actor, "priors/enamine_real_vocabulary.txt","/shared/albert/acegen-open/priors/gpt2_enamine_real.ckpt", SMILESTokenizer2())
# Load vocabulary
with open(voc_path, "r") as f:
        tokens = f.read().splitlines()
tokens_dict = dict(zip(tokens, range(len(tokens))))
vocabulary = SMILESVocabulary.create_from_dict(tokens_dict, start_token="GO", end_token="EOS", tokenizer=tokenizer)
# Create models
ckpt = torch.load(ckpt_path)
actor_training, actor_inference = create_actor(vocabulary_size=len(vocabulary))
actor_inference.load_state_dict(
    adapt_state_dict(ckpt, actor_inference.state_dict())
)
actor_training.load_state_dict(adapt_state_dict(ckpt, actor_training.state_dict()))
actor_inference = actor_inference.to(device)
actor_training = actor_training.to(device)
prior = deepcopy(actor_training)
 # Create RL environment
# For RNNs, create a transform to populate initial tensordict with recurrent states equal to 0.0
rhs_primers = []
if hasattr(actor_training, "rnn_spec"):
    primers = actor_training.rnn_spec.expand(2)
    rhs_primers.append(TensorDictPrimer(primers))
env_kwargs = {
    "start_token": vocabulary.start_token_index,
    "end_token": vocabulary.end_token_index,
    "length_vocabulary": len(vocabulary),
    "batch_size": 2,
    "max_length": 200,
    "device": device,
}
def create_env_fn():
    """Create a single RL rl_env."""
    env = SMILESEnv(**env_kwargs)
    env = TransformedEnv(env)
    env.append_transform(StepCounter())
    env.append_transform(InitTracker())
    for rhs_primer in rhs_primers:
        env.append_transform(rhs_primer)
    return env

env = create_env_fn()
#data = generate_complete_smiles(policy=actor_inference, environment=env)