In [1]:
import os
os.chdir('../')

In [None]:
import pandas as pd
import torch
import os
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from dataset import SmileDataset, SmileCollator
from tokenizer import SmilesTokenizer
from model import GPTConfig, GPT
import time
# from test_connet import reconstruct
from fragment_utils import reconstruct
from tqdm import tqdm
from utils.train_utils import seed_all
from tdc import Oracle


def calculate_tanimoto_distance(fingerprint1, fingerprint2):
    """
    计算两个指纹之间的 Tanimoto 距离。
    """
    return 1 - DataStructs.TanimotoSimilarity(fingerprint1, fingerprint2)

def calculate_morgan_fingerprint(mol, radius=2, nBits=2048):
    """
    计算分子的 Morgan 指纹。
    Args:
        mol: RDKit 分子对象。
        radius: Morgan 指纹的半径。
        nBits: 指纹的位数。
    Returns:
        分子指纹，或者如果分子无效则返回 None。
    """
    try:
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)
        return fp
    except:
        return None



def calculate_diversity(molecules, radius=2, nBits=2048):
    """
    计算生成分子的多样性（平均成对 Tanimoto 距离）。
    Args:
        molecules: RDKit 分子对象的列表。
        radius: Morgan 指纹的半径。
        nBits: 指纹的位数。
    Returns:
        多样性值。
    """
    fingerprints = []
    valid_molecules = []
    for mol in molecules:
        fp = calculate_morgan_fingerprint(mol, radius, nBits)
        if fp is not None:
            fingerprints.append(fp)
            valid_molecules.append(mol)
    if not fingerprints:
        return 0.0  # 如果没有有效分子，返回 0.0
    n = len(fingerprints)
    total_distance = 0.0
    count = 0
    for i in range(n):
        for j in range(i + 1, n):
            distance = calculate_tanimoto_distance(fingerprints[i], fingerprints[j])
            total_distance += distance
            count += 1
    if count == 0:
        return 0.0
    return total_distance / count

def calculate_distance(generated_molecules, original_molecules, radius=2, nBits=2048):
    """
    计算生成分子与原始分子之间的平均 Tanimoto 距离。
    Args:
        generated_molecules: 生成的 RDKit 分子对象的列表。
        original_molecules: 原始 RDKit 分子对象的列表。
        radius: Morgan 指纹的半径。
        nBits: 指纹的位数。
    Returns:
        平均距离值。
    """
    generated_fingerprints = []
    original_fingerprints = []
    # 计算生成分子的指纹
    for mol in generated_molecules:
        fp = calculate_morgan_fingerprint(mol, radius, nBits)
        if fp is not None:
            generated_fingerprints.append(fp)
    # 计算原始分子的指纹
    for mol in original_molecules:
        fp = calculate_morgan_fingerprint(mol, radius, nBits)
        if fp is not None:
            original_fingerprints.append(fp)
    if not generated_fingerprints or not original_fingerprints:
        return 0.0
    total_distance = 0.0
    count = 0
    for gen_fp in generated_fingerprints:
        for orig_fp in original_fingerprints:
            distance = calculate_tanimoto_distance(gen_fp, orig_fp)
            total_distance += distance
            count += 1
    if count == 0:
        return 0.0
    return total_distance / count


def cal_QED(smiles):
    oracle = Oracle(name = 'QED')
    return oracle(smiles)

def cal_SA(smiles):
    oracle = Oracle(name = 'SA')
    return oracle(smiles)

def cal_all(smiles):
    results = {}
    results['QED'] = cal_QED(smiles)
    results['SA'] = cal_SA(smiles)
    return results


def Test(model, smiles, tokenizer, max_seq_len, temperature, top_k, stream, rp, num_samples, kv_cache, is_simulation,
         device, scaffold=False, linker=False):
    complete_answer_list = []
    valid_answer_list = []
    model.eval()
    # place data on the correct device
    src_smiles = tokenizer.bos_token + smiles
    x = torch.tensor(tokenizer.encode(src_smiles, add_special_tokens=False), dtype=torch.long).unsqueeze(0)
    x = x.to(device)
    with torch.no_grad():
        res_y = model.generate(x, tokenizer, max_new_tokens=max_seq_len,
                               temperature=temperature, top_k=top_k, stream=stream, rp=rp, kv_cache=kv_cache,
                               is_simulation=is_simulation)
        try:
            y = next(res_y)
        except StopIteration:
            print("No answer")

        history_idx = 0
        complete_answer = f"{tokenizer.decode(x[0])}"  # 用于保存整个生成的句子

        while y != None:
            answer = tokenizer.decode(y[0].tolist())
            if answer and answer[-1] == '�':
                try:
                    y = next(res_y)
                except:
                    break
                continue
            if not len(answer):
                try:
                    y = next(res_y)
                except:
                    break
                continue

            # 保存生成的片段到完整回答中
            complete_answer += answer[history_idx:]

            try:
                y = next(res_y)
            except:
                break
            history_idx = len(answer)
            if not stream:
                break

        complete_answer = complete_answer.replace(" ", "").replace("[BOS]", "").replace("[EOS]", "")
        frag_list = complete_answer.replace(" ", "").split('[SEP]')
        try:
            if linker:
                last_frag = frag_list[0].split('.')[1]
                first_frag = frag_list[0].split('.')[0]
                frag_list[0] = first_frag
                frag_list[len(frag_list) - 1] = last_frag
            frag_mol = [Chem.MolFromSmiles(s) for s in frag_list]
            mol = reconstruct(frag_mol)[0]
            if type(mol) == list:
                mol = mol[0]
            if mol:
                generate_smiles = Chem.MolToSmiles(mol)
                valid_answer_list.append(generate_smiles)
                answer = frag_list
            else:
                answer = frag_list
        except:
            answer = frag_list
        complete_answer_list.append(answer)

    return complete_answer_list, valid_answer_list

def main_motif():
    motif_lst = ['*N1CC2(C[C@H]1C(=O)O)SCCS2[SEP]', '*C1Nc2cc(Cl)c(S(N)(=O)=O)cc2S(=O)(=O)N1[SEP]',
     '*[C@H]1O[C@@H](CO)[C@H](O)[C@@H]1O[SEP]', '*[C@H]1CCN(C(=O)C=C)C1[SEP]', '*C1(CC#N)CN(S(=O)(=O)CC)C1[SEP]',
     '*c1ccc2c(c1)OCCO2[SEP]', '*C[C@H](N)C(=O)O[SEP]',
     '*[C@@H]1[C@@H]2C(=C[C@H](C)C[C@@H]2OC(=O)[C@@H](C)CC)C=C[C@@H]1C[SEP]', '*n1c(Br)nnc1SCC(=O)O[SEP]', '*OCCOC[SEP]']
    original_smiles = ['CCOC(=O)[C@H](CCc1ccccc1)N[C@@H](C)C(=O)N1CC2(C[C@H]1C(=O)O)SCCS2',
     'NS(=O)(=O)c1cc2c(cc1Cl)NC(C1CC3C=CC1C3)NS2(=O)=O', 'CC(C)Nc1nc2cc(Cl)c(Cl)cc2n1[C@H]1O[C@@H](CO)[C@H](O)[C@@H]1O',
     'C=CC(=O)N1CC[C@H](n2nc(C#Cc3cc(OC)cc(OC)c3)c3c(N)ncnc32)C1',
     'CCS(=O)(=O)N1CC(CC#N)(n2cc(-c3ncnc4[nH]ccc34)cn2)C1', 'CCCCCCCC(=O)N[C@H](CN1CCCC1)[C@H](O)c1ccc2c(c1)OCCO2',
     'N[C@@H](Cc1cc(I)c(Oc2ccc(O)c(I)c2)c(I)c1)C(=O)O',
     'CC[C@H](C)C(=O)O[C@H]1C[C@@H](C)C=C2C=C[C@H](C)[C@H](CC[C@@H]3C[C@@H](O)CC(=O)O3)[C@H]21',
     'O=C(O)CSc1nnc(Br)n1-c1ccc(C2CC2)c2ccccc12', 'C#Cc1cccc(Nc2ncnc3cc(OCCOC)c(OCCOC)cc23)c1']

    # 设置随机种子的值
    seed_value = 42
    seed_all(seed_value)
    # device = torch.device(f'cuda:{0}')  # 逻辑编号 cuda:0 对应 os.environ["CUDA_VISIBLE_DEVICES"]中的第一个gpu
    device = 'cuda:8'
    batch_size = 1

    test_names = "test"

    tokenizer = SmilesTokenizer('./vocabs/vocab.txt')
    tokenizer.bos_token = "[BOS]"
    tokenizer.bos_token_id = tokenizer.convert_tokens_to_ids("[BOS]")
    tokenizer.eos_token = "[EOS]"
    tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("[EOS]")

    mconf = GPTConfig(vocab_size=tokenizer.vocab_size, n_layer=12, n_head=12, n_embd=768)
    model = GPT(mconf).to(device)
    checkpoint = torch.load(f'./weights/fragpt.pt', weights_only=True)
    model.load_state_dict(checkpoint)

    start_time = time.time()
    valid_ratio_sum = 0
    uniqueness_sum = 0
    quality_sum = 0
    sa_sum = 0
    qed_sum = 0
    div_sum = 0
    dist_sum = 0
    for i in motif_lst:
        complete_answer_list, valid_answer_list = [], []
        for j in tqdm(range(100)):
            l1, l2 = Test(model, i, tokenizer, max_seq_len=512, temperature=1.2, top_k=8, stream=False, rp=1., num_samples=1,
                 kv_cache=True, is_simulation=True, device=device)
            if (len(l2) != 0):
                valid_answer_list.append(l2[0])
            if (len(l1) != 0):
                complete_answer_list.append(l1[0])
        unique_smiles = set(smile for smile in valid_answer_list if smile is not None)
        unique_smiles_lst = list(unique_smiles)
        num_unique_molecules = len(unique_smiles)
        uniqueness = num_unique_molecules / len(valid_answer_list)
        valid_ratio = len(valid_answer_list) / 100
        results = cal_all(unique_smiles_lst)
        SA_score = 0
        QED_score = 0
        sum = 0
        for k in range(len(unique_smiles_lst)):
            SA_score += results['SA'][k]
            QED_score += results['QED'][k]
            if (results['QED'][k] >= 0.6 and results['SA'][k] <= 4):
                sum += 1

        generated_molecules = [Chem.MolFromSmiles(s) for s in valid_answer_list]
        original_molecules = [Chem.MolFromSmiles(s) for s in original_smiles]
        # 计算多样性
        diversity = calculate_diversity(generated_molecules)
        # 计算距离
        distance = calculate_distance(generated_molecules, original_molecules)

        print('valid_ratio:', valid_ratio, 'uniqueness:', uniqueness, 'Quality:', sum / 100, 'SA:',
              SA_score / len(unique_smiles_lst), 'QED:', QED_score / len(unique_smiles_lst), 'diversity:', diversity,
              'distance:', distance)
        valid_ratio_sum += valid_ratio
        uniqueness_sum += uniqueness
        quality_sum += sum / len(unique_smiles_lst)
        sa_sum += SA_score / len(unique_smiles_lst)
        qed_sum += QED_score / len(unique_smiles_lst)
        div_sum += diversity
        dist_sum += distance
    end_time = time.time()
    elapsed_time = end_time - start_time

    print(f"运行时间: {elapsed_time:.4f} 秒")
    print(f"valid_ratio_avg: {valid_ratio_sum / len(motif_lst)}, uniqueness_avg: {uniqueness_sum / len(motif_lst)}, "
          f"quality_avg: {quality_sum / len(motif_lst)}, sa_avg: {sa_sum / len(motif_lst)}, "
          f"qed_avg: {qed_sum / len(motif_lst)}, div_avg: {div_sum / len(motif_lst)}, dist_avg: {dist_sum / len(motif_lst)}")



if __name__ == '__main__':

    main_motif()


  1%|█▋                                                                                                                                                                 | 1/100 [00:00<00:48,  2.03it/s][18:33:44] SMILES Parse Error: ring closure 2 duplicates bond between atom 4 and atom 5 for input: '*c1c[nH]n2c2c(F)cccc12'
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:29<00:00,  3.43it/s]
Found local copy...


valid_ratio: 0.99 uniqueness: 0.8484848484848485 Quality: 0.41 SA: 4.082602312947524 QED: 0.7548358677335902 diversity: 0.5818713929911064 distance: 0.8770998645193947


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:17<00:00,  5.76it/s]


valid_ratio: 1.0 uniqueness: 0.63 Quality: 0.32 SA: 3.5770179738832515 QED: 0.5876661552588345 diversity: 0.4314305540518022 distance: 0.8798861931212343


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:30<00:00,  3.24it/s]


valid_ratio: 0.98 uniqueness: 0.9489795918367347 Quality: 0.22 SA: 4.073123692669797 QED: 0.48954089654286553 diversity: 0.7215945300918065 distance: 0.8996018386891813


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:46<00:00,  2.17it/s]


valid_ratio: 0.95 uniqueness: 0.9789473684210527 Quality: 0.29 SA: 3.824050751402876 QED: 0.5663470609936114 diversity: 0.7254525779134393 distance: 0.8864769578345574


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:10<00:00,  9.85it/s]


valid_ratio: 1.0 uniqueness: 0.28 Quality: 0.21 SA: 3.426947587413814 QED: 0.6441973739660686 diversity: 0.266586589768544 distance: 0.9038154540013067


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:33<00:00,  2.99it/s]


valid_ratio: 0.97 uniqueness: 1.0 Quality: 0.84 SA: 2.911012857556073 QED: 0.7610462211755915 diversity: 0.7199658560547575 distance: 0.9044153334228633


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:27<00:00,  3.59it/s]


valid_ratio: 0.99 uniqueness: 0.98989898989899 Quality: 0.59 SA: 3.2409333606627007 QED: 0.6192221836286511 diversity: 0.741413627654963 distance: 0.8830181443746858


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:34<00:00,  2.94it/s]


valid_ratio: 0.98 uniqueness: 0.9387755102040817 Quality: 0.0 SA: 4.881448244351062 QED: 0.40276351752939127 diversity: 0.47239707635483164 distance: 0.8481383638570643


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:18<00:00,  5.34it/s]


valid_ratio: 1.0 uniqueness: 0.5 Quality: 0.3 SA: 2.9627928098030543 QED: 0.6241576441950124 diversity: 0.48271744221120644 distance: 0.8685142773081145


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:53<00:00,  1.88it/s]


valid_ratio: 0.91 uniqueness: 1.0 Quality: 0.16 SA: 3.6138539754333867 QED: 0.4124903346694743 diversity: 0.8308326803379792 distance: 0.9000781222485721
运行时间: 303.6905 秒
valid_ratio_avg: 0.977, uniqueness_avg: 0.8115086308845708, quality_avg: 0.4538263216399945, sa_avg: 3.659378356612354, qed_avg: 0.586226725569309, div_avg: 0.5974262327430436, dist_avg: 0.8851044549376974


In [None]:
import pandas as pd
import torch
import os
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from dataset import SmileDataset, SmileCollator
from tokenizer import SmilesTokenizer
from model import GPTConfig, GPT
import time
# from test_connet import reconstruct
from fragment_utils import reconstruct
from tqdm import tqdm
from utils.train_utils import seed_all
from tdc import Oracle


def calculate_tanimoto_distance(fingerprint1, fingerprint2):
    """
    计算两个指纹之间的 Tanimoto 距离。
    """
    return 1 - DataStructs.TanimotoSimilarity(fingerprint1, fingerprint2)

def calculate_morgan_fingerprint(mol, radius=2, nBits=2048):
    """
    计算分子的 Morgan 指纹。
    Args:
        mol: RDKit 分子对象。
        radius: Morgan 指纹的半径。
        nBits: 指纹的位数。
    Returns:
        分子指纹，或者如果分子无效则返回 None。
    """
    try:
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)
        return fp
    except:
        return None



def calculate_diversity(molecules, radius=2, nBits=2048):
    """
    计算生成分子的多样性（平均成对 Tanimoto 距离）。
    Args:
        molecules: RDKit 分子对象的列表。
        radius: Morgan 指纹的半径。
        nBits: 指纹的位数。
    Returns:
        多样性值。
    """
    fingerprints = []
    valid_molecules = []
    for mol in molecules:
        fp = calculate_morgan_fingerprint(mol, radius, nBits)
        if fp is not None:
            fingerprints.append(fp)
            valid_molecules.append(mol)
    if not fingerprints:
        return 0.0  # 如果没有有效分子，返回 0.0
    n = len(fingerprints)
    total_distance = 0.0
    count = 0
    for i in range(n):
        for j in range(i + 1, n):
            distance = calculate_tanimoto_distance(fingerprints[i], fingerprints[j])
            total_distance += distance
            count += 1
    if count == 0:
        return 0.0
    return total_distance / count

def calculate_distance(generated_molecules, original_molecules, radius=2, nBits=2048):
    """
    计算生成分子与原始分子之间的平均 Tanimoto 距离。
    Args:
        generated_molecules: 生成的 RDKit 分子对象的列表。
        original_molecules: 原始 RDKit 分子对象的列表。
        radius: Morgan 指纹的半径。
        nBits: 指纹的位数。
    Returns:
        平均距离值。
    """
    generated_fingerprints = []
    original_fingerprints = []
    # 计算生成分子的指纹
    for mol in generated_molecules:
        fp = calculate_morgan_fingerprint(mol, radius, nBits)
        if fp is not None:
            generated_fingerprints.append(fp)
    # 计算原始分子的指纹
    for mol in original_molecules:
        fp = calculate_morgan_fingerprint(mol, radius, nBits)
        if fp is not None:
            original_fingerprints.append(fp)
    if not generated_fingerprints or not original_fingerprints:
        return 0.0
    total_distance = 0.0
    count = 0
    for gen_fp in generated_fingerprints:
        for orig_fp in original_fingerprints:
            distance = calculate_tanimoto_distance(gen_fp, orig_fp)
            total_distance += distance
            count += 1
    if count == 0:
        return 0.0
    return total_distance / count


def cal_QED(smiles):
    oracle = Oracle(name = 'QED')
    return oracle(smiles)

def cal_SA(smiles):
    oracle = Oracle(name = 'SA')
    return oracle(smiles)

def cal_all(smiles):
    results = {}
    results['QED'] = cal_QED(smiles)
    results['SA'] = cal_SA(smiles)
    return results


def Test(model, smiles, tokenizer, max_seq_len, temperature, top_k, stream, rp, num_samples, kv_cache, is_simulation,
         device, scaffold=False, linker=False):
    complete_answer_list = []
    valid_answer_list = []
    model.eval()
    # place data on the correct device
    src_smiles = tokenizer.bos_token + smiles
    x = torch.tensor(tokenizer.encode(src_smiles, add_special_tokens=False), dtype=torch.long).unsqueeze(0)
    x = x.to(device)
    with torch.no_grad():
        res_y = model.generate(x, tokenizer, max_new_tokens=max_seq_len,
                               temperature=temperature, top_k=top_k, stream=stream, rp=rp, kv_cache=kv_cache,
                               is_simulation=is_simulation)
        try:
            y = next(res_y)
        except StopIteration:
            print("No answer")

        history_idx = 0
        complete_answer = f"{tokenizer.decode(x[0])}"  # 用于保存整个生成的句子

        while y != None:
            answer = tokenizer.decode(y[0].tolist())
            if answer and answer[-1] == '�':
                try:
                    y = next(res_y)
                except:
                    break
                continue
            if not len(answer):
                try:
                    y = next(res_y)
                except:
                    break
                continue

            # 保存生成的片段到完整回答中
            complete_answer += answer[history_idx:]

            try:
                y = next(res_y)
            except:
                break
            history_idx = len(answer)
            if not stream:
                break

        complete_answer = complete_answer.replace(" ", "").replace("[BOS]", "").replace("[EOS]", "")
        frag_list = complete_answer.replace(" ", "").split('[SEP]')
        try:
            if linker:
                last_frag = frag_list[0].split('.')[1]
                first_frag = frag_list[0].split('.')[0]
                frag_list[0] = first_frag
                frag_list[len(frag_list) - 1] = last_frag
            frag_mol = [Chem.MolFromSmiles(s) for s in frag_list]
            mol = reconstruct(frag_mol)[0]
            if type(mol) == list:
                mol = mol[0]
            if mol:
                generate_smiles = Chem.MolToSmiles(mol)
                valid_answer_list.append(generate_smiles)
                answer = frag_list
            else:
                answer = frag_list
        except:
            answer = frag_list
        complete_answer_list.append(answer)

    return complete_answer_list, valid_answer_list

def main_motif():
    motif_lst = ['*N1CC2(C[C@H]1C(=O)O)SCCS2[SEP]', '*C1Nc2cc(Cl)c(S(N)(=O)=O)cc2S(=O)(=O)N1[SEP]',
     '*[C@H]1O[C@@H](CO)[C@H](O)[C@@H]1O[SEP]', '*[C@H]1CCN(C(=O)C=C)C1[SEP]', '*C1(CC#N)CN(S(=O)(=O)CC)C1[SEP]',
     '*c1ccc2c(c1)OCCO2[SEP]', '*C[C@H](N)C(=O)O[SEP]',
     '*[C@@H]1[C@@H]2C(=C[C@H](C)C[C@@H]2OC(=O)[C@@H](C)CC)C=C[C@@H]1C[SEP]', '*n1c(Br)nnc1SCC(=O)O[SEP]', '*OCCOC[SEP]']
    original_smiles = ['CCOC(=O)[C@H](CCc1ccccc1)N[C@@H](C)C(=O)N1CC2(C[C@H]1C(=O)O)SCCS2',
     'NS(=O)(=O)c1cc2c(cc1Cl)NC(C1CC3C=CC1C3)NS2(=O)=O', 'CC(C)Nc1nc2cc(Cl)c(Cl)cc2n1[C@H]1O[C@@H](CO)[C@H](O)[C@@H]1O',
     'C=CC(=O)N1CC[C@H](n2nc(C#Cc3cc(OC)cc(OC)c3)c3c(N)ncnc32)C1',
     'CCS(=O)(=O)N1CC(CC#N)(n2cc(-c3ncnc4[nH]ccc34)cn2)C1', 'CCCCCCCC(=O)N[C@H](CN1CCCC1)[C@H](O)c1ccc2c(c1)OCCO2',
     'N[C@@H](Cc1cc(I)c(Oc2ccc(O)c(I)c2)c(I)c1)C(=O)O',
     'CC[C@H](C)C(=O)O[C@H]1C[C@@H](C)C=C2C=C[C@H](C)[C@H](CC[C@@H]3C[C@@H](O)CC(=O)O3)[C@H]21',
     'O=C(O)CSc1nnc(Br)n1-c1ccc(C2CC2)c2ccccc12', 'C#Cc1cccc(Nc2ncnc3cc(OCCOC)c(OCCOC)cc23)c1']

    # 设置随机种子的值
    seed_value = 43
    seed_all(seed_value)
    # device = torch.device(f'cuda:{0}')  # 逻辑编号 cuda:0 对应 os.environ["CUDA_VISIBLE_DEVICES"]中的第一个gpu
    device = 'cuda:8'
    batch_size = 1

    test_names = "test"

    tokenizer = SmilesTokenizer('./vocabs/vocab.txt')
    tokenizer.bos_token = "[BOS]"
    tokenizer.bos_token_id = tokenizer.convert_tokens_to_ids("[BOS]")
    tokenizer.eos_token = "[EOS]"
    tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("[EOS]")

    mconf = GPTConfig(vocab_size=tokenizer.vocab_size, n_layer=12, n_head=12, n_embd=768)
    model = GPT(mconf).to(device)
    checkpoint = torch.load(f'./weights/fragpt.pt', weights_only=True)
    model.load_state_dict(checkpoint)

    start_time = time.time()
    valid_ratio_sum = 0
    uniqueness_sum = 0
    quality_sum = 0
    sa_sum = 0
    qed_sum = 0
    div_sum = 0
    dist_sum = 0
    for i in motif_lst:
        complete_answer_list, valid_answer_list = [], []
        for j in tqdm(range(100)):
            l1, l2 = Test(model, i, tokenizer, max_seq_len=512, temperature=1.2, top_k=8, stream=False, rp=1., num_samples=1,
                 kv_cache=True, is_simulation=True, device=device)
            if (len(l2) != 0):
                valid_answer_list.append(l2[0])
            if (len(l1) != 0):
                complete_answer_list.append(l1[0])
        unique_smiles = set(smile for smile in valid_answer_list if smile is not None)
        unique_smiles_lst = list(unique_smiles)
        num_unique_molecules = len(unique_smiles)
        uniqueness = num_unique_molecules / len(valid_answer_list)
        valid_ratio = len(valid_answer_list) / 100
        results = cal_all(unique_smiles_lst)
        SA_score = 0
        QED_score = 0
        sum = 0
        for k in range(len(unique_smiles_lst)):
            SA_score += results['SA'][k]
            QED_score += results['QED'][k]
            if (results['QED'][k] >= 0.6 and results['SA'][k] <= 4):
                sum += 1

        generated_molecules = [Chem.MolFromSmiles(s) for s in valid_answer_list]
        original_molecules = [Chem.MolFromSmiles(s) for s in original_smiles]
        # 计算多样性
        diversity = calculate_diversity(generated_molecules)
        # 计算距离
        distance = calculate_distance(generated_molecules, original_molecules)

        print('valid_ratio:', valid_ratio, 'uniqueness:', uniqueness, 'Quality:', sum / 100, 'SA:',
              SA_score / len(unique_smiles_lst), 'QED:', QED_score / len(unique_smiles_lst), 'diversity:', diversity,
              'distance:', distance)
        valid_ratio_sum += valid_ratio
        uniqueness_sum += uniqueness
        quality_sum += sum / len(unique_smiles_lst)
        sa_sum += SA_score / len(unique_smiles_lst)
        qed_sum += QED_score / len(unique_smiles_lst)
        div_sum += diversity
        dist_sum += distance
    end_time = time.time()
    elapsed_time = end_time - start_time

    print(f"运行时间: {elapsed_time:.4f} 秒")
    print(f"valid_ratio_avg: {valid_ratio_sum / len(motif_lst)}, uniqueness_avg: {uniqueness_sum / len(motif_lst)}, "
          f"quality_avg: {quality_sum / len(motif_lst)}, sa_avg: {sa_sum / len(motif_lst)}, "
          f"qed_avg: {qed_sum / len(motif_lst)}, div_avg: {div_sum / len(motif_lst)}, dist_avg: {dist_sum / len(motif_lst)}")



if __name__ == '__main__':

    main_motif()


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:27<00:00,  3.70it/s]


valid_ratio: 0.98 uniqueness: 0.7653061224489796 Quality: 0.37 SA: 4.0471873052531775 QED: 0.7545465261416803 diversity: 0.5633223096235295 distance: 0.8755049484706365


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:19<00:00,  5.13it/s]


valid_ratio: 1.0 uniqueness: 0.75 Quality: 0.33 SA: 3.6392827088843402 QED: 0.5561704435118435 diversity: 0.4691635497757367 distance: 0.8790862775417675


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:34<00:00,  2.87it/s]


valid_ratio: 0.95 uniqueness: 0.9368421052631579 Quality: 0.16 SA: 4.167056595675648 QED: 0.43220426643167703 diversity: 0.7238316485244194 distance: 0.9004734953375509


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:44<00:00,  2.24it/s]


valid_ratio: 0.97 uniqueness: 0.9896907216494846 Quality: 0.33 SA: 3.790744183686355 QED: 0.5666686542736127 diversity: 0.7204393772245676 distance: 0.8851340754706231


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:14<00:00,  7.06it/s]


valid_ratio: 1.0 uniqueness: 0.29 Quality: 0.17 SA: 3.5619251530493647 QED: 0.5910098977371305 diversity: 0.32714394711850625 distance: 0.9019588589778192


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:30<00:00,  3.24it/s]


valid_ratio: 0.96 uniqueness: 1.0 Quality: 0.75 SA: 2.937193663879391 QED: 0.7333701140492966 diversity: 0.7308327962811164 distance: 0.9061918642383304


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:29<00:00,  3.33it/s]


valid_ratio: 0.98 uniqueness: 1.0 Quality: 0.6 SA: 3.2830916427822614 QED: 0.6376612707650426 diversity: 0.748315603562503 distance: 0.8809779597420602


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:32<00:00,  3.10it/s]


valid_ratio: 0.99 uniqueness: 0.9393939393939394 Quality: 0.0 SA: 4.779432262350212 QED: 0.41672534977475667 diversity: 0.4424315743066368 distance: 0.8460635319532098


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:16<00:00,  6.10it/s]


valid_ratio: 1.0 uniqueness: 0.53 Quality: 0.39 SA: 2.76748766119774 QED: 0.7265237963365047 diversity: 0.474277436289222 distance: 0.8669336765950608


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:51<00:00,  1.92it/s]


valid_ratio: 0.88 uniqueness: 1.0 Quality: 0.3 SA: 3.5823658020208713 QED: 0.46680602974331475 diversity: 0.825303516549678 distance: 0.89642818776501
运行时间: 304.0237 秒
valid_ratio_avg: 0.9710000000000001, uniqueness_avg: 0.8201232888755563, quality_avg: 0.4513318556255982, sa_avg: 3.655576697877936, qed_avg: 0.5881686348764859, div_avg: 0.6025061759255916, dist_avg: 0.8838752876092066


In [None]:
import pandas as pd
import torch
import os
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from dataset import SmileDataset, SmileCollator
from tokenizer import SmilesTokenizer
from model import GPTConfig, GPT
import time
# from test_connet import reconstruct
from fragment_utils import reconstruct
from tqdm import tqdm
from utils.train_utils import seed_all
from tdc import Oracle


def calculate_tanimoto_distance(fingerprint1, fingerprint2):
    """
    计算两个指纹之间的 Tanimoto 距离。
    """
    return 1 - DataStructs.TanimotoSimilarity(fingerprint1, fingerprint2)

def calculate_morgan_fingerprint(mol, radius=2, nBits=2048):
    """
    计算分子的 Morgan 指纹。
    Args:
        mol: RDKit 分子对象。
        radius: Morgan 指纹的半径。
        nBits: 指纹的位数。
    Returns:
        分子指纹，或者如果分子无效则返回 None。
    """
    try:
        fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits)
        return fp
    except:
        return None



def calculate_diversity(molecules, radius=2, nBits=2048):
    """
    计算生成分子的多样性（平均成对 Tanimoto 距离）。
    Args:
        molecules: RDKit 分子对象的列表。
        radius: Morgan 指纹的半径。
        nBits: 指纹的位数。
    Returns:
        多样性值。
    """
    fingerprints = []
    valid_molecules = []
    for mol in molecules:
        fp = calculate_morgan_fingerprint(mol, radius, nBits)
        if fp is not None:
            fingerprints.append(fp)
            valid_molecules.append(mol)
    if not fingerprints:
        return 0.0  # 如果没有有效分子，返回 0.0
    n = len(fingerprints)
    total_distance = 0.0
    count = 0
    for i in range(n):
        for j in range(i + 1, n):
            distance = calculate_tanimoto_distance(fingerprints[i], fingerprints[j])
            total_distance += distance
            count += 1
    if count == 0:
        return 0.0
    return total_distance / count

def calculate_distance(generated_molecules, original_molecules, radius=2, nBits=2048):
    """
    计算生成分子与原始分子之间的平均 Tanimoto 距离。
    Args:
        generated_molecules: 生成的 RDKit 分子对象的列表。
        original_molecules: 原始 RDKit 分子对象的列表。
        radius: Morgan 指纹的半径。
        nBits: 指纹的位数。
    Returns:
        平均距离值。
    """
    generated_fingerprints = []
    original_fingerprints = []
    # 计算生成分子的指纹
    for mol in generated_molecules:
        fp = calculate_morgan_fingerprint(mol, radius, nBits)
        if fp is not None:
            generated_fingerprints.append(fp)
    # 计算原始分子的指纹
    for mol in original_molecules:
        fp = calculate_morgan_fingerprint(mol, radius, nBits)
        if fp is not None:
            original_fingerprints.append(fp)
    if not generated_fingerprints or not original_fingerprints:
        return 0.0
    total_distance = 0.0
    count = 0
    for gen_fp in generated_fingerprints:
        for orig_fp in original_fingerprints:
            distance = calculate_tanimoto_distance(gen_fp, orig_fp)
            total_distance += distance
            count += 1
    if count == 0:
        return 0.0
    return total_distance / count


def cal_QED(smiles):
    oracle = Oracle(name = 'QED')
    return oracle(smiles)

def cal_SA(smiles):
    oracle = Oracle(name = 'SA')
    return oracle(smiles)

def cal_all(smiles):
    results = {}
    results['QED'] = cal_QED(smiles)
    results['SA'] = cal_SA(smiles)
    return results


def Test(model, smiles, tokenizer, max_seq_len, temperature, top_k, stream, rp, num_samples, kv_cache, is_simulation,
         device, scaffold=False, linker=False):
    complete_answer_list = []
    valid_answer_list = []
    model.eval()
    # place data on the correct device
    src_smiles = tokenizer.bos_token + smiles
    x = torch.tensor(tokenizer.encode(src_smiles, add_special_tokens=False), dtype=torch.long).unsqueeze(0)
    x = x.to(device)
    with torch.no_grad():
        res_y = model.generate(x, tokenizer, max_new_tokens=max_seq_len,
                               temperature=temperature, top_k=top_k, stream=stream, rp=rp, kv_cache=kv_cache,
                               is_simulation=is_simulation)
        try:
            y = next(res_y)
        except StopIteration:
            print("No answer")

        history_idx = 0
        complete_answer = f"{tokenizer.decode(x[0])}"  # 用于保存整个生成的句子

        while y != None:
            answer = tokenizer.decode(y[0].tolist())
            if answer and answer[-1] == '�':
                try:
                    y = next(res_y)
                except:
                    break
                continue
            if not len(answer):
                try:
                    y = next(res_y)
                except:
                    break
                continue

            # 保存生成的片段到完整回答中
            complete_answer += answer[history_idx:]

            try:
                y = next(res_y)
            except:
                break
            history_idx = len(answer)
            if not stream:
                break

        complete_answer = complete_answer.replace(" ", "").replace("[BOS]", "").replace("[EOS]", "")
        frag_list = complete_answer.replace(" ", "").split('[SEP]')
        try:
            if linker:
                last_frag = frag_list[0].split('.')[1]
                first_frag = frag_list[0].split('.')[0]
                frag_list[0] = first_frag
                frag_list[len(frag_list) - 1] = last_frag
            frag_mol = [Chem.MolFromSmiles(s) for s in frag_list]
            mol = reconstruct(frag_mol)[0]
            if type(mol) == list:
                mol = mol[0]
            if mol:
                generate_smiles = Chem.MolToSmiles(mol)
                valid_answer_list.append(generate_smiles)
                answer = frag_list
            else:
                answer = frag_list
        except:
            answer = frag_list
        complete_answer_list.append(answer)

    return complete_answer_list, valid_answer_list

def main_motif():
    motif_lst = ['*N1CC2(C[C@H]1C(=O)O)SCCS2[SEP]', '*C1Nc2cc(Cl)c(S(N)(=O)=O)cc2S(=O)(=O)N1[SEP]',
     '*[C@H]1O[C@@H](CO)[C@H](O)[C@@H]1O[SEP]', '*[C@H]1CCN(C(=O)C=C)C1[SEP]', '*C1(CC#N)CN(S(=O)(=O)CC)C1[SEP]',
     '*c1ccc2c(c1)OCCO2[SEP]', '*C[C@H](N)C(=O)O[SEP]',
     '*[C@@H]1[C@@H]2C(=C[C@H](C)C[C@@H]2OC(=O)[C@@H](C)CC)C=C[C@@H]1C[SEP]', '*n1c(Br)nnc1SCC(=O)O[SEP]', '*OCCOC[SEP]']
    original_smiles = ['CCOC(=O)[C@H](CCc1ccccc1)N[C@@H](C)C(=O)N1CC2(C[C@H]1C(=O)O)SCCS2',
     'NS(=O)(=O)c1cc2c(cc1Cl)NC(C1CC3C=CC1C3)NS2(=O)=O', 'CC(C)Nc1nc2cc(Cl)c(Cl)cc2n1[C@H]1O[C@@H](CO)[C@H](O)[C@@H]1O',
     'C=CC(=O)N1CC[C@H](n2nc(C#Cc3cc(OC)cc(OC)c3)c3c(N)ncnc32)C1',
     'CCS(=O)(=O)N1CC(CC#N)(n2cc(-c3ncnc4[nH]ccc34)cn2)C1', 'CCCCCCCC(=O)N[C@H](CN1CCCC1)[C@H](O)c1ccc2c(c1)OCCO2',
     'N[C@@H](Cc1cc(I)c(Oc2ccc(O)c(I)c2)c(I)c1)C(=O)O',
     'CC[C@H](C)C(=O)O[C@H]1C[C@@H](C)C=C2C=C[C@H](C)[C@H](CC[C@@H]3C[C@@H](O)CC(=O)O3)[C@H]21',
     'O=C(O)CSc1nnc(Br)n1-c1ccc(C2CC2)c2ccccc12', 'C#Cc1cccc(Nc2ncnc3cc(OCCOC)c(OCCOC)cc23)c1']

    # 设置随机种子的值
    seed_value = 44
    seed_all(seed_value)
    # device = torch.device(f'cuda:{0}')  # 逻辑编号 cuda:0 对应 os.environ["CUDA_VISIBLE_DEVICES"]中的第一个gpu
    device = 'cuda:8'
    batch_size = 1

    test_names = "test"

    tokenizer = SmilesTokenizer('./vocabs/vocab.txt')
    tokenizer.bos_token = "[BOS]"
    tokenizer.bos_token_id = tokenizer.convert_tokens_to_ids("[BOS]")
    tokenizer.eos_token = "[EOS]"
    tokenizer.eos_token_id = tokenizer.convert_tokens_to_ids("[EOS]")

    mconf = GPTConfig(vocab_size=tokenizer.vocab_size, n_layer=12, n_head=12, n_embd=768)
    model = GPT(mconf).to(device)
    checkpoint = torch.load(f'./weights/fragpt.pt', weights_only=True)
    model.load_state_dict(checkpoint)

    start_time = time.time()
    valid_ratio_sum = 0
    uniqueness_sum = 0
    quality_sum = 0
    sa_sum = 0
    qed_sum = 0
    div_sum = 0
    dist_sum = 0
    for i in motif_lst:
        complete_answer_list, valid_answer_list = [], []
        for j in tqdm(range(100)):
            l1, l2 = Test(model, i, tokenizer, max_seq_len=512, temperature=1.2, top_k=8, stream=False, rp=1., num_samples=1,
                 kv_cache=True, is_simulation=True, device=device)
            if (len(l2) != 0):
                valid_answer_list.append(l2[0])
            if (len(l1) != 0):
                complete_answer_list.append(l1[0])
        unique_smiles = set(smile for smile in valid_answer_list if smile is not None)
        unique_smiles_lst = list(unique_smiles)
        num_unique_molecules = len(unique_smiles)
        uniqueness = num_unique_molecules / len(valid_answer_list)
        valid_ratio = len(valid_answer_list) / 100
        results = cal_all(unique_smiles_lst)
        SA_score = 0
        QED_score = 0
        sum = 0
        for k in range(len(unique_smiles_lst)):
            SA_score += results['SA'][k]
            QED_score += results['QED'][k]
            if (results['QED'][k] >= 0.6 and results['SA'][k] <= 4):
                sum += 1

        generated_molecules = [Chem.MolFromSmiles(s) for s in valid_answer_list]
        original_molecules = [Chem.MolFromSmiles(s) for s in original_smiles]
        # 计算多样性
        diversity = calculate_diversity(generated_molecules)
        # 计算距离
        distance = calculate_distance(generated_molecules, original_molecules)

        print('valid_ratio:', valid_ratio, 'uniqueness:', uniqueness, 'Quality:', sum / 100, 'SA:',
              SA_score / len(unique_smiles_lst), 'QED:', QED_score / len(unique_smiles_lst), 'diversity:', diversity,
              'distance:', distance)
        valid_ratio_sum += valid_ratio
        uniqueness_sum += uniqueness
        quality_sum += sum / len(unique_smiles_lst)
        sa_sum += SA_score / len(unique_smiles_lst)
        qed_sum += QED_score / len(unique_smiles_lst)
        div_sum += diversity
        dist_sum += distance
    end_time = time.time()
    elapsed_time = end_time - start_time

    print(f"运行时间: {elapsed_time:.4f} 秒")
    print(f"valid_ratio_avg: {valid_ratio_sum / len(motif_lst)}, uniqueness_avg: {uniqueness_sum / len(motif_lst)}, "
          f"quality_avg: {quality_sum / len(motif_lst)}, sa_avg: {sa_sum / len(motif_lst)}, "
          f"qed_avg: {qed_sum / len(motif_lst)}, div_avg: {div_sum / len(motif_lst)}, dist_avg: {dist_sum / len(motif_lst)}")



if __name__ == '__main__':

    main_motif()


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:26<00:00,  3.71it/s]


valid_ratio: 0.99 uniqueness: 0.797979797979798 Quality: 0.44 SA: 4.02953527212658 QED: 0.7706266673604354 diversity: 0.5665239052519188 distance: 0.8765397306797168


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:18<00:00,  5.26it/s]


valid_ratio: 0.99 uniqueness: 0.5858585858585859 Quality: 0.29 SA: 3.692897223129867 QED: 0.5564129098905336 diversity: 0.4522680830612557 distance: 0.8796205112529146


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:37<00:00,  2.70it/s]


valid_ratio: 0.95 uniqueness: 0.9894736842105263 Quality: 0.2 SA: 4.2302249966876575 QED: 0.44244132125298463 diversity: 0.7321740009755279 distance: 0.9032678747022594


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:49<00:00,  2.03it/s]


valid_ratio: 0.96 uniqueness: 1.0 Quality: 0.29 SA: 3.909284414340761 QED: 0.5145598633832891 diversity: 0.7281164348814452 distance: 0.8851942291535593


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:10<00:00,  9.53it/s]


valid_ratio: 1.0 uniqueness: 0.34 Quality: 0.22 SA: 3.3381867833441383 QED: 0.6315214271902195 diversity: 0.3094609441460307 distance: 0.9028072979327085


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:31<00:00,  3.14it/s]


valid_ratio: 0.98 uniqueness: 0.9897959183673469 Quality: 0.73 SA: 3.087199066975929 QED: 0.7380253427114603 diversity: 0.7324647781745027 distance: 0.9044391216283378


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:33<00:00,  3.03it/s]


valid_ratio: 0.98 uniqueness: 1.0 Quality: 0.65 SA: 3.1311924616387086 QED: 0.6532572868479127 diversity: 0.7417959299918974 distance: 0.8796495182813083


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:34<00:00,  2.91it/s]


valid_ratio: 0.97 uniqueness: 0.865979381443299 Quality: 0.0 SA: 4.866521582973061 QED: 0.42469777810189996 diversity: 0.4459507276827063 distance: 0.8446740500871041


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:17<00:00,  5.81it/s]


valid_ratio: 1.0 uniqueness: 0.49 Quality: 0.36 SA: 2.773157701975621 QED: 0.7017392861488146 diversity: 0.4689384616999663 distance: 0.8667595494526619


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100/100 [00:58<00:00,  1.72it/s]


valid_ratio: 0.95 uniqueness: 1.0 Quality: 0.13 SA: 3.931251396652935 QED: 0.3863054897263462 diversity: 0.8365118745502933 distance: 0.8980199061132468
运行时间: 319.6354 秒
valid_ratio_avg: 0.977, uniqueness_avg: 0.8059087367859556, quality_avg: 0.45062487481502655, sa_avg: 3.6989450899845253, qed_avg: 0.5819587372613896, div_avg: 0.6014205140415544, dist_avg: 0.8840971789283817
