In [1]:
from copy import deepcopy
from typing import Callable, Union
from functools import partial

import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.probabilistic import set_interaction_type as set_exploration_type
from tensordict.tensordict import TensorDictBase
from torchrl.collectors import RandomPolicy
from torchrl.envs import EnvBase, TensorDictPrimer, TransformedEnv, CatFrames, InitTracker, StepCounter, UnsqueezeTransform
from torchrl.envs.utils import ExplorationType, step_mdp
from torchrl.modules.utils import get_primers_from_module


from acegen.rl_env.token_env import TokenEnv
from acegen.vocabulary import Vocabulary, SMILESTokenizerChEMBL
from acegen.data.utils import smiles_to_tensordict
from acegen.models import create_gru_actor, adapt_state_dict, create_gpt2_actor, models

from promptsmiles import ScaffoldDecorator, FragmentLinker
_has_promptsmiles = True

from acegen.rl_env.utils import generate_complete_smiles

 - make sure ninja and cmake were installed
 - make sure you ran `python setup.py clean && python setup.py develop` and that no error was raised
 - make sure the version of PyTorch you are using matches the one that was present in your virtual env during setup.
  from .autonotebook import tqdm as notebook_tqdm
  def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weight,
  def backward(ctx, dout):
  def forward(
  def backward(ctx, dout, *args):


In [1]:
from acegen.vocabulary import tokenizers

for tokenizer in [tokenizers.SMILESTokenizer, tokenizers.SMILESTokenizer2, tokenizers.DeepSMILESTokenizer, tokenizers.AISTokenizer, tokenizers.SAFETokenizer, tokenizers.SELFIESTokenizer]:
    print(tokenizer)
    smiles = "CC1=C(C(=O)N2CCCCC2=N1)CCN3CCC(CC3)C4=NOC5=C4C=CC(=C5)F"
    TOK = tokenizer()
    tokens = TOK.tokenize(smiles)
    print(tokens)
    untokens = TOK.untokenize(tokens)
    print(untokens)

  from .autonotebook import tqdm as notebook_tqdm


<class 'acegen.vocabulary.tokenizers.SMILESTokenizer'>
['C', 'C', '1', '=', 'C', '(', 'C', '(', '=', 'O', ')', 'N', '2', 'C', 'C', 'C', 'C', 'C', '2', '=', 'N', '1', ')', 'C', 'C', 'N', '3', 'C', 'C', 'C', '(', 'C', 'C', '3', ')', 'C', '4', '=', 'N', 'O', 'C', '5', '=', 'C', '4', 'C', '=', 'C', 'C', '(', '=', 'C', '5', ')', 'F']
CC1=C(C(=O)N2CCCCC2=N1)CCN3CCC(CC3)C4=NOC5=C4C=CC(=C5)F
<class 'acegen.vocabulary.tokenizers.SMILESTokenizer2'>
['C', 'C', '1', '=', 'C', '(', 'C', '(', '=', 'O', ')', 'N', '2', 'C', 'C', 'C', 'C', 'C', '2', '=', 'N', '1', ')', 'C', 'C', 'N', '3', 'C', 'C', 'C', '(', 'C', 'C', '3', ')', 'C', '4', '=', 'N', 'O', 'C', '5', '=', 'C', '4', 'C', '=', 'C', 'C', '(', '=', 'C', '5', ')', 'F']
CC1=C(C(=O)N2CCCCC2=N1)CCN3CCC(CC3)C4=NOC5=C4C=CC(=C5)F
<class 'acegen.vocabulary.tokenizers.DeepSMILESTokenizer'>
['C', 'C', '=', 'C', 'C', '=', 'O', ')', 'N', 'C', 'C', 'C', 'C', 'C', '6', '=', 'N', '%', '1', '0', ')', ')', ')', ')', ')', ')', ')', ')', 'C', 'C', 'N', 'C', 'C', 

In [5]:
output_data = generate_complete_smiles(
    environment=env,
    vocabulary=vocabulary,
    policy_sample=actor_inference,
    policy_evaluate=actor_training,
    prompt=None, #"c%11ccccc%11", # "c1ccccc1C(=O)",
    promptsmiles="N=C(N)c1c(*)cc2[nH]c(-c3ccc(Cn4cc(COc5ccc(-c6c(*)cc(OCc7cn(Cc8ccc(-c9nc%10cc(C(=N)N)ccc%10[nH]9)o8)nn7)cc6)cc5)nn4)o3)nc2c1", #"c%12ccc(*)cc%12", # "N1(*)CCN(CC1)CCCCN(*)", # fragment "N1(*)CCNCC1.C1(*)CC1.c1cncc(*)c1", # scaffold "N1(*)CCN(CC1)CCCCN(*)"
    promptsmiles_optimize=False,
    promptsmiles_shuffle=True,
    promptsmiles_multi=False,
    return_smiles_only=False
)
if isinstance(output_data, list):
    _ = [print(s) for s in output_data]
else:
    print(output_data)
    smiles = output_data.get("action").cpu()
    smiles_str = [vocabulary.decode(smi.numpy()) for smi in smiles]
    _ = [print(s) for s in smiles_str]

> [0;32m/shared/morgan/acegen-open/acegen/rl_env/utils.py[0m(268)[0;36mgenerate_complete_smiles[0;34m()[0m
[0;32m    266 [0;31m        [0;32mif[0m [0mprompt[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    267 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 268 [0;31m            [0;32mif[0m [0misinstance[0m[0;34m([0m[0mprompt[0m[0;34m,[0m [0mstr[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    269 [0;31m                [0mprompt[0m [0;34m=[0m [0;34m[[0m[0mprompt[0m[0;34m][0m [0;34m*[0m [0mbatch_size[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    270 [0;31m            [0;31m# Encode the prompt(s)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/shared/morgan/acegen-open/acegen/rl_env/utils.py[0m(355)[0;36mgenerate_complete_smiles[0;34m()[0m
[0;32m    353 [0;31m            [0mtensordicts[0m[0;34m[[0m[0;34m-[0m[0;36m1[0m[0;34m][0m[0;34m[[0m[0;34m([0m[0;34m"next"[0m[0;34m,[0m [0;34m"done"[0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0;34m~[0m[0mfinished[0m[0;34m.[0m[0mclone[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    354 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 355 [0;31m        [0moutput_data[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mstack[0m[0;34m([0m[0mtensordicts[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mcontiguous[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    356 [0;31m        [0moutput_data[0m[0;34m.[0m[0mrefine_names[0m[0;34m([0m[0;34m...[0m[0;34m

ipdb>  c


> [0;32m/shared/morgan/acegen-open/acegen/rl_env/utils.py[0m(360)[0;36mgenerate_complete_smiles[0;34m()[0m
[0;32m    358 [0;31m    [0;32mif[0m [0mreturn_smiles_only[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    359 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 360 [0;31m        [0msmiles[0m [0;34m=[0m [0moutput_data[0m[0;34m.[0m[0mget[0m[0;34m([0m[0;34m"action"[0m[0;34m)[0m[0;34m.[0m[0mcpu[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    361 [0;31m        [0msmiles_str[0m [0;34m=[0m [0;34m[[0m[0mvocabulary[0m[0;34m.[0m[0mdecode[0m[0;34m([0m[0msmi[0m[0;34m.[0m[0mnumpy[0m[0;34m([0m[0;34m)[0m[0;34m)[0m [0;32mfor[0m [0msmi[0m [0;32min[0m [0msmiles[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    362 [0;31m        [0;31m# Replace failed encodings with original prompt for P

ipdb>  c


> [0;32m/shared/morgan/acegen-open/acegen/rl_env/utils.py[0m(268)[0;36mgenerate_complete_smiles[0;34m()[0m
[0;32m    266 [0;31m        [0;32mif[0m [0mprompt[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    267 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 268 [0;31m            [0;32mif[0m [0misinstance[0m[0;34m([0m[0mprompt[0m[0;34m,[0m [0mstr[0m[0;34m)[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    269 [0;31m                [0mprompt[0m [0;34m=[0m [0;34m[[0m[0mprompt[0m[0;34m][0m [0;34m*[0m [0mbatch_size[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    270 [0;31m            [0;31m# Encode the prompt(s)[0m[0;34m[0m[0;34m[0m[0m
[0m


ipdb>  c


> [0;32m/shared/morgan/acegen-open/acegen/rl_env/utils.py[0m(355)[0;36mgenerate_complete_smiles[0;34m()[0m
[0;32m    353 [0;31m            [0mtensordicts[0m[0;34m[[0m[0;34m-[0m[0;36m1[0m[0;34m][0m[0;34m[[0m[0;34m([0m[0;34m"next"[0m[0;34m,[0m [0;34m"done"[0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0;34m~[0m[0mfinished[0m[0;34m.[0m[0mclone[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    354 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 355 [0;31m        [0moutput_data[0m [0;34m=[0m [0mtorch[0m[0;34m.[0m[0mstack[0m[0;34m([0m[0mtensordicts[0m[0;34m,[0m [0mdim[0m[0;34m=[0m[0;34m-[0m[0;36m1[0m[0;34m)[0m[0;34m.[0m[0mcontiguous[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    356 [0;31m        [0moutput_data[0m[0;34m.[0m[0mrefine_names[0m[0;34m([0m[0;34m...[0m[0;34m

ipdb>  c


> [0;32m/shared/morgan/acegen-open/acegen/rl_env/utils.py[0m(360)[0;36mgenerate_complete_smiles[0;34m()[0m
[0;32m    358 [0;31m    [0;32mif[0m [0mreturn_smiles_only[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    359 [0;31m        [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 360 [0;31m        [0msmiles[0m [0;34m=[0m [0moutput_data[0m[0;34m.[0m[0mget[0m[0;34m([0m[0;34m"action"[0m[0;34m)[0m[0;34m.[0m[0mcpu[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    361 [0;31m        [0msmiles_str[0m [0;34m=[0m [0;34m[[0m[0mvocabulary[0m[0;34m.[0m[0mdecode[0m[0;34m([0m[0msmi[0m[0;34m.[0m[0mnumpy[0m[0;34m([0m[0;34m)[0m[0;34m)[0m [0;32mfor[0m [0msmi[0m [0;32min[0m [0msmiles[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    362 [0;31m        [0;31m# Replace failed encodings with original prompt for P

ipdb>  c


TensorDict(
    fields={
        action: Tensor(shape=torch.Size([2, 200]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([2, 200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        is_init: Tensor(shape=torch.Size([2, 200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        mask: Tensor(shape=torch.Size([2, 200]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                is_init: Tensor(shape=torch.Size([2, 200, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([2, 200]), device=cpu, dtype=torch.int32, is_shared=False),
                reward: Tensor(shape=torch.Size([2, 200, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                sequence: Tensor(shape=torch.Size([2, 200]), device=cpu, dtype=to

In [2]:
# Load prereqs
smiles = [
    'c1ccccc1',
    'c1ccccc1C(=O)C'
]

# Get available device
device = torch.device("cpu")

create_actor, _, _, voc_path, ckpt_path, tokenizer = models["gru"]
#create_actor, voc_path, ckpt_path, tokenizer = (create_gru_actor, "acegen/priors/chembl_filtered_vocabulary.txt", "acegen/priors/gru_chembl_filtered.ckpt", SMILESTokenizerChEMBL())
#create_actor, voc_path, ckpt_path, tokenizer = (create_gpt2_actor, "priors/enamine_real_vocabulary.txt","/shared/albert/acegen-open/priors/gpt2_enamine_real.ckpt", SMILESTokenizer2())

# Load vocabulary
vocabulary = Vocabulary.load(voc_path, tokenizer=tokenizer)

# Create models
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
actor_training, actor_inference = create_actor(vocabulary_size=len(vocabulary), temperature=1.0)
actor_inference.load_state_dict(
    adapt_state_dict(deepcopy(ckpt), actor_inference.state_dict())
)
actor_inference = actor_inference.to(device)
actor_training = actor_training.to(device)

# Create RL environment
env_kwargs = {
    "start_token": vocabulary.start_token_index,
    "end_token": vocabulary.end_token_index,
    "length_vocabulary": len(vocabulary),
    "batch_size": 10,
    "device": device,
}

def create_env_fn():
    """Create a single RL rl_env."""
    env = TokenEnv(**env_kwargs)
    env = TransformedEnv(env)
    env.append_transform(InitTracker())
    if primers := get_primers_from_module(actor_inference):
        env.append_transform(primers)
    return env

env = create_env_fn()
#data = generate_complete_smiles(policy=actor_inference, environment=env)

In [4]:
data = generate_complete_smiles(
    environment=env,
    vocabulary=vocabulary,
    policy_sample=actor_inference,
    return_smiles_only=True,
    temperature=0.1
)
for d in data:
    print(d)

c1c(C(=O)Nc2ccc(C(=O)N3CCCC3)cc2)ccc(C)c1
c1cccc(C(NC(=O)Nc2cc3c(c(NC(C)C)n[nH]3)cn2)C)c1
c1cccc(C(C)NC(=O)Nc2cc3c(c(NC(C)C)n[nH]3)cn2)c1
c1cc(C(=O)NCC(NC2CCCCC2)=O)ccc1C(C)(C)C
c1cc(C(=O)Nc2ccc(C3CNCCO3)cc2)ccc1
c1cc(C(=O)NCC(=O)NC2CCCCC2)ccc1S(=O)(N1CCCCC1)=O
c1cc(C(=O)NCC(=O)NC2CCCCC2)ccc1S(=O)(=O)N
c1cc(C(=O)Nc2ccc(C3CNCCO3)cc2)ccc1
c1cc(C(=O)NCC(NC2CCCCC2)=O)ccc1C
c1cc(C(=O)NCC(NC2CCCCC2)=O)ccc1OC


In [41]:
#inf_logits = data["logits"].clone()
inf_logits

tensor([[[ -4.0747,  -7.4466,  -2.8203,  ...,   4.2903,   2.9706,   3.0404],
         [ -5.6927,  -5.5361,  -1.8578,  ...,   2.8094,  -2.2807,  -2.0903],
         [ -4.1017,  -8.4821,  -2.5741,  ...,   8.3017,   5.0653,   4.8057],
         ...,
         [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

        [[ -4.0747,  -7.4466,  -2.8203,  ...,   4.2903,   2.9706,   3.0404],
         [ -5.5144,  -9.1012,   1.2390,  ...,   2.5216,  -2.8495,  -3.3048],
         [ -6.1429,  -5.8778,   1.8811,  ...,   2.0452,  -3.1004,  -2.8386],
         ...,
         [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000],
         [  0.0000,   0.0000,   0.0000,  ...,   0.0000,   0.0000,   0.0000]],

        [[ -4.0747,  -7.4466,  -2.8203,  ...

In [32]:
og_log_prob = data["sample_log_prob"].clone()

In [36]:
og_log_prob * mask.sum(-1)

RuntimeError: The size of tensor a (100) must match the size of tensor b (10) at non-singleton dimension 1

#### Actor training

In [33]:
model = actor_training
mask = data.get("mask").squeeze(-1)
actions = data.get("action")
model_in = data.select(*model.in_keys, strict=False)
dist = model.get_dist(model_in)
log_prob = dist.log_prob(actions)
log_prob = (log_prob * mask).sum(-1)

In [44]:
dist.logits

tensor([[[-19.5038, -26.2475, -16.9949,  ...,  -2.7737,  -5.4131,  -5.2736],
         [-30.1359, -29.8229, -22.4661,  ..., -13.1318, -23.3120, -22.9311],
         [-26.6587, -35.4195, -23.6034,  ...,  -1.8519,  -8.3246,  -8.8439],
         ...,
         [  0.0000, -30.5823, -26.8065,  ..., -21.7027, -21.8613, -22.5407],
         [  0.0000, -30.5771, -26.8114,  ..., -21.6964, -21.8537, -22.5398],
         [  0.0000, -30.5720, -26.8164,  ..., -21.6901, -21.8462, -22.5389]],

        [[-19.5038, -26.2475, -16.9949,  ...,  -2.7737,  -5.4131,  -5.2736],
         [-20.6632, -27.8367,  -7.1562,  ...,  -4.5912, -15.3333, -16.2439],
         [-34.1759, -33.6457, -18.1279,  ..., -17.7996, -28.0910, -27.5672],
         ...,
         [  0.0000, -31.1820, -26.3252,  ..., -20.8977, -21.2366, -22.2002],
         [  0.0000, -31.1848, -26.3162,  ..., -20.8840, -21.2339, -22.2012],
         [  0.0000, -31.1876, -26.3076,  ..., -20.8703, -21.2310, -22.2020]],

        [[-19.5038, -26.2475, -16.9949,  ...