In [1]:
%cd ..

/home/hudongcheng/Desktop/bo_osda_generator


In [2]:
import pandas as pd
from rdkit import Chem
import numpy as np
import random
import pickle
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from tqdm import tqdm
from datetime import datetime

# import the necessary modules
from datasets.data_loader import *
from utils.utils import *
from models.RNN import *
from utils.metrics import *
from utils.plot_figures import *
from utils.build_vocab import *

In [3]:
cudnn.benchmark = True
cudnn.enabled = True

train_loss_history = []
train_acc_history = []
test_loss_history = []
test_acc_history = []

now = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')

PAD = 0
UNK = 1
EOS = 2
SOS = 3
MASK = 4
MAX_LEN = 220

In [4]:
def generate_rnn(model, start_sequence, condition_props, max_length, vocab, device, temperature=1.0, top_k=0):
    """
    Autoregressive generation process for a GPT model.

    Args:
        model (RNN): The pre-trained GPT model for token generation.
        start_sequence (torch.Tensor): The initial sequence to start generation (batch_size, seq_length).
        condition_props (torch.Tensor): The conditional property vector (batch_size, num_props).
        max_length (int): The maximum length of the generated sequence.
        vocab: The vocabulary object for encoding and decoding SMILES strings.
        device (torch.device): The device on which to run the generation.
        temperature (float): Temperature parameter for sampling; higher values increase randomness.
        top_k (int): Limits sampling to top-k logits; if 0, no top-k sampling is applied.

    Returns:
        List[str]: A list of generated SMILES strings.
    """
    model.eval()
    batch_size = start_sequence.size(0)
    generated_sequences = start_sequence.clone().to(device)  # Clone and move to device

    for _ in range(max_length - start_sequence.size(1)):
        # Get the current sequence length
        current_length = generated_sequences.size(1)

        # Forward pass through the model
        logits, _ = model(condition_props, generated_sequences)  # (batch_size, seq_length, vocab_size)

        # Extract the logits for the last time step
        next_token_logits = logits[:, -1, :]  # (batch_size, vocab_size)

        # Apply temperature scaling
        next_token_logits = next_token_logits / temperature

        # Apply top-k filtering
        if top_k > 0:
            top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
            mask = torch.full_like(next_token_logits, float('-inf'))
            mask.scatter_(dim=-1, index=top_k_indices, src=top_k_logits)
            next_token_logits = mask

        # Convert logits to probabilities
        next_token_probs = F.softmax(next_token_logits, dim=-1)

        # Sample from the probability distribution
        next_token = torch.multinomial(next_token_probs, num_samples=1)  # (batch_size, 1)
        
        # Get the most likely next token
        # next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)

        # Append the generated token to the sequence
        generated_sequences = torch.cat([generated_sequences, next_token], dim=1)

        # Check if all sequences have reached the end token
        if all(next_token[i].item() == EOS for i in range(batch_size)):
            break

    # Decode the generated sequences into SMILES strings
    generated_smiles = []
    for seq in generated_sequences:
        # Convert indices to characters, ignoring padding and start tokens
        # check if the generated sequence contains the end token, if meet, stop decoding
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        generated_smiles.append(smiles)

    return generated_smiles

In [5]:
# load the data
AFI_smiles = read_strings('./data_AFI/AFI_smiles.csv', idx=False)
AFI_zeo = read_vec('./data_AFI/AFI_zeo.csv', idx=False)
AFI_syn = read_vec('./data_AFI/AFI_syn.csv', idx=False)
CHA_smiles = read_strings('./data_CHA/CHA_smiles.csv', idx=False)
CHA_zeo = read_vec('./data_CHA/CHA_zeo.csv', idx=False)
CHA_syn = read_vec('./data_CHA/CHA_syn.csv', idx=False)
AEI_smiles = read_strings('./data_AEI/AEI_smiles.csv', idx=False)
AEI_zeo = read_vec('./data_AEI/AEI_zeo.csv', idx=False)
AEI_syn = read_vec('./data_AEI/AEI_syn.csv', idx=False)

vocab = WordVocab.load_vocab('./model_hub/vocab.pkl')
print('the vocab size is :', len(vocab))
charlen = len(vocab)
print('the total num of charset is :', charlen)

cudnn.benchmark = True
batch_size = 64

manual_seed = 42
random.seed(manual_seed)
torch.manual_seed(manual_seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# create the dataset and dataloader
AFI_dataset = Seq2seqDataset(AFI_zeo, AFI_syn, AFI_smiles, vocab)
CHA_dataset = Seq2seqDataset(CHA_zeo, CHA_syn, CHA_smiles, vocab)
AEI_dataset = Seq2seqDataset(AEI_zeo, AEI_syn, AEI_smiles, vocab)
AFI_dataloader = DataLoader(AFI_dataset, batch_size=batch_size, shuffle=True)
CHA_dataloader = DataLoader(CHA_dataset, batch_size=batch_size, shuffle=False)
AEI_dataloader = DataLoader(AEI_dataset, batch_size=batch_size, shuffle=False)

the vocab size is : 45
the total num of charset is : 45


In [6]:
# load the contrastive learning model and original model from checkpoints
model_origin = RNNModel(input_size=charlen, 
                 synthesis_dim=24, 
                 embedding_dim=128,
                 hidden_size=256,
                 num_layers=3,
                 dropout=0,
                 vocab_size=charlen)
model_origin.load_state_dict(torch.load('./checkpoints/best_charnn_model.pth'))
model_origin.eval()
model_origin.to(device)

model_cl = RNNModel(input_size=charlen,
                    synthesis_dim=24,
                    embedding_dim=128,
                    hidden_size=256,
                    num_layers=3,
                    dropout=0,
                    vocab_size=charlen)
model_cl.load_state_dict(torch.load('./checkpoints/best_charnn_contrastive_model.pth'))
model_cl.eval()
model_cl.to(device)

total = sum(p.numel() for p in model_origin.parameters())
print('total parameters: %0.2fM' % (total / 1e6))  # print the total parameters

total parameters: 0.39M


In [7]:
# generate the AFI smiles with the original model
generated_smile_origin = []
target_smile_origin = []
for i, (zeo, syn, tgt) in enumerate(tqdm(AFI_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_rnn(model_origin, tgt[:, :10], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_origin.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_origin.extend(tgt_smiles)
    
# calculate the metrics
AFI_validity_rate_origin = validity_rate(generated_smile_origin)
AFI_uniqueness_rate_origin = uniqueness_rate(generated_smile_origin)
AFI_novelty_rate_origin = novelty_rate(generated_smile_origin, target_smile_origin)
AFI_reconstructability_rate_origin = reconstructability_rate(generated_smile_origin, target_smile_origin)
AFI_IntDiv_origin = IntDiv(generated_smile_origin)
AFI_FCD_score_origin = FCD_score(target_smile_origin, generated_smile_origin)
# print the metrics
print('AFI_validity_rate_origin: ', AFI_validity_rate_origin)
print('AFI_uniqueness_rate_origin: ', AFI_uniqueness_rate_origin)
print('AFI_novelty_rate_origin: ', AFI_novelty_rate_origin)
print('AFI_reconstructability_rate_origin: ', AFI_reconstructability_rate_origin)
print('AFI_IntDiv_origin: ', AFI_IntDiv_origin)
print('AFI_FCD_score_origin: ', AFI_FCD_score_origin)

100%|██████████| 16/16 [00:11<00:00,  1.36it/s]
[19:18:42] SMILES Parse Error: extra open parentheses for input: 'c1c[n+](CCCCCC[n+]2cccccc1'
[19:18:42] SMILES Parse Error: unclosed ring for input: 'C([N+]1(C)CCCCCCCCCCCCC[N+]1(C)C)CCCCCCCCC1'
[19:18:42] Can't kekulize mol.  Unkekulized atoms: 0 13 14 15 20
[19:18:42] SMILES Parse Error: unclosed ring for input: 'CCN(CC)CCCC1'
[19:18:42] SMILES Parse Error: unclosed ring for input: 'CCN(CC)CCCC1'
[19:18:42] SMILES Parse Error: unclosed ring for input: 'c1(C)[n+](CCCCCCCCC[N+]12CCCCCCCCCC[n+]1c(C)n(C)cc1)C'
[19:18:42] SMILES Parse Error: unclosed ring for input: 'C(C[N+](C)(C)C)CCCC(C)CC1'
[19:18:42] SMILES Parse Error: unclosed ring for input: 'C1C[C@@H]3(C1CCCCCCCCN(C)C)CCC'
[19:18:42] SMILES Parse Error: unclosed ring for input: 'CCN(CC)CCCCCC1'
[19:18:42] Can't kekulize mol.  Unkekulized atoms: 0 1 19 20 21 22 23
[19:18:42] SMILES Parse Error: unclosed ring for input: 'CC(CNCC(C)CC)c2ccccc1'
[19:18:42] SMILES Parse Error: unclosed r

AFI_validity_rate_origin:  0.242
AFI_uniqueness_rate_origin:  0.872
AFI_novelty_rate_origin:  0.9827981651376146
AFI_reconstructability_rate_origin:  0.017201834862385322
AFI_IntDiv_origin:  0.8509377086947255
AFI_FCD_score_origin:  10.923637546885224


In [8]:
# generate the AFI smiles with the contrastive learning model
generated_smile_cl = []
target_smile_cl = []
for i, (zeo, syn, tgt) in enumerate(tqdm(AFI_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_rnn(model_cl, tgt[:, :10], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_cl.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_cl.extend(tgt_smiles)

# calculate the metrics
AFI_validity_rate_cl = validity_rate(generated_smile_cl)
AFI_uniqueness_rate_cl = uniqueness_rate(generated_smile_cl)
AFI_novelty_rate_cl = novelty_rate(generated_smile_cl, target_smile_cl)
AFI_reconstructability_rate_cl = reconstructability_rate(generated_smile_cl, target_smile_cl)
AFI_IntDiv_cl = IntDiv(generated_smile_cl)
AFI_FCD_score_cl = FCD_score(target_smile_cl, generated_smile_cl)
# print the metrics
print('AFI_validity_rate_cl: ', AFI_validity_rate_cl)
print('AFI_uniqueness_rate_cl: ', AFI_uniqueness_rate_cl)
print('AFI_novelty_rate_cl: ', AFI_novelty_rate_cl)
print('AFI_reconstructability_rate_cl: ', AFI_reconstructability_rate_cl)
print('AFI_IntDiv_cl: ', AFI_IntDiv_cl)
print('AFI_FCD_score_cl: ', AFI_FCD_score_cl)

100%|██████████| 16/16 [00:11<00:00,  1.37it/s]
[19:18:56] Can't kekulize mol.  Unkekulized atoms: 0 1 3 4 5 12 13
[19:18:56] SMILES Parse Error: extra open parentheses for input: 'C(O)[C@H](=C(O)CCCCCN(C)C(C)C'
[19:18:56] SMILES Parse Error: extra close parentheses while parsing: N([C@@H](C)C)(C)C)(C)C
[19:18:56] SMILES Parse Error: Failed parsing SMILES 'N([C@@H](C)C)(C)C)(C)C' for input: 'N([C@@H](C)C)(C)C)(C)C'
[19:18:56] SMILES Parse Error: unclosed ring for input: 'c1cc(CN2[NH3+]CCC(CC3)C)Cc2cccc1'
[19:18:56] SMILES Parse Error: syntax error while parsing: C1CN2[C@HN+](C)(C1)CC1CCCCN1C2
[19:18:56] SMILES Parse Error: Failed parsing SMILES 'C1CN2[C@HN+](C)(C1)CC1CCCCN1C2' for input: 'C1CN2[C@HN+](C)(C1)CC1CCCCN1C2'
[19:18:56] Can't kekulize mol.  Unkekulized atoms: 0 1 2 3 12 13 14
[19:18:56] Can't kekulize mol.  Unkekulized atoms: 0 12 13 14 15
[19:18:56] Can't kekulize mol.  Unkekulized atoms: 0 1 2 3 12 13 14
[19:18:56] SMILES Parse Error: syntax error while parsing: C1N2[C@H]+

AFI_validity_rate_cl:  0.882
AFI_uniqueness_rate_cl:  0.456
AFI_novelty_rate_cl:  0.7302631578947368
AFI_reconstructability_rate_cl:  0.26973684210526316
AFI_IntDiv_cl:  0.795176816489249
AFI_FCD_score_cl:  0.8061547880749487


In [9]:
# write the metrics to the folder data_AFI
with open('./data_AFI/AFI_generated_charnn_metrics.txt', 'w') as f:
    # write the mertics
    f.write(f'AFI_validity_rate_origin: {AFI_validity_rate_origin}, AFI_validity_rate_cl: {AFI_validity_rate_cl}\n')
    f.write(f'AFI_uniqueness_rate_origin: {AFI_uniqueness_rate_origin}, AFI_uniqueness_rate_cl: {AFI_uniqueness_rate_cl}\n')
    f.write(f'AFI_novelty_rate_origin: {AFI_novelty_rate_origin}, AFI_novelty_rate_cl: {AFI_novelty_rate_cl}\n')
    f.write(f'AFI_reconstructability_rate_origin: {AFI_reconstructability_rate_origin}, AFI_reconstructability_rate_cl: {AFI_reconstructability_rate_cl}\n')
    f.write(f'AFI_IntDiv_origin: {AFI_IntDiv_origin}, AFI_IntDiv_cl: {AFI_IntDiv_cl}\n')
    f.write(f'AFI_FCD_score_origin: {AFI_FCD_score_origin}, AFI_FCD_score_cl: {AFI_FCD_score_cl}\n')

# write the generated smiles (origin and cl) and target smiles to the folder data_AFI
with open('./data_AFI/AFI_generated_charnn_smiles_origin.txt', 'w') as f:
    for smiles in range(len(generated_smile_origin)):
        f.write(f'origin: {generated_smile_origin[smiles]}, cl: {generated_smile_cl[smiles]}, target: {target_smile_origin[smiles]}\n')

In [10]:
# generate the CHA smiles with the original model
generated_smile_origin = []
target_smile_origin = []
for i, (zeo, syn, tgt) in enumerate(tqdm(CHA_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_rnn(model_origin, tgt[:, :10], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_origin.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_origin.extend(tgt_smiles)

# calculate the metrics
CHA_validity_rate_origin = validity_rate(generated_smile_origin)
CHA_uniqueness_rate_origin = uniqueness_rate(generated_smile_origin)
CHA_novelty_rate_origin = novelty_rate(generated_smile_origin, target_smile_origin)
CHA_reconstructability_rate_origin = reconstructability_rate(generated_smile_origin, target_smile_origin)
CHA_IntDiv_origin = IntDiv(generated_smile_origin)
CHA_FCD_score_origin = FCD_score(target_smile_origin, generated_smile_origin)
# print the metrics
print('CHA_validity_rate_origin: ', CHA_validity_rate_origin)
print('CHA_uniqueness_rate_origin: ', CHA_uniqueness_rate_origin)
print('CHA_novelty_rate_origin: ', CHA_novelty_rate_origin)
print('CHA_reconstructability_rate_origin: ', CHA_reconstructability_rate_origin)
print('CHA_IntDiv_origin: ', CHA_IntDiv_origin)
print('CHA_FCD_score_origin: ', CHA_FCD_score_origin)

100%|██████████| 16/16 [00:11<00:00,  1.37it/s]
[19:19:10] SMILES Parse Error: unclosed ring for input: 'C1CCCN(C2)CCC1'
[19:19:10] SMILES Parse Error: unclosed ring for input: 'C(CC)NCCCCCCC1'
[19:19:10] SMILES Parse Error: unclosed ring for input: 'C(C)CNCCCCCCC1'
[19:19:10] SMILES Parse Error: unclosed ring for input: '[NH+]1(C)CCCC3CC1C2'
[19:19:10] SMILES Parse Error: unclosed ring for input: 'CN1CCCCC1CCN(C)CCN(CC3)C1'
[19:19:10] SMILES Parse Error: unclosed ring for input: 'C[N+](C)(C)CCCCC1'
[19:19:10] SMILES Parse Error: unclosed ring for input: 'C[N+](C12CCCCCCCCCC[N+](C)(C)C)(C)C1'
[19:19:10] SMILES Parse Error: unclosed ring for input: 'C1C2([N+]2(C)CCCC3)CCCN1CCCC1'
[19:19:10] SMILES Parse Error: extra close parentheses while parsing: C1(N)CCCC2)CC
[19:19:10] SMILES Parse Error: Failed parsing SMILES 'C1(N)CCCC2)CC' for input: 'C1(N)CCCC2)CC'
[19:19:10] SMILES Parse Error: unclosed ring for input: 'C1CCCCC1N(C)CCN(C)CC2'
[19:19:10] SMILES Parse Error: unclosed ring for inp

CHA_validity_rate_origin:  0.358
CHA_uniqueness_rate_origin:  0.819
CHA_novelty_rate_origin:  0.9438339438339438
CHA_reconstructability_rate_origin:  0.05616605616605617
CHA_IntDiv_origin:  0.8891243641726648
CHA_FCD_score_origin:  4.23145289635346


In [11]:
# generate the CHA smiles with the contrastive learning model
generated_smile_cl = []
target_smile_cl = []
for i, (zeo, syn, tgt) in enumerate(tqdm(CHA_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_rnn(model_cl, tgt[:, :10], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_cl.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_cl.extend(tgt_smiles)

# calculate the metrics
CHA_validity_rate_cl = validity_rate(generated_smile_cl)
CHA_uniqueness_rate_cl = uniqueness_rate(generated_smile_cl)
CHA_novelty_rate_cl = novelty_rate(generated_smile_cl, target_smile_cl)
CHA_reconstructability_rate_cl = reconstructability_rate(generated_smile_cl, target_smile_cl)
CHA_IntDiv_cl = IntDiv(generated_smile_cl)
CHA_FCD_score_cl = FCD_score(target_smile_cl, generated_smile_cl)
# print the metrics
print('CHA_validity_rate_cl: ', CHA_validity_rate_cl)
print('CHA_uniqueness_rate_cl: ', CHA_uniqueness_rate_cl)
print('CHA_novelty_rate_cl: ', CHA_novelty_rate_cl)
print('CHA_reconstructability_rate_cl: ', CHA_reconstructability_rate_cl)
print('CHA_IntDiv_cl: ', CHA_IntDiv_cl)
print('CHA_FCD_score_cl: ', CHA_FCD_score_cl)

100%|██████████| 16/16 [00:11<00:00,  1.38it/s]
[19:19:23] Explicit valence for atom # 3 N, 4, is greater than permitted
[19:19:23] Can't kekulize mol.  Unkekulized atoms: 0 3 4
[19:19:23] SMILES Parse Error: extra close parentheses while parsing: C(C([O-])O)[P+](N(C)C)(N(C)C)CCCC[N+](CC)(CC)CC)C
[19:19:23] SMILES Parse Error: Failed parsing SMILES 'C(C([O-])O)[P+](N(C)C)(N(C)C)CCCC[N+](CC)(CC)CC)C' for input: 'C(C([O-])O)[P+](N(C)C)(N(C)C)CCCC[N+](CC)(CC)CC)C'
[19:19:23] Explicit valence for atom # 3 O, 3, is greater than permitted
[19:19:23] SMILES Parse Error: unclosed ring for input: 'C1[N@@+](CCCCC1)(C)CC2'
[19:19:23] SMILES Parse Error: syntax error while parsing: C1C[C@@H]2]=C[N+](CC)(CC)CCC2CC1
[19:19:23] SMILES Parse Error: Failed parsing SMILES 'C1C[C@@H]2]=C[N+](CC)(CC)CCC2CC1' for input: 'C1C[C@@H]2]=C[N+](CC)(CC)CCC2CC1'
[19:19:23] Explicit valence for atom # 1 N, 4, is greater than permitted
[19:19:23] SMILES Parse Error: unclosed ring for input: 'C12C3C[N+](C)(C)CC3CC(C2

CHA_validity_rate_cl:  0.969
CHA_uniqueness_rate_cl:  0.529
CHA_novelty_rate_cl:  0.720226843100189
CHA_reconstructability_rate_cl:  0.27977315689981097
CHA_IntDiv_cl:  0.8936066782485486
CHA_FCD_score_cl:  0.39907973016443776


In [12]:
# write the metrics to the folder data_CHA
with open('./data_CHA/CHA_generated_charnn_metrics.txt', 'w') as f:
    # write the mertics
    f.write(f'CHA_validity_rate_origin: {CHA_validity_rate_origin}, CHA_validity_rate_cl: {CHA_validity_rate_cl}\n')
    f.write(f'CHA_uniqueness_rate_origin: {CHA_uniqueness_rate_origin}, CHA_uniqueness_rate_cl: {CHA_uniqueness_rate_cl}\n')
    f.write(f'CHA_novelty_rate_origin: {CHA_novelty_rate_origin}, CHA_novelty_rate_cl: {CHA_novelty_rate_cl}\n')
    f.write(f'CHA_reconstructability_rate_origin: {CHA_reconstructability_rate_origin}, CHA_reconstructability_rate_cl: {CHA_reconstructability_rate_cl}\n')
    f.write(f'CHA_IntDiv_origin: {CHA_IntDiv_origin}, CHA_IntDiv_cl: {CHA_IntDiv_cl}\n')
    f.write(f'CHA_FCD_score_origin: {CHA_FCD_score_origin}, CHA_FCD_score_cl: {CHA_FCD_score_cl}\n')
# write the generated smiles (origin and cl) and target smiles to the folder data_CHA
with open('./data_CHA/CHA_generated_charnn_smiles_origin.txt', 'w') as f:
    for smiles in range(len(generated_smile_origin)):
        f.write(f'origin: {generated_smile_origin[smiles]}, cl: {generated_smile_cl[smiles]}, target: {target_smile_origin[smiles]}\n')

In [13]:
# generate the AEI smiles with the original model
generated_smile_origin = []
target_smile_origin = []
for i, (zeo, syn, tgt) in enumerate(tqdm(AEI_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_rnn(model_origin, tgt[:, :10], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_origin.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_origin.extend(tgt_smiles)
    
# calculate the metrics
AEI_validity_rate_origin = validity_rate(generated_smile_origin)
AEI_uniqueness_rate_origin = uniqueness_rate(generated_smile_origin)
AEI_novelty_rate_origin = novelty_rate(generated_smile_origin, target_smile_origin)
AEI_reconstructability_rate_origin = reconstructability_rate(generated_smile_origin, target_smile_origin)
AEI_IntDiv_origin = IntDiv(generated_smile_origin)
AEI_FCD_score_origin = FCD_score(target_smile_origin, generated_smile_origin)
# print the metrics
print('AEI_validity_rate_origin: ', AEI_validity_rate_origin)
print('AEI_uniqueness_rate_origin: ', AEI_uniqueness_rate_origin)
print('AEI_novelty_rate_origin: ', AEI_novelty_rate_origin)
print('AEI_reconstructability_rate_origin: ', AEI_reconstructability_rate_origin)
print('AEI_IntDiv_origin: ', AEI_IntDiv_origin)
print('AEI_FCD_score_origin: ', AEI_FCD_score_origin)

100%|██████████| 16/16 [00:11<00:00,  1.35it/s]
[19:19:37] SMILES Parse Error: unclosed ring for input: 'C1N(C)CCN(C)CC2'
[19:19:37] SMILES Parse Error: unclosed ring for input: 'C1[C@H](C)(C)CC2CC2(C)CC1CC2C'
[19:19:37] SMILES Parse Error: unclosed ring for input: '[N+]1(C)(CCCCCCCC3)CC[N+]1(C)CCCC2C2CCCCC1'
[19:19:37] Explicit valence for atom # 2 C, 5, is greater than permitted
[19:19:37] SMILES Parse Error: unclosed ring for input: 'C[C@H]1C[N+]1(CCCCCCCCCCCCCCC[N+](CC)(C)C)CCCCC2'
[19:19:37] SMILES Parse Error: unclosed ring for input: 'C1[C@@H](C)(CCCCC(C)C)CC2CCCCC1CC[N+](C)(C)C'
[19:19:37] SMILES Parse Error: unclosed ring for input: 'CC[P+](CCCCCCC2)cccc1'
[19:19:37] SMILES Parse Error: unclosed ring for input: 'C(C)[N+](C)(C)CC2'
[19:19:37] SMILES Parse Error: unclosed ring for input: 'C1C(C)C[N+]3(CCCC2)CCCC1'
[19:19:37] SMILES Parse Error: unclosed ring for input: 'CC1CC(C)CCCC[N+](C)(C)CCC(C1)CC2'
[19:19:37] SMILES Parse Error: unclosed ring for input: 'C([P+](CC)(C)C)CCCC

AEI_validity_rate_origin:  0.272
AEI_uniqueness_rate_origin:  0.936
AEI_novelty_rate_origin:  0.9978632478632479
AEI_reconstructability_rate_origin:  0.002136752136752137
AEI_IntDiv_origin:  0.8328920949606096
AEI_FCD_score_origin:  5.770402010250862


In [14]:
# generate the AEI smiles with the contrastive learning model
generated_smile_cl = []
target_smile_cl = []
for i, (zeo, syn, tgt) in enumerate(tqdm(AEI_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_rnn(model_cl, tgt[:, :10], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_cl.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_cl.extend(tgt_smiles)

# calculate the metrics
AEI_validity_rate_cl = validity_rate(generated_smile_cl)
AEI_uniqueness_rate_cl = uniqueness_rate(generated_smile_cl)
AEI_novelty_rate_cl = novelty_rate(generated_smile_cl, target_smile_cl)
AEI_reconstructability_rate_cl = reconstructability_rate(generated_smile_cl, target_smile_cl)
AEI_IntDiv_cl = IntDiv(generated_smile_cl)
AEI_FCD_score_cl = FCD_score(target_smile_cl, generated_smile_cl)
# print the metrics
print('AEI_validity_rate_cl: ', AEI_validity_rate_cl)
print('AEI_uniqueness_rate_cl: ', AEI_uniqueness_rate_cl)
print('AEI_novelty_rate_cl: ', AEI_novelty_rate_cl)
print('AEI_reconstructability_rate_cl: ', AEI_reconstructability_rate_cl)
print('AEI_IntDiv_cl: ', AEI_IntDiv_cl)
print('AEI_FCD_score_cl: ', AEI_FCD_score_cl)

100%|██████████| 16/16 [00:11<00:00,  1.36it/s]
[19:19:51] SMILES Parse Error: extra close parentheses while parsing: C1CN(C)CCN(C)CCN(CCN(CCN(CCN1C)C)C)C)C
[19:19:51] SMILES Parse Error: Failed parsing SMILES 'C1CN(C)CCN(C)CCN(CCN(CCN(CCN1C)C)C)C)C' for input: 'C1CN(C)CCN(C)CCN(CCN(CCN(CCN1C)C)C)C)C'
[19:19:51] SMILES Parse Error: extra close parentheses while parsing: C([P+](CCCCCCCCCCCCCCCCCCCCCC)(C)C)C)CCCCC
[19:19:51] SMILES Parse Error: Failed parsing SMILES 'C([P+](CCCCCCCCCCCCCCCCCCCCCC)(C)C)C)CCCCC' for input: 'C([P+](CCCCCCCCCCCCCCCCCCCCCC)(C)C)C)CCCCC'
[19:19:51] SMILES Parse Error: unclosed ring for input: 'C1[N+](C)(C)CC2C3C4C(C(C2C1)C=C3)C[N+](C(C)C)(C2)CC'
[19:19:51] SMILES Parse Error: unclosed ring for input: 'C1[NH2+]C(C)(C)CCCC1C2'
[19:19:51] SMILES Parse Error: unclosed ring for input: 'C1[C@@H](-O)C'
[19:19:51] SMILES Parse Error: syntax error while parsing: [NH2+]1[CCC(CC1)CCCC1CC[N+](C)(CC1)C
[19:19:51] SMILES Parse Error: Failed parsing SMILES '[NH2+]1[CCC(CC1)C

AEI_validity_rate_cl:  0.944
AEI_uniqueness_rate_cl:  0.519
AEI_novelty_rate_cl:  0.8978805394990366
AEI_reconstructability_rate_cl:  0.10211946050096339
AEI_IntDiv_cl:  0.8151251602826046
AEI_FCD_score_cl:  1.2114492080787613


In [15]:
# write the metrics to the folder data_AEI
with open('./data_AEI/AEI_generated_charnn_metrics.txt', 'w') as f:
    # write the mertics
    f.write(f'AEI_validity_rate_origin: {AEI_validity_rate_origin}, AEI_validity_rate_cl: {AEI_validity_rate_cl}\n')
    f.write(f'AEI_uniqueness_rate_origin: {AEI_uniqueness_rate_origin}, AEI_uniqueness_rate_cl: {AEI_uniqueness_rate_cl}\n')
    f.write(f'AEI_novelty_rate_origin: {AEI_novelty_rate_origin}, AEI_novelty_rate_cl: {AEI_novelty_rate_cl}\n')
    f.write(f'AEI_reconstructability_rate_origin: {AEI_reconstructability_rate_origin}, AEI_reconstructability_rate_cl: {AEI_reconstructability_rate_cl}\n')
    f.write(f'AEI_IntDiv_origin: {AEI_IntDiv_origin}, AEI_IntDiv_cl: {AEI_IntDiv_cl}\n')
    f.write(f'AEI_FCD_score_origin: {AEI_FCD_score_origin}, AEI_FCD_score_cl: {AEI_FCD_score_cl}\n')

# write the generated smiles (origin and cl) and target smiles to the folder data_AEI
with open('./data_AEI/AEI_generated_charnn_smiles_origin.txt', 'w') as f:
    for smiles in range(len(generated_smile_origin)):
        f.write(f'origin: {generated_smile_origin[smiles]}, cl: {generated_smile_cl[smiles]}, target: {target_smile_origin[smiles]}\n')