In [1]:
# change working path to the current file
%cd ..

/home/hudongcheng/Desktop/bo_osda_generator


In [2]:
import numpy as np
import pandas as pd
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn import functional as F
from tqdm import tqdm
from datetime import datetime
import torch.backends.cudnn as cudnn

# import custom modules
from datasets.data_loader import *
from models.GPT import *
from models.loss import InfoNCELoss
from models.trfm import *
from utils.utils import *
from utils.plot_figures import *
from utils.metrics import *
from utils.build_vocab import *

In [3]:
cudnn.benchmark = True
cudnn.enabled = True

train_loss_history = []
train_acc_history = []
test_loss_history = []
test_acc_history = []

now = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')

PAD = 0
UNK = 1
EOS = 2
SOS = 3
MASK = 4
MAX_LEN = 220

In [4]:
def generate_gpt(model, start_sequence, condition_props, max_length, vocab, device, temperature=1.0, top_k=0):
    """
    Autoregressive generation process for a GPT model.

    Args:
        model (GPT): The pre-trained GPT model for token generation.
        start_sequence (torch.Tensor): The initial sequence to start generation (batch_size, seq_length).
        condition_props (torch.Tensor): The conditional property vector (batch_size, num_props).
        max_length (int): The maximum length of the generated sequence.
        vocab: The vocabulary object for encoding and decoding SMILES strings.
        device (torch.device): The device on which to run the generation.
        temperature (float): Temperature parameter for sampling; higher values increase randomness.
        top_k (int): Limits sampling to top-k logits; if 0, no top-k sampling is applied.

    Returns:
        List[str]: A list of generated SMILES strings.
    """
    model.eval()
    batch_size = start_sequence.size(0)
    generated_sequences = start_sequence.clone().to(device)  # Clone and move to device

    for _ in range(max_length - start_sequence.size(1)):
        # Get the current sequence length
        current_length = generated_sequences.size(1)

        # Forward pass through the model
        logits = model(generated_sequences, condition_props)  # (batch_size, seq_length, vocab_size)

        # Extract the logits for the last time step
        next_token_logits = logits[:, -1, :]  # (batch_size, vocab_size)

        # Apply temperature scaling
        next_token_logits = next_token_logits / temperature

        # Apply top-k filtering
        if top_k > 0:
            top_k_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
            mask = torch.full_like(next_token_logits, float('-inf'))
            mask.scatter_(dim=-1, index=top_k_indices, src=top_k_logits)
            next_token_logits = mask

        # Convert logits to probabilities
        next_token_probs = F.softmax(next_token_logits, dim=-1)

        # Sample from the probability distribution
        next_token = torch.multinomial(next_token_probs, num_samples=1)  # (batch_size, 1)
        
        # Get the most likely next token
        # next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)

        # Append the generated token to the sequence
        generated_sequences = torch.cat([generated_sequences, next_token], dim=1)

        # Check if all sequences have reached the end token
        if all(next_token[i].item() == EOS for i in range(batch_size)):
            break

    # Decode the generated sequences into SMILES strings
    generated_smiles = []
    for seq in generated_sequences:
        # Convert indices to characters, ignoring padding and start tokens
        # check if the generated sequence contains the end token, if meet, stop decoding
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        generated_smiles.append(smiles)

    return generated_smiles

In [5]:
# load the data
AFI_smiles = read_strings('./data_AFI/AFI_smiles.csv', idx=False)
AFI_zeo = read_vec('./data_AFI/AFI_zeo.csv', idx=False)
AFI_syn = read_vec('./data_AFI/AFI_syn.csv', idx=False)
CHA_smiles = read_strings('./data_CHA/CHA_smiles.csv', idx=False)
CHA_zeo = read_vec('./data_CHA/CHA_zeo.csv', idx=False)
CHA_syn = read_vec('./data_CHA/CHA_syn.csv', idx=False)
AEI_smiles = read_strings('./data_AEI/AEI_smiles.csv', idx=False)
AEI_zeo = read_vec('./data_AEI/AEI_zeo.csv', idx=False)
AEI_syn = read_vec('./data_AEI/AEI_syn.csv', idx=False)

vocab = WordVocab.load_vocab('./model_hub/vocab.pkl')
print('the vocab size is :', len(vocab))

charlen = len(vocab)
print('the total num of charset is :', charlen)

cudnn.benchmark = True
batch_size = 64

manual_seed = 42
random.seed(manual_seed)
torch.manual_seed(manual_seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# create the dataset and dataloader
AFI_dataset = Seq2seqDataset(AFI_zeo, AFI_syn, AFI_smiles, vocab)
CHA_dataset = Seq2seqDataset(CHA_zeo, CHA_syn, CHA_smiles, vocab)
AEI_dataset = Seq2seqDataset(AEI_zeo, AEI_syn, AEI_smiles, vocab)
AFI_dataloader = DataLoader(AFI_dataset, batch_size=batch_size, shuffle=True)
CHA_dataloader = DataLoader(CHA_dataset, batch_size=batch_size, shuffle=False)
AEI_dataloader = DataLoader(AEI_dataset, batch_size=batch_size, shuffle=False)

the vocab size is : 45
the total num of charset is : 45


In [6]:
# create the model
config = GPTConfig(vocab_size=charlen, block_size=220, num_props=24)
model_origin = GPT(config)
model_origin.load_state_dict(torch.load('./checkpoints/best_GPT_model.pth'))
model_origin = model_origin.to(device)

model_cl = GPT(config)
model_cl.load_state_dict(torch.load('./checkpoints/best_GPT_contrastive_model.pth'))
model_cl = model_cl.to(device)

total = sum(p.numel() for p in model_cl.parameters())
print('total parameters: %0.2fM' % (total / 1e6))

total parameters: 3.25M


In [7]:
# generate the AFI smiles with the original model
generated_smile_origin = []
target_smile_origin = []
for i, (zeo, syn, tgt) in enumerate(tqdm(AFI_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_gpt(model_origin, tgt[:, :1], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_origin.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_origin.extend(tgt_smiles)
    
# calculate the metrics
AFI_validity_rate_origin = validity_rate(generated_smile_origin)
AFI_uniqueness_rate_origin = uniqueness_rate(generated_smile_origin)
AFI_novelty_rate_origin = novelty_rate(generated_smile_origin, target_smile_origin)
AFI_reconstructability_rate_origin = reconstructability_rate(generated_smile_origin, target_smile_origin)
AFI_IntDiv_origin = IntDiv(generated_smile_origin)
AFI_FCD_score_origin = FCD_score(target_smile_origin, generated_smile_origin)
# print the metrics
print('AFI_validity_rate_origin: ', AFI_validity_rate_origin)
print('AFI_uniqueness_rate_origin: ', AFI_uniqueness_rate_origin)
print('AFI_novelty_rate_origin: ', AFI_novelty_rate_origin)
print('AFI_reconstructability_rate_origin: ', AFI_reconstructability_rate_origin)
print('AFI_IntDiv_origin: ', AFI_IntDiv_origin)
print('AFI_FCD_score_origin: ', AFI_FCD_score_origin)

100%|██████████| 16/16 [00:12<00:00,  1.30it/s]
[19:40:34] Can't kekulize mol.  Unkekulized atoms: 0 1 2 5 6
[19:40:34] Can't kekulize mol.  Unkekulized atoms: 0 1 2 6 8
[19:40:34] Can't kekulize mol.  Unkekulized atoms: 0 1 2 5 6
[19:40:34] Can't kekulize mol.  Unkekulized atoms: 0 1 2 6 8
[19:40:35] Can't kekulize mol.  Unkekulized atoms: 0 1 2 5 6
[19:40:35] Can't kekulize mol.  Unkekulized atoms: 0 1 2 6 8


AFI_validity_rate_origin:  0.998
AFI_uniqueness_rate_origin:  0.33
AFI_novelty_rate_origin:  0.7333333333333333
AFI_reconstructability_rate_origin:  0.26666666666666666
AFI_IntDiv_origin:  0.7774910774229551
AFI_FCD_score_origin:  3.9085507782014908


In [8]:
# generate the AFI smiles with the contrastive learning model
generated_smile_cl = []
target_smile_cl = []
for i, (zeo, syn, tgt) in enumerate(tqdm(AFI_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_gpt(model_cl, tgt[:, :5], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_cl.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_cl.extend(tgt_smiles)

# calculate the metrics
AFI_validity_rate_cl = validity_rate(generated_smile_cl)
AFI_uniqueness_rate_cl = uniqueness_rate(generated_smile_cl)
AFI_novelty_rate_cl = novelty_rate(generated_smile_cl, target_smile_cl)
AFI_reconstructability_rate_cl = reconstructability_rate(generated_smile_cl, target_smile_cl)
AFI_IntDiv_cl = IntDiv(generated_smile_cl)
AFI_FCD_score_cl = FCD_score(target_smile_cl, generated_smile_cl)
# print the metrics
print('AFI_validity_rate_cl: ', AFI_validity_rate_cl)
print('AFI_uniqueness_rate_cl: ', AFI_uniqueness_rate_cl)
print('AFI_novelty_rate_cl: ', AFI_novelty_rate_cl)
print('AFI_reconstructability_rate_cl: ', AFI_reconstructability_rate_cl)
print('AFI_IntDiv_cl: ', AFI_IntDiv_cl)
print('AFI_FCD_score_cl: ', AFI_FCD_score_cl)

100%|██████████| 16/16 [00:11<00:00,  1.40it/s]
[19:40:48] SMILES Parse Error: syntax error while parsing: [C@H2CCCCC2CCCCC[N+]1(C)C2CCCC1
[19:40:48] SMILES Parse Error: Failed parsing SMILES '[C@H2CCCCC2CCCCC[N+]1(C)C2CCCC1' for input: '[C@H2CCCCC2CCCCC[N+]1(C)C2CCCC1'
[19:40:48] SMILES Parse Error: syntax error while parsing: [C@@(C)(C)C
[19:40:48] SMILES Parse Error: Failed parsing SMILES '[C@@(C)(C)C' for input: '[C@@(C)(C)C'
[19:40:48] SMILES Parse Error: syntax error while parsing: N([C(C)C(c1ccccc1)O)(C)C
[19:40:48] SMILES Parse Error: Failed parsing SMILES 'N([C(C)C(c1ccccc1)O)(C)C' for input: 'N([C(C)C(c1ccccc1)O)(C)C'
[19:40:48] SMILES Parse Error: syntax error while parsing: C[C@(C(O)c1ccccc1)C
[19:40:48] SMILES Parse Error: Failed parsing SMILES 'C[C@(C(O)c1ccccc1)C' for input: 'C[C@(C(O)c1ccccc1)C'
[19:40:48] SMILES Parse Error: syntax error while parsing: [C@@12CC3CC(C2)CC(C1)C3
[19:40:48] SMILES Parse Error: Failed parsing SMILES '[C@@12CC3CC(C2)CC(C1)C3' for input: '[C@

AFI_validity_rate_cl:  0.964
AFI_uniqueness_rate_cl:  0.459
AFI_novelty_rate_cl:  0.7995642701525054
AFI_reconstructability_rate_cl:  0.20043572984749455
AFI_IntDiv_cl:  0.8337992111897545
AFI_FCD_score_cl:  1.0821486295469178


In [9]:
# write the metrics to the folder data_AFI
with open('./data_AFI/AFI_generated_GPT_metrics.txt', 'w') as f:
    # write the mertics
    f.write(f'AFI_validity_rate_origin: {AFI_validity_rate_origin}, AFI_validity_rate_cl: {AFI_validity_rate_cl}\n')
    f.write(f'AFI_uniqueness_rate_origin: {AFI_uniqueness_rate_origin}, AFI_uniqueness_rate_cl: {AFI_uniqueness_rate_cl}\n')
    f.write(f'AFI_novelty_rate_origin: {AFI_novelty_rate_origin}, AFI_novelty_rate_cl: {AFI_novelty_rate_cl}\n')
    f.write(f'AFI_reconstructability_rate_origin: {AFI_reconstructability_rate_origin}, AFI_reconstructability_rate_cl: {AFI_reconstructability_rate_cl}\n')
    f.write(f'AFI_IntDiv_origin: {AFI_IntDiv_origin}, AFI_IntDiv_cl: {AFI_IntDiv_cl}\n')
    f.write(f'AFI_FCD_score_origin: {AFI_FCD_score_origin}, AFI_FCD_score_cl: {AFI_FCD_score_cl}\n')

# write the generated smiles (origin and cl) and target smiles to the folder data_AFI
with open('./data_AFI/AFI_generated_GPT_smiles_origin.txt', 'w') as f:
    for smiles in range(len(generated_smile_origin)):
        f.write(f'origin: {generated_smile_origin[smiles]}, cl: {generated_smile_cl[smiles]}, target: {target_smile_origin[smiles]}\n')

In [10]:
# generate the CHA smiles with the original model
generated_smile_origin = []
target_smile_origin = []
for i, (zeo, syn, tgt) in enumerate(tqdm(CHA_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_gpt(model_origin, tgt[:, :1], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_origin.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_origin.extend(tgt_smiles)

# calculate the metrics
CHA_validity_rate_origin = validity_rate(generated_smile_origin)
CHA_uniqueness_rate_origin = uniqueness_rate(generated_smile_origin)
CHA_novelty_rate_origin = novelty_rate(generated_smile_origin, target_smile_origin)
CHA_reconstructability_rate_origin = reconstructability_rate(generated_smile_origin, target_smile_origin)
CHA_IntDiv_origin = IntDiv(generated_smile_origin)
CHA_FCD_score_origin = FCD_score(target_smile_origin, generated_smile_origin)
# print the metrics
print('CHA_validity_rate_origin: ', CHA_validity_rate_origin)
print('CHA_uniqueness_rate_origin: ', CHA_uniqueness_rate_origin)
print('CHA_novelty_rate_origin: ', CHA_novelty_rate_origin)
print('CHA_reconstructability_rate_origin: ', CHA_reconstructability_rate_origin)
print('CHA_IntDiv_origin: ', CHA_IntDiv_origin)
print('CHA_FCD_score_origin: ', CHA_FCD_score_origin)

100%|██████████| 16/16 [00:12<00:00,  1.32it/s]
[19:41:03] Explicit valence for atom # 0 C, 6, is greater than permitted
[19:41:03] Explicit valence for atom # 0 C, 5, is greater than permitted
[19:41:03] Can't kekulize mol.  Unkekulized atoms: 0 1 2 3 5
[19:41:03] Explicit valence for atom # 1 C, 5, is greater than permitted
[19:41:03] Explicit valence for atom # 0 C, 6, is greater than permitted
[19:41:03] Explicit valence for atom # 0 C, 5, is greater than permitted
[19:41:03] Can't kekulize mol.  Unkekulized atoms: 0 1 2 3 5
[19:41:03] Explicit valence for atom # 1 C, 5, is greater than permitted
[19:41:03] Explicit valence for atom # 0 C, 6, is greater than permitted
[19:41:03] Explicit valence for atom # 0 C, 5, is greater than permitted
[19:41:03] Can't kekulize mol.  Unkekulized atoms: 0 1 2 3 5
[19:41:03] Explicit valence for atom # 1 C, 5, is greater than permitted


CHA_validity_rate_origin:  0.996
CHA_uniqueness_rate_origin:  0.412
CHA_novelty_rate_origin:  0.6868932038834952
CHA_reconstructability_rate_origin:  0.3131067961165049
CHA_IntDiv_origin:  0.8830681144108509
CHA_FCD_score_origin:  2.0578216341368716


In [11]:
# generate the CHA smiles with the contrastive learning model
generated_smile_cl = []
target_smile_cl = []
for i, (zeo, syn, tgt) in enumerate(tqdm(CHA_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_gpt(model_cl, tgt[:, :5], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_cl.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_cl.extend(tgt_smiles)

# calculate the metrics
CHA_validity_rate_cl = validity_rate(generated_smile_cl)
CHA_uniqueness_rate_cl = uniqueness_rate(generated_smile_cl)
CHA_novelty_rate_cl = novelty_rate(generated_smile_cl, target_smile_cl)
CHA_reconstructability_rate_cl = reconstructability_rate(generated_smile_cl, target_smile_cl)
CHA_IntDiv_cl = IntDiv(generated_smile_cl)
CHA_FCD_score_cl = FCD_score(target_smile_cl, generated_smile_cl)
# print the metrics
print('CHA_validity_rate_cl: ', CHA_validity_rate_cl)
print('CHA_uniqueness_rate_cl: ', CHA_uniqueness_rate_cl)
print('CHA_novelty_rate_cl: ', CHA_novelty_rate_cl)
print('CHA_reconstructability_rate_cl: ', CHA_reconstructability_rate_cl)
print('CHA_IntDiv_cl: ', CHA_IntDiv_cl)
print('CHA_FCD_score_cl: ', CHA_FCD_score_cl)

100%|██████████| 16/16 [00:11<00:00,  1.41it/s]
[19:41:17] Can't kekulize mol.  Unkekulized atoms: 0 1 2 3 4
[19:41:17] SMILES Parse Error: syntax error while parsing: C1[C2CC3CC(CC1C3)C2
[19:41:17] SMILES Parse Error: Failed parsing SMILES 'C1[C2CC3CC(CC1C3)C2' for input: 'C1[C2CC3CC(CC1C3)C2'
[19:41:17] SMILES Parse Error: extra open parentheses for input: 'N1CCNC(C)(CC(C)NCCNC(C)(CC1(C)C)C'
[19:41:17] SMILES Parse Error: unclosed ring for input: 'C1OCCN(CCCN)CCN'
[19:41:17] Can't kekulize mol.  Unkekulized atoms: 0 1 2 3 4
[19:41:17] SMILES Parse Error: extra open parentheses for input: 'C(c1n(c([n+](c1)C)C)C'
[19:41:17] Explicit valence for atom # 1 N, 4, is greater than permitted
[19:41:17] Explicit valence for atom # 1 C, 5, is greater than permitted
[19:41:17] SMILES Parse Error: syntax error while parsing: [C@@CC(C)CCC(C)[N+]1(CCCCC1)C
[19:41:17] SMILES Parse Error: Failed parsing SMILES '[C@@CC(C)CCC(C)[N+]1(CCCCC1)C' for input: '[C@@CC(C)CCC(C)[N+]1(CCCCC1)C'
[19:41:17] Expli

CHA_validity_rate_cl:  0.985
CHA_uniqueness_rate_cl:  0.588
CHA_novelty_rate_cl:  0.7976190476190477
CHA_reconstructability_rate_cl:  0.20238095238095238
CHA_IntDiv_cl:  0.8941642246936903
CHA_FCD_score_cl:  1.3336202100087853


In [12]:
# write the metrics to the folder data_CHA
with open('./data_CHA/CHA_generated_GPT_metrics.txt', 'w') as f:
    # write the mertics
    f.write(f'CHA_validity_rate_origin: {CHA_validity_rate_origin}, CHA_validity_rate_cl: {CHA_validity_rate_cl}\n')
    f.write(f'CHA_uniqueness_rate_origin: {CHA_uniqueness_rate_origin}, CHA_uniqueness_rate_cl: {CHA_uniqueness_rate_cl}\n')
    f.write(f'CHA_novelty_rate_origin: {CHA_novelty_rate_origin}, CHA_novelty_rate_cl: {CHA_novelty_rate_cl}\n')
    f.write(f'CHA_reconstructability_rate_origin: {CHA_reconstructability_rate_origin}, CHA_reconstructability_rate_cl: {CHA_reconstructability_rate_cl}\n')
    f.write(f'CHA_IntDiv_origin: {CHA_IntDiv_origin}, CHA_IntDiv_cl: {CHA_IntDiv_cl}\n')
    f.write(f'CHA_FCD_score_origin: {CHA_FCD_score_origin}, CHA_FCD_score_cl: {CHA_FCD_score_cl}\n')
# write the generated smiles (origin and cl) and target smiles to the folder data_CHA
with open('./data_CHA/CHA_generated_GPT_smiles_origin.txt', 'w') as f:
    for smiles in range(len(generated_smile_origin)):
        f.write(f'origin: {generated_smile_origin[smiles]}, cl: {generated_smile_cl[smiles]}, target: {target_smile_origin[smiles]}\n')

In [13]:
# generate the AEI smiles with the original model
generated_smile_origin = []
target_smile_origin = []
for i, (zeo, syn, tgt) in enumerate(tqdm(AEI_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_gpt(model_origin, tgt[:, :1], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_origin.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_origin.extend(tgt_smiles)

# calculate the metrics
AEI_validity_rate_origin = validity_rate(generated_smile_origin)
AEI_uniqueness_rate_origin = uniqueness_rate(generated_smile_origin)
AEI_novelty_rate_origin = novelty_rate(generated_smile_origin, target_smile_origin)
AEI_reconstructability_rate_origin = reconstructability_rate(generated_smile_origin, target_smile_origin)
AEI_IntDiv_origin = IntDiv(generated_smile_origin)
AEI_FCD_score_origin = FCD_score(target_smile_origin, generated_smile_origin)
# print the metrics
print('AEI_validity_rate_origin: ', AEI_validity_rate_origin)
print('AEI_uniqueness_rate_origin: ', AEI_uniqueness_rate_origin)
print('AEI_novelty_rate_origin: ', AEI_novelty_rate_origin)
print('AEI_reconstructability_rate_origin: ', AEI_reconstructability_rate_origin)
print('AEI_IntDiv_origin: ', AEI_IntDiv_origin)
print('AEI_FCD_score_origin: ', AEI_FCD_score_origin)

100%|██████████| 16/16 [00:12<00:00,  1.32it/s]


AEI_validity_rate_origin:  1.0
AEI_uniqueness_rate_origin:  0.348
AEI_novelty_rate_origin:  0.7959770114942529
AEI_reconstructability_rate_origin:  0.20402298850574713
AEI_IntDiv_origin:  0.8278493052245057
AEI_FCD_score_origin:  3.5808793040934503


In [14]:
# generate the AEI smiles with the contrastive learning model
generated_smile_cl = []
target_smile_cl = []
for i, (zeo, syn, tgt) in enumerate(tqdm(AEI_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_gpt(model_cl, tgt[:, :5], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_cl.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_cl.extend(tgt_smiles)

# calculate the metrics
AEI_validity_rate_cl = validity_rate(generated_smile_cl)
AEI_uniqueness_rate_cl = uniqueness_rate(generated_smile_cl)
AEI_novelty_rate_cl = novelty_rate(generated_smile_cl, target_smile_cl)
AEI_reconstructability_rate_cl = reconstructability_rate(generated_smile_cl, target_smile_cl)
AEI_IntDiv_cl = IntDiv(generated_smile_cl)
AEI_FCD_score_cl = FCD_score(target_smile_cl, generated_smile_cl)
# print the metrics
print('AEI_validity_rate_cl: ', AEI_validity_rate_cl)
print('AEI_uniqueness_rate_cl: ', AEI_uniqueness_rate_cl)
print('AEI_novelty_rate_cl: ', AEI_novelty_rate_cl)
print('AEI_reconstructability_rate_cl: ', AEI_reconstructability_rate_cl)
print('AEI_IntDiv_cl: ', AEI_IntDiv_cl)
print('AEI_FCD_score_cl: ', AEI_FCD_score_cl)

100%|██████████| 16/16 [00:08<00:00,  1.85it/s]
[19:41:43] SMILES Parse Error: syntax error while parsing: [C@@[N+]1(C2CCCC1CC2)C
[19:41:43] SMILES Parse Error: Failed parsing SMILES '[C@@[N+]1(C2CCCC1CC2)C' for input: '[C@@[N+]1(C2CCCC1CC2)C'
[19:41:43] SMILES Parse Error: syntax error while parsing: [C@H2C(CCCC[N+]1(C)C)CCCC2
[19:41:43] SMILES Parse Error: Failed parsing SMILES '[C@H2C(CCCC[N+]1(C)C)CCCC2' for input: '[C@H2C(CCCC[N+]1(C)C)CCCC2'
[19:41:43] SMILES Parse Error: syntax error while parsing: C[C@1CCCC([N+]1(C)C)C
[19:41:43] SMILES Parse Error: Failed parsing SMILES 'C[C@1CCCC([N+]1(C)C)C' for input: 'C[C@1CCCC([N+]1(C)C)C'
[19:41:43] SMILES Parse Error: syntax error while parsing: [C@@CC[n+]1ccn(c1)C
[19:41:43] SMILES Parse Error: Failed parsing SMILES '[C@@CC[n+]1ccn(c1)C' for input: '[C@@CC[n+]1ccn(c1)C'
[19:41:43] SMILES Parse Error: unclosed ring for input: 'C1[N+](C)(C)CC2C3C4C5C(C[N+](C)(C)C5)C(C=C4)C21'
[19:41:43] SMILES Parse Error: unclosed ring for input: 'C1[N+

AEI_validity_rate_cl:  0.899
AEI_uniqueness_rate_cl:  0.502
AEI_novelty_rate_cl:  0.896414342629482
AEI_reconstructability_rate_cl:  0.10358565737051793
AEI_IntDiv_cl:  0.8226898947204558
AEI_FCD_score_cl:  1.7340445329114953


In [15]:
# write the metrics to the folder data_AEI
with open('./data_AEI/AEI_generated_GPT_metrics.txt', 'w') as f:
    # write the mertics
    f.write(f'AEI_validity_rate_origin: {AEI_validity_rate_origin}, AEI_validity_rate_cl: {AEI_validity_rate_cl}\n')
    f.write(f'AEI_uniqueness_rate_origin: {AEI_uniqueness_rate_origin}, AEI_uniqueness_rate_cl: {AEI_uniqueness_rate_cl}\n')
    f.write(f'AEI_novelty_rate_origin: {AEI_novelty_rate_origin}, AEI_novelty_rate_cl: {AEI_novelty_rate_cl}\n')
    f.write(f'AEI_reconstructability_rate_origin: {AEI_reconstructability_rate_origin}, AEI_reconstructability_rate_cl: {AEI_reconstructability_rate_cl}\n')
    f.write(f'AEI_IntDiv_origin: {AEI_IntDiv_origin}, AEI_IntDiv_cl: {AEI_IntDiv_cl}\n')
    f.write(f'AEI_FCD_score_origin: {AEI_FCD_score_origin}, AEI_FCD_score_cl: {AEI_FCD_score_cl}\n')
# write the generated smiles (origin and cl) and target smiles to the folder data_AEI
with open('./data_AEI/AEI_generated_GPT_smiles_origin.txt', 'w') as f:
    for smiles in range(len(generated_smile_origin)):
        f.write(f'origin: {generated_smile_origin[smiles]}, cl: {generated_smile_cl[smiles]}, target: {target_smile_origin[smiles]}\n')

: 