In [1]:
# change working path to the current file
%cd ..

/home/hudongcheng/Desktop/bo_osda_generator


In [2]:
import numpy as np
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn import functional as F
from tqdm import tqdm
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
import torch.backends.cudnn as cudnn

# import custom modules
from models.ddc import SMILESGenerator
from utils.utils import *
from datasets.data_loader import *
from utils.plot_figures import *
from utils.metrics import *
from utils.build_vocab import *

In [3]:
cudnn.benchmark = True
cudnn.enabled = True

train_loss_history = []
train_acc_history = []
test_loss_history = []
test_acc_history = []

now = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')

PAD = 0
UNK = 1
EOS = 2
SOS = 3
MASK = 4
MAX_LEN = 220

In [4]:
def generate_ddc(model, start_sequence, condition, max_length, vocab,  device, temperature=0.5, top_k=0):
    """
    Autoregressive generation process for a GPT model.

    Args:
        model (ddc): The pre-trained GPT model for token generation.
        start_sequence (torch.Tensor): The initial sequence to start generation (batch_size, seq_length).
        condition_props (torch.Tensor): The conditional property vector (batch_size, num_props).
        max_length (int): The maximum length of the generated sequence.
        vocab: The vocabulary object for encoding and decoding SMILES strings.
        device (torch.device): The device on which to run the generation.
        temperature (float): Temperature parameter for sampling; higher values increase randomness.
        top_k (int): Limits sampling to top-k logits; if 0, no top-k sampling is applied.

    Returns:
        List[str]: A list of generated SMILES strings.
    """

    model.eval()
    batch_size = start_sequence.size(0)
    generated_seq = start_sequence.clone().to(device)
    
    with torch.no_grad():
        # We will iteratively fill positions from [start_len ... seqlen-1]
        for cur_len in range(max_length - start_sequence.size(1)):
            current_len = generated_seq.size(1)
            generated_seq_hot = F.one_hot(generated_seq, num_classes=len(vocab)).float()
            # forward pass to get logits
            logits = model(condition, generated_seq_hot)
            # extract the logits for the next token
            next_token_logits = logits[:, -1, :]
            # temperature scaling
            next_token_logits /= temperature
            # top-k sampling
            if top_k > 0:
                next_token_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
                next_token_probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(next_token_probs, num_samples=1)
            else:
                # sample from the distribution
                next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1)
            # append the new token to the sequence
            generated_seq = torch.cat([generated_seq, next_token], dim=-1)
            
            # check if all sequences have reached EOS
            if all(next_token[i].item() == EOS for i in range(batch_size)):
                break
    
    # Decode the generated sequences into SMILES strings
    generated_smiles = []
    for seq in generated_seq:
        # Convert indices to characters, ignoring padding and start tokens
        # check if the generated sequence contains the end token, if meet, stop decoding
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        generated_smiles.append(smiles)

    return generated_smiles


In [5]:
# load the data
AFI_smiles = read_strings('./data_AFI/AFI_smiles.csv', idx=False)
AFI_zeo = read_vec('./data_AFI/AFI_zeo.csv', idx=False)
AFI_syn = read_vec('./data_AFI/AFI_syn.csv', idx=False)
CHA_smiles = read_strings('./data_CHA/CHA_smiles.csv', idx=False)
CHA_zeo = read_vec('./data_CHA/CHA_zeo.csv', idx=False)
CHA_syn = read_vec('./data_CHA/CHA_syn.csv', idx=False)
AEI_smiles = read_strings('./data_AEI/AEI_smiles.csv', idx=False)
AEI_zeo = read_vec('./data_AEI/AEI_zeo.csv', idx=False)
AEI_syn = read_vec('./data_AEI/AEI_syn.csv', idx=False)

vocab = WordVocab.load_vocab('./model_hub/vocab.pkl')
print('the vocab size is :', len(vocab))
charlen = len(vocab)
print('the total num of charset is :', charlen)

cudnn.benchmark = True
batch_size = 64

manual_seed = 42
random.seed(manual_seed)
torch.manual_seed(manual_seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# create the dataset and dataloader
AFI_dataset = Seq2seqDataset(AFI_zeo, AFI_syn, AFI_smiles, vocab)
CHA_dataset = Seq2seqDataset(CHA_zeo, CHA_syn, CHA_smiles, vocab)
AEI_dataset = Seq2seqDataset(AEI_zeo, AEI_syn, AEI_smiles, vocab)
AFI_dataloader = DataLoader(AFI_dataset, batch_size=batch_size, shuffle=True)
CHA_dataloader = DataLoader(CHA_dataset, batch_size=batch_size, shuffle=False)
AEI_dataloader = DataLoader(AEI_dataset, batch_size=batch_size, shuffle=False)

the vocab size is : 45
the total num of charset is : 45


In [6]:
# load the model
model_origin = SMILESGenerator(condition_dim=24, lstm_dim=256, dec_layers=3, charset_size=charlen)
model_origin.load_state_dict(torch.load('./checkpoints/best_ddc_model.pth'))
model_origin.to(device)
model_origin.eval()

model_cl = SMILESGenerator(condition_dim=24, lstm_dim=256, dec_layers=3, charset_size=charlen)
model_cl.load_state_dict(torch.load('./checkpoints/best_ddc_contrastive_model.pth'))
model_cl.to(device)
model_cl.eval()

total = sum(p.numel() for p in model_origin.parameters())
print('total parameters: %0.2fM' % (total / 1e6))  # print the total parameters

total parameters: 1.41M


In [7]:
# generate the AFI smiles with the original model
generated_smile_origin = []
target_smile_origin = []
for i, (zeo, syn, tgt) in enumerate(tqdm(AFI_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_ddc(model_origin, tgt[:, :10], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_origin.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_origin.extend(tgt_smiles)
    
# calculate the metrics
AFI_validity_rate_origin = validity_rate(generated_smile_origin)
AFI_uniqueness_rate_origin = uniqueness_rate(generated_smile_origin)
AFI_novelty_rate_origin = novelty_rate(generated_smile_origin, target_smile_origin)
AFI_reconstructability_rate_origin = reconstructability_rate(generated_smile_origin, target_smile_origin)
AFI_IntDiv_origin = IntDiv(generated_smile_origin)
AFI_FCD_score_origin = FCD_score(target_smile_origin, generated_smile_origin)
# print the metrics
print('AFI_validity_rate_origin: ', AFI_validity_rate_origin)
print('AFI_uniqueness_rate_origin: ', AFI_uniqueness_rate_origin)
print('AFI_novelty_rate_origin: ', AFI_novelty_rate_origin)
print('AFI_reconstructability_rate_origin: ', AFI_reconstructability_rate_origin)
print('AFI_IntDiv_origin: ', AFI_IntDiv_origin)
print('AFI_FCD_score_origin: ', AFI_FCD_score_origin)

100%|██████████| 16/16 [00:22<00:00,  1.39s/it]
[19:38:07] SMILES Parse Error: syntax error while parsing: C([N+](CC(()(C)C)O)CcccccF
[19:38:07] SMILES Parse Error: Failed parsing SMILES 'C([N+](CC(()(C)C)O)CcccccF' for input: 'C([N+](CC(()(C)C)O)CcccccF'
[19:38:07] SMILES Parse Error: extra open parentheses for input: 'C1CN(Cc2ccccc2FCCO1'
[19:38:07] SMILES Parse Error: unclosed ring for input: 'C1CN2CCOCCOCCN(CCOCCOCC2)CCOCCOC'
[19:38:07] SMILES Parse Error: extra close parentheses while parsing: C[N+](C)(C)21C3CC(C4CCCC33CC))1CCCC2
[19:38:07] SMILES Parse Error: Failed parsing SMILES 'C[N+](C)(C)21C3CC(C4CCCC33CC))1CCCC2' for input: 'C[N+](C)(C)21C3CC(C4CCCC33CC))1CCCC2'
[19:38:07] SMILES Parse Error: extra close parentheses while parsing: C1N2CCN(CC))CC
[19:38:07] SMILES Parse Error: Failed parsing SMILES 'C1N2CCN(CC))CC' for input: 'C1N2CCN(CC))CC'
[19:38:07] SMILES Parse Error: syntax error while parsing: c1([C@@H]+](ccccc2)(c)cc1)C11CCCC1
[19:38:07] SMILES Parse Error: Failed pa

AFI_validity_rate_origin:  0.702
AFI_uniqueness_rate_origin:  0.507
AFI_novelty_rate_origin:  0.7495069033530573
AFI_reconstructability_rate_origin:  0.2504930966469428
AFI_IntDiv_origin:  0.7335750937471875
AFI_FCD_score_origin:  1.1722012762897123


In [8]:
# generate the AFI smiles with the contrastive learning model
generated_smile_cl = []
target_smile_cl = []
for i, (zeo, syn, tgt) in enumerate(tqdm(AFI_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_ddc(model_origin, tgt[:, :10], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_cl.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_cl.extend(tgt_smiles)

# calculate the metrics
AFI_validity_rate_cl = validity_rate(generated_smile_cl)
AFI_uniqueness_rate_cl = uniqueness_rate(generated_smile_cl)
AFI_novelty_rate_cl = novelty_rate(generated_smile_cl, target_smile_cl)
AFI_reconstructability_rate_cl = reconstructability_rate(generated_smile_cl, target_smile_cl)
AFI_IntDiv_cl = IntDiv(generated_smile_cl)
AFI_FCD_score_cl = FCD_score(target_smile_cl, generated_smile_cl)
# print the metrics
print('AFI_validity_rate_cl: ', AFI_validity_rate_cl)
print('AFI_uniqueness_rate_cl: ', AFI_uniqueness_rate_cl)
print('AFI_novelty_rate_cl: ', AFI_novelty_rate_cl)
print('AFI_reconstructability_rate_cl: ', AFI_reconstructability_rate_cl)
print('AFI_IntDiv_cl: ', AFI_IntDiv_cl)
print('AFI_FCD_score_cl: ', AFI_FCD_score_cl)

100%|██████████| 16/16 [00:22<00:00,  1.38s/it]
[19:38:31] SMILES Parse Error: syntax error while parsing: [C@H]1(COn(ccccc1)c)ccc1-(1)C)1
[19:38:31] SMILES Parse Error: Failed parsing SMILES '[C@H]1(COn(ccccc1)c)ccc1-(1)C)1' for input: '[C@H]1(COn(ccccc1)c)ccc1-(1)C)1'
[19:38:31] SMILES Parse Error: syntax error while parsing: C(CC)CCC[[+]12cCc(cccc2ccccc2)cc1
[19:38:31] SMILES Parse Error: Failed parsing SMILES 'C(CC)CCC[[+]12cCc(cccc2ccccc2)cc1' for input: 'C(CC)CCC[[+]12cCc(cccc2ccccc2)cc1'
[19:38:31] SMILES Parse Error: extra close parentheses while parsing: C(CO)N(CCCCCCCCCCCCCC))C
[19:38:31] SMILES Parse Error: Failed parsing SMILES 'C(CO)N(CCCCCCCCCCCCCC))C' for input: 'C(CO)N(CCCCCCCCCCCCCC))C'
[19:38:31] SMILES Parse Error: syntax error while parsing: CCN(CC)C(
[19:38:31] SMILES Parse Error: Failed parsing SMILES 'CCN(CC)C(' for input: 'CCN(CC)C('
[19:38:31] SMILES Parse Error: extra close parentheses while parsing: C1OCCN2CCOCCOCCN(CCOCC)CCC)CC1
[19:38:31] SMILES Parse Error

AFI_validity_rate_cl:  0.707
AFI_uniqueness_rate_cl:  0.506
AFI_novelty_rate_cl:  0.7608695652173914
AFI_reconstructability_rate_cl:  0.2391304347826087
AFI_IntDiv_cl:  0.7334706760075068
AFI_FCD_score_cl:  1.241578763786226


In [9]:
# write the metrics to the folder data_AFI
with open('./data_AFI/AFI_generated_ddc_metrics.txt', 'w') as f:
    # write the mertics
    f.write(f'AFI_validity_rate_origin: {AFI_validity_rate_origin}, AFI_validity_rate_cl: {AFI_validity_rate_cl}\n')
    f.write(f'AFI_uniqueness_rate_origin: {AFI_uniqueness_rate_origin}, AFI_uniqueness_rate_cl: {AFI_uniqueness_rate_cl}\n')
    f.write(f'AFI_novelty_rate_origin: {AFI_novelty_rate_origin}, AFI_novelty_rate_cl: {AFI_novelty_rate_cl}\n')
    f.write(f'AFI_reconstructability_rate_origin: {AFI_reconstructability_rate_origin}, AFI_reconstructability_rate_cl: {AFI_reconstructability_rate_cl}\n')
    f.write(f'AFI_IntDiv_origin: {AFI_IntDiv_origin}, AFI_IntDiv_cl: {AFI_IntDiv_cl}\n')
    f.write(f'AFI_FCD_score_origin: {AFI_FCD_score_origin}, AFI_FCD_score_cl: {AFI_FCD_score_cl}\n')

# write the generated smiles (origin and cl) and target smiles to the folder data_AFI
with open('./data_AFI/AFI_generated_ddc_smiles_origin.txt', 'w') as f:
    for smiles in range(len(generated_smile_origin)):
        f.write(f'origin: {generated_smile_origin[smiles]}, cl: {generated_smile_cl[smiles]}, target: {target_smile_origin[smiles]}\n')

In [10]:
# generate the CHA smiles with the original model
generated_smile_origin = []
target_smile_origin = []
for i, (zeo, syn, tgt) in enumerate(tqdm(CHA_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_ddc(model_origin, tgt[:, :10], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_origin.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_origin.extend(tgt_smiles)

# calculate the metrics
CHA_validity_rate_origin = validity_rate(generated_smile_origin)
CHA_uniqueness_rate_origin = uniqueness_rate(generated_smile_origin)
CHA_novelty_rate_origin = novelty_rate(generated_smile_origin, target_smile_origin)
CHA_reconstructability_rate_origin = reconstructability_rate(generated_smile_origin, target_smile_origin)
CHA_IntDiv_origin = IntDiv(generated_smile_origin)
CHA_FCD_score_origin = FCD_score(target_smile_origin, generated_smile_origin)
# print the metrics
print('CHA_validity_rate_origin: ', CHA_validity_rate_origin)
print('CHA_uniqueness_rate_origin: ', CHA_uniqueness_rate_origin)
print('CHA_novelty_rate_origin: ', CHA_novelty_rate_origin)
print('CHA_reconstructability_rate_origin: ', CHA_reconstructability_rate_origin)
print('CHA_IntDiv_origin: ', CHA_IntDiv_origin)
print('CHA_FCD_score_origin: ', CHA_FCD_score_origin)

100%|██████████| 16/16 [00:21<00:00,  1.37s/it]
[19:38:55] Can't kekulize mol.  Unkekulized atoms: 0 1 2 3 4
[19:38:55] SMILES Parse Error: syntax error while parsing: C1=NC=C[NH]+1
[19:38:55] SMILES Parse Error: Failed parsing SMILES 'C1=NC=C[NH]+1' for input: 'C1=NC=C[NH]+1'
[19:38:55] SMILES Parse Error: extra open parentheses for input: 'C1[N+](C)(C)CC2C(C3C(C)CC2CC(C3)C2'
[19:38:55] SMILES Parse Error: ring closure 1 duplicates bond between atom 10 and atom 11 for input: 'C[N+]1(C)CCCC21CCCC221C1CCC2'
[19:38:55] SMILES Parse Error: extra close parentheses while parsing: C1CC2C(CCC))CC[[+]12CC(C)CCC2
[19:38:55] SMILES Parse Error: Failed parsing SMILES 'C1CC2C(CCC))CC[[+]12CC(C)CCC2' for input: 'C1CC2C(CCC))CC[[+]12CC(C)CCC2'
[19:38:55] SMILES Parse Error: unclosed ring for input: '[N+]1(C)(C)C2CCCCCCC1CCCC'
[19:38:55] SMILES Parse Error: syntax error while parsing: C1C2CC[N+](CC(()CC))CC2CCCC
[19:38:55] SMILES Parse Error: Failed parsing SMILES 'C1C2CC[N+](CC(()CC))CC2CCCC' for in

CHA_validity_rate_origin:  0.714
CHA_uniqueness_rate_origin:  0.576
CHA_novelty_rate_origin:  0.7361111111111112
CHA_reconstructability_rate_origin:  0.2638888888888889
CHA_IntDiv_origin:  0.8904189356647727
CHA_FCD_score_origin:  0.544384674478767


In [11]:
# generate the CHA smiles with the contrastive learning model
generated_smile_cl = []
target_smile_cl = []
for i, (zeo, syn, tgt) in enumerate(tqdm(CHA_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_ddc(model_cl, tgt[:, :10], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_cl.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_cl.extend(tgt_smiles)

# calculate the metrics
CHA_validity_rate_cl = validity_rate(generated_smile_cl)
CHA_uniqueness_rate_cl = uniqueness_rate(generated_smile_cl)
CHA_novelty_rate_cl = novelty_rate(generated_smile_cl, target_smile_cl)
CHA_reconstructability_rate_cl = reconstructability_rate(generated_smile_cl, target_smile_cl)
CHA_IntDiv_cl = IntDiv(generated_smile_cl)
CHA_FCD_score_cl = FCD_score(target_smile_cl, generated_smile_cl)
# print the metrics
print('CHA_validity_rate_cl: ', CHA_validity_rate_cl)
print('CHA_uniqueness_rate_cl: ', CHA_uniqueness_rate_cl)
print('CHA_novelty_rate_cl: ', CHA_novelty_rate_cl)
print('CHA_reconstructability_rate_cl: ', CHA_reconstructability_rate_cl)
print('CHA_IntDiv_cl: ', CHA_IntDiv_cl)
print('CHA_FCD_score_cl: ', CHA_FCD_score_cl)

100%|██████████| 16/16 [00:21<00:00,  1.37s/it]
[19:39:19] Can't kekulize mol.  Unkekulized atoms: 0 3 4 5 6
[19:39:19] SMILES Parse Error: syntax error while parsing: C1=CN=C[NH+]1C=
[19:39:19] SMILES Parse Error: Failed parsing SMILES 'C1=CN=C[NH+]1C=' for input: 'C1=CN=C[NH+]1C='
[19:39:19] SMILES Parse Error: syntax error while parsing: C1N(C)CCC((CCN)C1)N
[19:39:19] SMILES Parse Error: Failed parsing SMILES 'C1N(C)CCC((CCN)C1)N' for input: 'C1N(C)CCC((CCN)C1)N'
[19:39:19] SMILES Parse Error: extra close parentheses while parsing: C(C)N(CC))C
[19:39:19] SMILES Parse Error: Failed parsing SMILES 'C(C)N(CC))C' for input: 'C(C)N(CC))C'
[19:39:19] SMILES Parse Error: syntax error while parsing: C1N(C)CCN(C)CCN(CCC(C)NCC(()11CCCC11CC)C)C
[19:39:19] SMILES Parse Error: Failed parsing SMILES 'C1N(C)CCN(C)CCN(CCC(C)NCC(()11CCCC11CC)C)C' for input: 'C1N(C)CCN(C)CCN(CCC(C)NCC(()11CCCC11CC)C)C'
[19:39:19] SMILES Parse Error: syntax error while parsing: C[N+]1(C)CCCC(C1(()C)C)CC
[19:39:19] SMI

CHA_validity_rate_cl:  0.706
CHA_uniqueness_rate_cl:  0.597
CHA_novelty_rate_cl:  0.7537688442211056
CHA_reconstructability_rate_cl:  0.24623115577889448
CHA_IntDiv_cl:  0.891493239863967
CHA_FCD_score_cl:  0.4983394725029946


In [12]:
# write the metrics to the folder data_CHA
with open('./data_CHA/CHA_generated_ddc_metrics.txt', 'w') as f:
    # write the mertics
    f.write(f'CHA_validity_rate_origin: {CHA_validity_rate_origin}, CHA_validity_rate_cl: {CHA_validity_rate_cl}\n')
    f.write(f'CHA_uniqueness_rate_origin: {CHA_uniqueness_rate_origin}, CHA_uniqueness_rate_cl: {CHA_uniqueness_rate_cl}\n')
    f.write(f'CHA_novelty_rate_origin: {CHA_novelty_rate_origin}, CHA_novelty_rate_cl: {CHA_novelty_rate_cl}\n')
    f.write(f'CHA_reconstructability_rate_origin: {CHA_reconstructability_rate_origin}, CHA_reconstructability_rate_cl: {CHA_reconstructability_rate_cl}\n')
    f.write(f'CHA_IntDiv_origin: {CHA_IntDiv_origin}, CHA_IntDiv_cl: {CHA_IntDiv_cl}\n')
    f.write(f'CHA_FCD_score_origin: {CHA_FCD_score_origin}, CHA_FCD_score_cl: {CHA_FCD_score_cl}\n')
# write the generated smiles (origin and cl) and target smiles to the folder data_CHA
with open('./data_CHA/CHA_generated_ddc_smiles_origin.txt', 'w') as f:
    for smiles in range(len(generated_smile_origin)):
        f.write(f'origin: {generated_smile_origin[smiles]}, cl: {generated_smile_cl[smiles]}, target: {target_smile_origin[smiles]}\n')

In [13]:
# generate the AEI smiles with the original model
generated_smile_origin = []
target_smile_origin = []
for i, (zeo, syn, tgt) in enumerate(tqdm(AEI_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_ddc(model_origin, tgt[:, :10], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_origin.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_origin.extend(tgt_smiles)

# calculate the metrics
AEI_validity_rate_origin = validity_rate(generated_smile_origin)
AEI_uniqueness_rate_origin = uniqueness_rate(generated_smile_origin)
AEI_novelty_rate_origin = novelty_rate(generated_smile_origin, target_smile_origin)
AEI_reconstructability_rate_origin = reconstructability_rate(generated_smile_origin, target_smile_origin)
AEI_IntDiv_origin = IntDiv(generated_smile_origin)
AEI_FCD_score_origin = FCD_score(target_smile_origin, generated_smile_origin)
# print the metrics
print('AEI_validity_rate_origin: ', AEI_validity_rate_origin)
print('AEI_uniqueness_rate_origin: ', AEI_uniqueness_rate_origin)
print('AEI_novelty_rate_origin: ', AEI_novelty_rate_origin)
print('AEI_reconstructability_rate_origin: ', AEI_reconstructability_rate_origin)
print('AEI_IntDiv_origin: ', AEI_IntDiv_origin)
print('AEI_FCD_score_origin: ', AEI_FCD_score_origin)

100%|██████████| 16/16 [00:22<00:00,  1.38s/it]
[19:39:43] SMILES Parse Error: unclosed ring for input: 'N1(C)CCN(C)CCN(C)CC1(C)CCN(C)C2CCC1C'
[19:39:43] SMILES Parse Error: syntax error while parsing: C[C@@H]1[(+]1Cc3ccccc2)CCC(()(C)C)c1
[19:39:43] SMILES Parse Error: Failed parsing SMILES 'C[C@@H]1[(+]1Cc3ccccc2)CCC(()(C)C)c1' for input: 'C[C@@H]1[(+]1Cc3ccccc2)CCC(()(C)C)c1'
[19:39:43] SMILES Parse Error: syntax error while parsing: C1CC[C@@HC+]2c(c)ccc1)CC
[19:39:43] SMILES Parse Error: Failed parsing SMILES 'C1CC[C@@HC+]2c(c)ccc1)CC' for input: 'C1CC[C@@HC+]2c(c)ccc1)CC'
[19:39:43] SMILES Parse Error: unclosed ring for input: '[N+]1(C)(C)CCCCCCC2'
[19:39:43] Explicit valence for atom # 12 F, 2, is greater than permitted
[19:39:43] SMILES Parse Error: syntax error while parsing: C(C)[P+](CCC(())C)CCCCC[N+](C)(C)C
[19:39:43] SMILES Parse Error: Failed parsing SMILES 'C(C)[P+](CCC(())C)CCCCC[N+](C)(C)C' for input: 'C(C)[P+](CCC(())C)CCCCC[N+](C)(C)C'
[19:39:43] SMILES Parse Error: du

AEI_validity_rate_origin:  0.634
AEI_uniqueness_rate_origin:  0.495
AEI_novelty_rate_origin:  0.8666666666666667
AEI_reconstructability_rate_origin:  0.13333333333333333
AEI_IntDiv_origin:  0.7881203151305758
AEI_FCD_score_origin:  1.6035418763118123


In [14]:
# generate the AEI smiles with the contrastive learning model
generated_smile_cl = []
target_smile_cl = []
for i, (zeo, syn, tgt) in enumerate(tqdm(AEI_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_ddc(model_cl, tgt[:, :10], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_cl.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_cl.extend(tgt_smiles)

# calculate the metrics
AEI_validity_rate_cl = validity_rate(generated_smile_cl)
AEI_uniqueness_rate_cl = uniqueness_rate(generated_smile_cl)
AEI_novelty_rate_cl = novelty_rate(generated_smile_cl, target_smile_cl)
AEI_reconstructability_rate_cl = reconstructability_rate(generated_smile_cl, target_smile_cl)
AEI_IntDiv_cl = IntDiv(generated_smile_cl)
AEI_FCD_score_cl = FCD_score(target_smile_cl, generated_smile_cl)
# print the metrics
print('AEI_validity_rate_cl: ', AEI_validity_rate_cl)
print('AEI_uniqueness_rate_cl: ', AEI_uniqueness_rate_cl)
print('AEI_novelty_rate_cl: ', AEI_novelty_rate_cl)
print('AEI_reconstructability_rate_cl: ', AEI_reconstructability_rate_cl)
print('AEI_IntDiv_cl: ', AEI_IntDiv_cl)
print('AEI_FCD_score_cl: ', AEI_FCD_score_cl)

100%|██████████| 16/16 [00:22<00:00,  1.38s/it]
[19:40:07] SMILES Parse Error: syntax error while parsing: C1CN(C)CCN(
[19:40:07] SMILES Parse Error: Failed parsing SMILES 'C1CN(C)CCN(' for input: 'C1CN(C)CCN('
[19:40:07] SMILES Parse Error: syntax error while parsing: C[C@@H]1[(+](C)(C)C)(
[19:40:07] SMILES Parse Error: Failed parsing SMILES 'C[C@@H]1[(+](C)(C)C)(' for input: 'C[C@@H]1[(+](C)(C)C)('
[19:40:07] SMILES Parse Error: unclosed ring for input: 'C[N+]1(C)CCCC2CCCCC211'
[19:40:07] SMILES Parse Error: unclosed ring for input: 'C[C@@H]1C(C)(C)C'
[19:40:07] SMILES Parse Error: unclosed ring for input: 'C1[C@H](C2(CCCC2)C2)C'
[19:40:07] SMILES Parse Error: syntax error while parsing: C1[C@@H](]2CCCC22CCCCC2)cCc1cc<unk>
[19:40:07] SMILES Parse Error: Failed parsing SMILES 'C1[C@@H](]2CCCC22CCCCC2)cCc1cc<unk>' for input: 'C1[C@@H](]2CCCC22CCCCC2)cCc1cc<unk>'
[19:40:07] Explicit valence for atom # 3 N, 4, is greater than permitted
[19:40:07] SMILES Parse Error: syntax error while pa

AEI_validity_rate_cl:  0.565
AEI_uniqueness_rate_cl:  0.551
AEI_novelty_rate_cl:  0.9292196007259528
AEI_reconstructability_rate_cl:  0.07078039927404718
AEI_IntDiv_cl:  0.7941249302914077
AEI_FCD_score_cl:  1.3362764722287217


In [15]:
# write the metrics to the folder data_AEI
with open('./data_AEI/AEI_generated_ddc_metrics.txt', 'w') as f:
    # write the mertics
    f.write(f'AEI_validity_rate_origin: {AEI_validity_rate_origin}, AEI_validity_rate_cl: {AEI_validity_rate_cl}\n')
    f.write(f'AEI_uniqueness_rate_origin: {AEI_uniqueness_rate_origin}, AEI_uniqueness_rate_cl: {AEI_uniqueness_rate_cl}\n')
    f.write(f'AEI_novelty_rate_origin: {AEI_novelty_rate_origin}, AEI_novelty_rate_cl: {AEI_novelty_rate_cl}\n')
    f.write(f'AEI_reconstructability_rate_origin: {AEI_reconstructability_rate_origin}, AEI_reconstructability_rate_cl: {AEI_reconstructability_rate_cl}\n')
    f.write(f'AEI_IntDiv_origin: {AEI_IntDiv_origin}, AEI_IntDiv_cl: {AEI_IntDiv_cl}\n')
    f.write(f'AEI_FCD_score_origin: {AEI_FCD_score_origin}, AEI_FCD_score_cl: {AEI_FCD_score_cl}\n')
# write the generated smiles (origin and cl) and target smiles to the folder data_AEI
with open('./data_AEI/AEI_generated_ddc_smiles_origin.txt', 'w') as f:
    for smiles in range(len(generated_smile_origin)):
        f.write(f'origin: {generated_smile_origin[smiles]}, cl: {generated_smile_cl[smiles]}, target: {target_smile_origin[smiles]}\n')

: 