In [1]:
from copy import deepcopy
from typing import Callable, Union
from functools import partial

import torch
from tensordict.nn import TensorDictModule
from tensordict.nn.probabilistic import set_interaction_type as set_exploration_type
from tensordict.tensordict import TensorDictBase
from torchrl.collectors import RandomPolicy
from torchrl.envs import EnvBase, TensorDictPrimer, TransformedEnv, CatFrames, InitTracker, StepCounter, UnsqueezeTransform
from torchrl.envs.utils import ExplorationType, step_mdp


from acegen.rl_env.smiles_env import SMILESEnv
from acegen.vocabulary import SMILESVocabulary, SMILESTokenizer, SMILESTokenizer2
from acegen.data.utils import smiles_to_tensordict
from acegen.models import create_gru_actor, adapt_state_dict

from promptsmiles import ScaffoldDecorator, FragmentLinker
_has_promptsmiles = True

from acegen.rl_env.utils import generate_complete_smiles, get_log_prob

  from .autonotebook import tqdm as notebook_tqdm


In [1]:
from acegen.vocabulary import tokenizers

for tokenizer in [tokenizers.SMILESTokenizer, tokenizers.SMILESTokenizer2, tokenizers.DeepSMILESTokenizer, tokenizers.AISTokenizer, tokenizers.SAFETokenizer, tokenizers.SELFIESTokenizer]:
    print(tokenizer)
    smiles = "CC1=C(C(=O)N2CCCCC2=N1)CCN3CCC(CC3)C4=NOC5=C4C=CC(=C5)F"
    TOK = tokenizer()
    tokens = TOK.tokenize(smiles)
    print(tokens)
    untokens = TOK.untokenize(tokens)
    print(untokens)

  from .autonotebook import tqdm as notebook_tqdm


<class 'acegen.vocabulary.tokenizers.SMILESTokenizer'>
['C', 'C', '1', '=', 'C', '(', 'C', '(', '=', 'O', ')', 'N', '2', 'C', 'C', 'C', 'C', 'C', '2', '=', 'N', '1', ')', 'C', 'C', 'N', '3', 'C', 'C', 'C', '(', 'C', 'C', '3', ')', 'C', '4', '=', 'N', 'O', 'C', '5', '=', 'C', '4', 'C', '=', 'C', 'C', '(', '=', 'C', '5', ')', 'F']
CC1=C(C(=O)N2CCCCC2=N1)CCN3CCC(CC3)C4=NOC5=C4C=CC(=C5)F
<class 'acegen.vocabulary.tokenizers.SMILESTokenizer2'>
['C', 'C', '1', '=', 'C', '(', 'C', '(', '=', 'O', ')', 'N', '2', 'C', 'C', 'C', 'C', 'C', '2', '=', 'N', '1', ')', 'C', 'C', 'N', '3', 'C', 'C', 'C', '(', 'C', 'C', '3', ')', 'C', '4', '=', 'N', 'O', 'C', '5', '=', 'C', '4', 'C', '=', 'C', 'C', '(', '=', 'C', '5', ')', 'F']
CC1=C(C(=O)N2CCCCC2=N1)CCN3CCC(CC3)C4=NOC5=C4C=CC(=C5)F
<class 'acegen.vocabulary.tokenizers.DeepSMILESTokenizer'>
['C', 'C', '=', 'C', 'C', '=', 'O', ')', 'N', 'C', 'C', 'C', 'C', 'C', '6', '=', 'N', '%', '1', '0', ')', ')', ')', ')', ')', ')', ')', ')', 'C', 'C', 'N', 'C', 'C', 

In [3]:
output_data = generate_complete_smiles(
    environment=env,
    vocabulary=vocabulary,
    policy=actor_inference,
    prompt="c1ccccc",
    promptsmiles=None,#"N1(*)CCN(CC1)CCCCN(*)",
    promptsmiles_optimize=False,
    promptsmiles_shuffle=False,
    return_smiles_only=True
)
print(output_data)
#smiles = output_data.select("action").cpu()
#smiles_str = [vocabulary.decode(smi.numpy()) for smi in smiles["action"]]
#_ = [print(s) for s in smiles_str]

> [0;32m/shared/morgan/acegen-open/acegen/rl_env/utils.py[0m(140)[0;36mgenerate_complete_smiles[0;34m()[0m
[0;32m    138 [0;31m        [0;32mif[0m [0mprompt[0m[0;34m:[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    139 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m--> 140 [0;31m            [0;32mif[0m [0misinstance[0m[0;34m([0m[0mprompt[0m[0;34m,[0m [0mstr[0m[0;34m)[0m[0;34m:[0m [0mprompt[0m [0;34m=[0m [0;34m[[0m[0mprompt[0m[0;34m][0m[0;34m*[0m[0mbatch_size[0m[0;34m[[0m[0;36m0[0m[0;34m][0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    141 [0;31m            [0mtokens[0m [0;34m=[0m [0;34m[[0m[0mtorch[0m[0;34m.[0m[0mtensor[0m[0;34m([0m[0mvocabulary[0m[0;34m.[0m[0mencode[0m[0;34m([0m[0msmi[0m[0;34m,[0m [0mwith_end[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m)[0m [0;32mfor[0m [0msmi[0m [0;32min[0

ipdb>  c


['c1ccccc1n(-c1ccc(OC)cc1)c(=O)c1n2c(C)nn1)N', 'c1ccccc1oc(C)nn1)=O)cc2', 'c1ccccc1ccon1)=O)cc2', 'c1ccccc1n[nH]c(-c2cc([N+](=O)[O-])c(C)o2)n1)=O', 'c1ccccc1nccs1)(O)c1ccccc1)cc2', 'c1ccccc1nc2cccc(Cl)c2[nH]1)=NO', 'c1ccccc1ccccc1)=O)cc2', 'c1ccccc1nc2c(cc(N)cc2)[nH]1)=O', 'c1ccccc1[nH]ccn1)[O-])cc2', 'c1ccccc1[nH]c2cc(Br)ccc2n1)=O', 'c1ccccc1ccc(Cl)cc1)CO)cn2', 'c1ccccc1occc1)O)cc2NC', 'c1ccccc1ocnc1C)=O)cc2', 'c1ccccc1[nH]ncc1)CC2(C)C)C', 'c1ccccc1noc(-c2ccccc2)c1)=O', 'c1ccccc1ccc(F)cc1)=O)cc2', 'c1ccccc1cnc(-c2ccc(F)cc2)o1)O', 'c1ccccc1ccc2[nH]c(=O)oc2c1)=O)C', 'c1ccccc1ccccc1)=O)cn2', 'c1ccccc1nc2ccccc2[nH]1)=O', 'c1ccccc1c(F)cc(Cl)cc1)=O)nc2', 'c1ccccc1onc(C)c1)=O)cc2', 'c1ccccc1nonc1O)(C(F)(F)F)O)cc2', 'c1ccccc1ccccc1)=O)cc2', 'c1ccccc3cccc(F)c3)(N)CC1)s2', 'c1ccccc1nn(C)cc1)c1ccccc12)N', 'c1ccccc1ccccc1)=O)cc2', 'c1ccccc1ccccc1)=O)cc2', 'c1ccccc1[nH]cc(-c2ccccc2)n1)=O', 'c1ccccc1ccc(Cl)cc1)=O)nc2', 'c1ccccc1nc2nc(C)cc(C)n2n1)C#N', 'c1ccccc1ccc(OC)cc1F)=O)cc2', 'c1ccccc1sc2cc(CN

In [2]:
# Load prereqs
smiles = [
    'c1ccccc1',
    'c1ccccc1C(=O)C'
]

# Get available device
device = torch.device("cpu")
create_actor, voc_path, ckpt_path, tokenizer = (create_gru_actor,"priors/chembl_filtered_vocabulary.txt","priors/gru_chembl_filtered.ckpt", SMILESTokenizer())
# Load vocabulary
with open(voc_path, "r") as f:
        tokens = f.read().splitlines()
tokens_dict = dict(zip(tokens, range(len(tokens))))
vocabulary = SMILESVocabulary.create_from_dict(tokens_dict, start_token="GO", end_token="EOS", tokenizer=tokenizer)
# Create models
ckpt = torch.load(ckpt_path)
actor_training, actor_inference = create_actor(vocabulary_size=len(vocabulary))
actor_inference.load_state_dict(
    adapt_state_dict(ckpt, actor_inference.state_dict())
)
actor_training.load_state_dict(adapt_state_dict(ckpt, actor_training.state_dict()))
actor_inference = actor_inference.to(device)
actor_training = actor_training.to(device)
prior = deepcopy(actor_training)
 # Create RL environment
# For RNNs, create a transform to populate initial tensordict with recurrent states equal to 0.0
rhs_primers = []
if hasattr(actor_training, "rnn_spec"):
    primers = actor_training.rnn_spec.expand(64)
    rhs_primers.append(TensorDictPrimer(primers))
env_kwargs = {
    "start_token": vocabulary.start_token_index,
    "end_token": vocabulary.end_token_index,
    "length_vocabulary": len(vocabulary),
    "batch_size": 64,
    "device": device,
}
def create_env_fn():
    """Create a single RL rl_env."""
    env = SMILESEnv(**env_kwargs)
    env = TransformedEnv(env)
    env.append_transform(StepCounter())
    env.append_transform(InitTracker())
    for rhs_primer in rhs_primers:
        env.append_transform(rhs_primer)
    return env

env = create_env_fn()
#data = generate_complete_smiles(policy=actor_inference, environment=env)

In [17]:
@torch.no_grad()
def get_log_prob(smiles: list, policy: Union[TensorDictModule, Callable[[TensorDictBase], TensorDictBase]]):
    data = smiles_to_tensordict(smiles, mask_value=0)
    data.set('is_init', torch.zeros_like(data.get('done')))

    actions = data.get("action").clone()

    # For transformers-based policies
    data.set("sequence", data.get("observation"))

    policy_in = data.select(*policy.in_keys, strict=False)
    log_prob = policy.get_dist(policy_in).log_prob(actions)
    log_prob = log_prob.sum(-1)
    return log_prob

In [18]:
# TODO Pad and encode smiles
tokens = [torch.tensor(vocabulary.encode(smi)) for smi in smiles]
enc_smiles = torch.vstack([torch.nn.functional.pad(tok, (0, 100-tok.size()[0])) for tok in tokens])
get_log_prob(enc_smiles, actor_inference)

tensor([-55.1180, -81.1181])

In [None]:
torch.tensorz()

In [12]:
smiles_td = smiles_to_tensordict(enc_smiles, mask_value=0)
smiles_td

TensorDict(
    fields={
        action: Tensor(shape=torch.Size([2, 99]), device=cpu, dtype=torch.float32, is_shared=False),
        done: Tensor(shape=torch.Size([2, 99, 1]), device=cpu, dtype=torch.bool, is_shared=False),
        mask: Tensor(shape=torch.Size([2, 99]), device=cpu, dtype=torch.bool, is_shared=False),
        next: TensorDict(
            fields={
                done: Tensor(shape=torch.Size([2, 99, 1]), device=cpu, dtype=torch.bool, is_shared=False),
                observation: Tensor(shape=torch.Size([2, 99]), device=cpu, dtype=torch.int32, is_shared=False),
                reward: Tensor(shape=torch.Size([2, 99, 1]), device=cpu, dtype=torch.float32, is_shared=False),
                terminated: Tensor(shape=torch.Size([2, 99, 1]), device=cpu, dtype=torch.bool, is_shared=False)},
            batch_size=torch.Size([2, 99]),
            device=cpu,
            is_shared=False),
        observation: Tensor(shape=torch.Size([2, 99]), device=cpu, dtype=torch.int32, is_