In [None]:
import torch
import numpy as np
from transformers.models.gpt2 import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
from transformers import BertTokenizer
import argparse
from tqdm import tqdm
from vocab.tokenization import SMILESBPETokenizer
from model.GPT2ModelWithPreFixTuning import GPT2LMHeadMoelWithPrefixTuning
from pytorch_lightning import Trainer
import os
from scipy.spatial.transform import Rotation as R
from rdkit import Chem
from rdkit.Chem import AllChem
from rdkit import RDLogger
RDLogger.DisableLog('rdApp.*')

Load Checkpoint and tokenizer

In [None]:
# Directory to serialize a tokenizer and model.
checkpoint = "ckpt"
tokenizer_filename = "vocab/tokenizer.json"

tokenizer = SMILESBPETokenizer.get_hf_tokenizer(
    tokenizer_filename, model_max_length=256)
model = GPT2LMHeadMoelWithPrefixTuning.from_pretrained(checkpoint)

加载ECloud数据 这里加载了7个


In [None]:
ecloud_rootdir = 'generation/eclouds'
ecloud_fnames =  os.listdir(ecloud_rootdir)
# sorted by int value
ecloud_fnames = sorted(ecloud_fnames)
ecloud_fnames = [os.path.join(ecloud_rootdir, fname) for fname in ecloud_fnames]
eclouds = np.zeros((len(ecloud_fnames), 64, 64, 64), dtype=np.float16)
print("Shape of Eclouds: ", eclouds.shape)
for i, fname in tqdm(enumerate(ecloud_fnames), total=len(ecloud_fnames)):
    eclouds[i] = np.load(fname)
# convert eclouds to torch Float32
eclouds = torch.from_numpy(eclouds).float()
print(eclouds[0].dtype)

准备生成使用的get_attention_mask以及get_prompt函数，这里get_prompt需要先load model

In [None]:
def get_attention_mask_for_generation(prefix_len=128, bz=1):
    # attention mask for generation
    # 1 for prefix tokens, 0 for generated tokens
    # plus 1 for bos token
    attention_mask = torch.ones((bz, prefix_len + 1)).long()
    return attention_mask

In [None]:
def get_prompt(ecloud):
    prefix_encoder = model.prefix_encoder
    past_key_values = prefix_encoder(ecloud)
    return past_key_values

In [None]:
prompts = get_prompt(eclouds)
print(len(prompts)) # 12 layers
print(prompts[0].shape) # each layer [2, 1, 12, 128, 64] k&v, bs, num_heads, prefix_len, d_model_per_head

生成过程 根据输入电子云生成smiles

In [None]:
n_generated = 10
smiles_start = torch.LongTensor([[tokenizer.bos_token_id]])
generated_smiles_list = []
# sequential generation
# for i in range(len(eclouds)):
#     temp_smiles_list = []
#     past_key_values = get_prompt(eclouds[i].unsqueeze(0))
#     for _ in tqdm(range(n_generated), total=n_generated):
#         generated_ids = model.generate(input_ids=smiles_start,
#                                     max_length=512-129,
#                                     top_k=50,
#                                     top_p=0.96,
#                                     repetition_penalty=0.8,
#                                     temperature=0.9,
#                                     do_sample=True,
#                                     attention_mask=attention_mask,
#                                     pad_token_id=tokenizer.eos_token_id,
#                                     past_key_values=past_key_values,
#                                     num_return_sequences=1)
#         temp_smiles_list.append(tokenizer.decode(generated_ids[0], skip_special_tokens=True))
#     print(temp_smiles_list)
#     generated_smiles_list.append(temp_smiles_list)

# batch generation
generated_smiles_list = [[] for _ in range(len(eclouds))]
attention_mask = get_attention_mask_for_generation(bz=len(eclouds))
for _ in tqdm(range(n_generated), total=n_generated):
    generated_ids = model.generate(input_ids=smiles_start.repeat(len(eclouds), 1),
                                    max_length=512-129,
                                    top_k=50,
                                    top_p=0.96,
                                    repetition_penalty=0.8,
                                    temperature=0.9,
                                    do_sample=True,
                                    attention_mask=attention_mask,
                                    pad_token_id=tokenizer.eos_token_id,
                                    past_key_values=prompts,
                                    num_return_sequences=1)
    temp_generated_smiles_list = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
    print(len(generated_smiles_list))
    for i in range(len(eclouds)):
        generated_smiles_list[i].append(temp_generated_smiles_list[i])

这里load了一下参考的sdf

In [None]:
sdf_root = 'generation/sdfs'
sdf_names = os.listdir(sdf_root)
sdf_names = sorted(sdf_names)
sdf_names = [os.path.join(sdf_root, fname) for fname in sdf_names]

In [None]:
generated_smiles_list[0]

In [None]:

write_root = 'generation/gen_res'
count = 0
valid_count = 0
for i, temp_smiles_list in enumerate(generated_smiles_list):
    for j, smiles in enumerate(temp_smiles_list):
        count += 1
        mol = Chem.MolFromSmiles(smiles)
        if mol is None:
            continue
        mol = Chem.AddHs(mol)
        AllChem.EmbedMolecule(mol, randomSeed=42)
        conformer = mol.GetConformer()
        # mol_coord = torch.tensor(conformer.GetPositions()).float()
        # target_coord = target_positions[0]
        # trans_mat = Chem.rdMolAlign.GetAlignmentTransform(mol, target_mol)
        # Chem.rdMolTransforms.TransformConformer(conformer, trans_mat)
        Chem.SDWriter(os.path.join(write_root, sdf_names[i].split('/')[-1].split('.')[0] + f'_gen_res_{j}.sdf')).write(mol)
        valid_count += 1
print("Validity: ", valid_count / count)

In [1]:
from rdkit import Chem
from data.utils import *
mol = Chem.MolFromSmiles('CCCC(NC(=O)C(C)C)C(=O)Nc1ccc(C)cc1')
print(calculate_hba(mol))
print(calculate_hbd(mol))
print(calculate_tpsa(mol))
print(calculate_mw(mol))


2
2
58.2
276.38


In [4]:
import torch
torch.FloatTensor([2])

tensor([2.])