In [1]:
%cd ..

/home/hudongcheng/Desktop/bo_osda_generator


In [2]:
import numpy as np
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn import functional as F
from tqdm import tqdm
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
import torch.backends.cudnn as cudnn

# import custom modules
from datasets.data_loader import *
from models.cvae import *
from models.loss import InfoNCELoss
from models.trfm import *
from utils.utils import *
from utils.plot_figures import *
from utils.metrics import *
from utils.build_vocab import *

In [3]:
cudnn.benchmark = True
cudnn.enabled = True

train_loss_history = []
train_acc_history = []
test_loss_history = []
test_acc_history = []

now = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')

PAD = 0
UNK = 1
EOS = 2
SOS = 3
MASK = 4
MAX_LEN = 220

In [4]:
def generate_cvae(model, condition, tgt_input, vocab, device='cuda', max_len=MAX_LEN, temperature=1.0):
    """
    Generates SMILES strings from a trained CVAE model, given an already-merged synthesis condition.

    Args:
        model:       The trained ConditionalVAE model (must implement model.generate).
        condition:   (B, cond_dim) - the already-merged condition vector (e.g., torch.cat([zeo, syn], dim=-1)).
        tgt_input:   (B, seq_len) - the input SMILES sequence (e.g., the '<sos>' token).
        vocab:       The vocabulary object (must implement vocab.idx2char).
        device:      'cuda' or 'cpu'.
        max_len:     Maximum length of generated SMILES.

    Returns:
        smiles_list: A list of generated SMILES strings (length = batch size).
                     Each string excludes '<sos>', ignores '<pad>', and is cut after '<eos>'.
    """
    model.eval()
    smiles_list = []

    with torch.no_grad():
        # Move condition to device
        condition = condition.to(device)
        
        # Generate SMILES strings
        # Assuming the model has a generate function that takes condition as input
        # generated = model.generate(condition, max_len=max_len, device=device, teacher_force_inputs=tgt_input)  # shape: (B, seq_len)
        logits, z_mean, z_log = model(
            x_smi=tgt_input,         # SMILES input
            x_cond=condition,        # condition
            teacher_force_inputs=tgt_input  # Use teacher forcing during inference
        )

        # sample from the logits using multinomial
        probs = F.softmax(logits / temperature, dim=-1)
        generated = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1).view(probs.size(0), probs.size(1))
        
        
        for i in range(generated.size(0)):
            # Convert the generated token IDs to SMILES string
            smi = []
            for idx in generated[i]:
                # Convert index to character, skipping padding or other invalid tokens
                char = vocab.itos[idx.item()]
                if char == '<eos>':  # End of the SMILES string
                    break
                if char not in ('<sos>', '<pad>'):  # Remove unwanted characters
                    smi.append(char)
            smiles_list.append(''.join(smi))

    return smiles_list


In [5]:
# load the data
AFI_smiles = read_strings('./data_AFI/AFI_smiles.csv', idx=False)
AFI_zeo = read_vec('./data_AFI/AFI_zeo.csv', idx=False)
AFI_syn = read_vec('./data_AFI/AFI_syn.csv', idx=False)
CHA_smiles = read_strings('./data_CHA/CHA_smiles.csv', idx=False)
CHA_zeo = read_vec('./data_CHA/CHA_zeo.csv', idx=False)
CHA_syn = read_vec('./data_CHA/CHA_syn.csv', idx=False)
AEI_smiles = read_strings('./data_AEI/AEI_smiles.csv', idx=False)
AEI_zeo = read_vec('./data_AEI/AEI_zeo.csv', idx=False)
AEI_syn = read_vec('./data_AEI/AEI_syn.csv', idx=False)

vocab = WordVocab.load_vocab('./model_hub/vocab.pkl')
print('the vocab size is :', len(vocab))
charlen = len(vocab)
print('the total num of charset is :', charlen)

cudnn.benchmark = True
batch_size = 64

manual_seed = 42
random.seed(manual_seed)
torch.manual_seed(manual_seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# create the dataset and dataloader
AFI_dataset = Seq2seqDataset(AFI_zeo, AFI_syn, AFI_smiles, vocab)
CHA_dataset = Seq2seqDataset(CHA_zeo, CHA_syn, CHA_smiles, vocab)
AEI_dataset = Seq2seqDataset(AEI_zeo, AEI_syn, AEI_smiles, vocab)
AFI_dataloader = DataLoader(AFI_dataset, batch_size=batch_size, shuffle=True)
CHA_dataloader = DataLoader(CHA_dataset, batch_size=batch_size, shuffle=False)
AEI_dataloader = DataLoader(AEI_dataset, batch_size=batch_size, shuffle=False)

the vocab size is : 45
the total num of charset is : 45


In [6]:
# load the original model and contrastive model
# Suppose we have:
params = {
    'NCHARS': charlen,
    'MAX_LEN': MAX_LEN,
    'COND_DIM': 24,
    'hidden_dim': 256,
    'conv_depth': 3,
    'conv_dim_depth': 64,
    'conv_dim_width': 3,
    'conv_d_growth_factor': 2,
    'conv_w_growth_factor': 1,
    'middle_layer': 2,
    'activation': 'tanh',
    'batchnorm_conv': True,
    'batchnorm_mid': True,
    'dropout_rate_mid': 0.2,
    'gru_depth': 2,
    'recurrent_dim': 256,
    # ...
}

encoder = ConditionalEncoder(
    input_channels=params['NCHARS'],
    max_len=params['MAX_LEN'],
    cond_dim=params['COND_DIM'],
    hidden_dim=params['hidden_dim'],
    conv_depth=params['conv_depth'],
    conv_dim_depth=params['conv_dim_depth'],
    conv_dim_width=params['conv_dim_width'],
    conv_d_growth_factor=params['conv_d_growth_factor'],
    conv_w_growth_factor=params['conv_w_growth_factor'],
    middle_layer=params['middle_layer'],
    activation=params['activation'],
    batchnorm_conv=params['batchnorm_conv'],
    batchnorm_mid=params['batchnorm_mid'],
    dropout_rate_mid=params['dropout_rate_mid']
).to(device)

decoder = ConditionalDecoder(
    hidden_dim=params['hidden_dim'],
    cond_dim=params['COND_DIM'],
    n_chars=params['NCHARS'],
    max_len=params['MAX_LEN'],
    gru_depth=params['gru_depth'],
    recurrent_dim=params['recurrent_dim'],
    dropout=0.2
).to(device)

model_origin = ConditionalVAE(encoder, decoder)
# 加载 state_dict
state_dict = torch.load('./checkpoints/best_cvae_model.pth')
# 获取模型的 state_dict 的键
model_state_dict = model_origin.state_dict()
# 过滤出模型中存在的键
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
# 加载过滤后的 state_dict
model_origin.load_state_dict(filtered_state_dict, strict=False)

model_cl = ConditionalVAE(encoder, decoder)
# 加载 state_dict
state_dict = torch.load('./checkpoints/last_cvae_contrastive_model.pth')
# 获取模型的 state_dict 的键
model_state_dict = model_cl.state_dict()
# 过滤出模型中存在的键
filtered_state_dict = {k: v for k, v in state_dict.items() if k in model_state_dict}
# 加载过滤后的 state_dict
model_cl.load_state_dict(filtered_state_dict, strict=False)


# # model_origin.load_state_dict(torch.load('./checkpoints/best_cvae_model.pth'))
# # model_origin = torch.load('./checkpoints/best_cvae_model.pth')
# model_origin = model_origin.to(device)
# model_origin.eval()

# model_cl = ConditionalVAE(encoder, decoder)
# model_cl.load_state_dict(torch.load('./checkpoints/best_cvae_contrastive_model.pth'))
# # model_cl = torch.load('./checkpoints/best_cvae_contrastive_model.pth')
# model_cl = model_cl.to(device)
# model_cl.eval()

total = sum(p.numel() for p in model_origin.parameters())
print('total parameters: %0.2fM' % (total / 1e6))  # print the total parameters

total parameters: 1.15M


In [11]:
# generate the AFI smiles with the original model
generated_smile_origin = []
target_smile_origin = []
for i, (zeo, syn, tgt) in enumerate(tqdm(AFI_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    tgt_input = tgt[:, :-1].contiguous()
    tgt_input = F.one_hot(tgt_input, num_classes=charlen).float()
    generated_smiles = generate_cvae(model=model_origin, condition=condition_synthesis, tgt_input=tgt_input, vocab=vocab, device=device, max_len=MAX_LEN, temperature=1.0)
    generated_smile_origin.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_origin.extend(tgt_smiles)
    
# calculate the metrics
AFI_validity_rate_origin = validity_rate(generated_smile_origin)
AFI_uniqueness_rate_origin = uniqueness_rate(generated_smile_origin)
AFI_novelty_rate_origin = novelty_rate(generated_smile_origin, target_smile_origin)
AFI_reconstructability_rate_origin = reconstructability_rate(generated_smile_origin, target_smile_origin)
AFI_IntDiv_origin = IntDiv(generated_smile_origin)
AFI_FCD_score_origin = FCD_score(target_smile_origin, generated_smile_origin)
# print the metrics
print('AFI_validity_rate_origin: ', AFI_validity_rate_origin)
print('AFI_uniqueness_rate_origin: ', AFI_uniqueness_rate_origin)
print('AFI_novelty_rate_origin: ', AFI_novelty_rate_origin)
print('AFI_reconstructability_rate_origin: ', AFI_reconstructability_rate_origin)
print('AFI_IntDiv_origin: ', AFI_IntDiv_origin)
print('AFI_FCD_score_origin: ', AFI_FCD_score_origin)

100%|██████████| 16/16 [00:00<00:00, 21.39it/s]
[20:01:12] SMILES Parse Error: syntax error while parsing: C1CC((NNCCCC1CCCCC2)C1
[20:01:12] SMILES Parse Error: Failed parsing SMILES 'C1CC((NNCCCC1CCCCC2)C1' for input: 'C1CC((NNCCCC1CCCCC2)C1'
[20:01:12] SMILES Parse Error: syntax error while parsing: C1[CC3CCN+](CC(CCC)CC(CC)CC(C2)C3
[20:01:12] SMILES Parse Error: Failed parsing SMILES 'C1[CC3CCN+](CC(CCC)CC(CC)CC(C2)C3' for input: 'C1[CC3CCN+](CC(CCC)CC(CC)CC(C2)C3'
[20:01:12] SMILES Parse Error: syntax error while parsing: C1CCC((CCN+](C)(CCCCCC1C1)CC1C3
[20:01:12] SMILES Parse Error: Failed parsing SMILES 'C1CCC((CCN+](C)(CCCCCC1C1)CC1C3' for input: 'C1CCC((CCN+](C)(CCCCCC1C1)CC1C3'
[20:01:12] SMILES Parse Error: syntax error while parsing: C11N+](C)1(Cc1C11CcccCC(
[20:01:12] SMILES Parse Error: Failed parsing SMILES 'C11N+](C)1(Cc1C11CcccCC(' for input: 'C11N+](C)1(Cc1C11CcccCC('
[20:01:12] SMILES Parse Error: syntax error while parsing: C1[CC2N+]2(CCCCN+]C(CN+1CC(C((CC1C2N++C(C2)

AFI_validity_rate_origin:  0.001
AFI_uniqueness_rate_origin:  1.0
AFI_novelty_rate_origin:  1.0
AFI_reconstructability_rate_origin:  0.0
AFI_IntDiv_origin:  0.0
AFI_FCD_score_origin:  11.387737683282882


In [8]:
# generate the AFI smiles with the contrastive learning model
generated_smile_cl = []
target_smile_cl = []
for i, (zeo, syn, tgt) in enumerate(tqdm(AFI_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    tgt_input = tgt[:, :-1].contiguous()
    tgt_input = F.one_hot(tgt_input, num_classes=charlen).float()
    generated_smiles = generate_cvae(model=model_cl, condition=condition_synthesis, tgt_input=tgt_input, vocab=vocab, device=device, max_len=MAX_LEN, temperature=1.0)
    generated_smile_cl.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_cl.extend(tgt_smiles)

# calculate the metrics
AFI_validity_rate_cl = validity_rate(generated_smile_cl)
AFI_uniqueness_rate_cl = uniqueness_rate(generated_smile_cl)
AFI_novelty_rate_cl = novelty_rate(generated_smile_cl, target_smile_cl)
AFI_reconstructability_rate_cl = reconstructability_rate(generated_smile_cl, target_smile_cl)
AFI_IntDiv_cl = IntDiv(generated_smile_cl)
AFI_FCD_score_cl = FCD_score(target_smile_cl, generated_smile_cl)
# print the metrics
print('AFI_validity_rate_cl: ', AFI_validity_rate_cl)
print('AFI_uniqueness_rate_cl: ', AFI_uniqueness_rate_cl)
print('AFI_novelty_rate_cl: ', AFI_novelty_rate_cl)
print('AFI_reconstructability_rate_cl: ', AFI_reconstructability_rate_cl)
print('AFI_IntDiv_cl: ', AFI_IntDiv_cl)
print('AFI_FCD_score_cl: ', AFI_FCD_score_cl)

100%|██████████| 16/16 [00:00<00:00, 21.85it/s]
[20:00:55] SMILES Parse Error: ring closure 1 duplicates bond between atom 0 and atom 1 for input: 'C1C1C(C)NCNNCCcNC'
[20:00:55] SMILES Parse Error: syntax error while parsing: C1[(CCCN+]2CC2(CcCCCCCCC1
[20:00:55] SMILES Parse Error: Failed parsing SMILES 'C1[(CCCN+]2CC2(CcCCCCCCC1' for input: 'C1[(CCCN+]2CC2(CcCCCCCCC1'
[20:00:55] SMILES Parse Error: syntax error while parsing: C((C[[1ccc(c1
[20:00:55] SMILES Parse Error: Failed parsing SMILES 'C((C[[1ccc(c1' for input: 'C((C[[1ccc(c1'
[20:00:55] SMILES Parse Error: extra close parentheses while parsing: c1O1)[NN+]1ccn(c)c1
[20:00:55] SMILES Parse Error: Failed parsing SMILES 'c1O1)[NN+]1ccn(c)c1' for input: 'c1O1)[NN+]1ccn(c)c1'
[20:00:55] SMILES Parse Error: syntax error while parsing: CN+]n(CCCCC(C)2ccccc2)C(C1
[20:00:55] SMILES Parse Error: Failed parsing SMILES 'CN+]n(CCCCC(C)2ccccc2)C(C1' for input: 'CN+]n(CCCCC(C)2ccccc2)C(C1'
[20:00:55] SMILES Parse Error: syntax error while par

AFI_validity_rate_cl:  0.002
AFI_uniqueness_rate_cl:  1.0
AFI_novelty_rate_cl:  1.0
AFI_reconstructability_rate_cl:  0.0
AFI_IntDiv_cl:  1.0
AFI_FCD_score_cl:  11.030199787134464


In [9]:
# write the metrics to the folder data_AFI
with open('./data_AFI/AFI_generated_cvae_metrics.txt', 'w') as f:
    # write the mertics
    f.write(f'AFI_validity_rate_origin: {AFI_validity_rate_origin}, AFI_validity_rate_cl: {AFI_validity_rate_cl}\n')
    f.write(f'AFI_uniqueness_rate_origin: {AFI_uniqueness_rate_origin}, AFI_uniqueness_rate_cl: {AFI_uniqueness_rate_cl}\n')
    f.write(f'AFI_novelty_rate_origin: {AFI_novelty_rate_origin}, AFI_novelty_rate_cl: {AFI_novelty_rate_cl}\n')
    f.write(f'AFI_reconstructability_rate_origin: {AFI_reconstructability_rate_origin}, AFI_reconstructability_rate_cl: {AFI_reconstructability_rate_cl}\n')
    f.write(f'AFI_IntDiv_origin: {AFI_IntDiv_origin}, AFI_IntDiv_cl: {AFI_IntDiv_cl}\n')
    f.write(f'AFI_FCD_score_origin: {AFI_FCD_score_origin}, AFI_FCD_score_cl: {AFI_FCD_score_cl}\n')

# write the generated smiles (origin and cl) and target smiles to the folder data_AFI
with open('./data_AFI/AFI_generated_cvae_smiles_origin.txt', 'w') as f:
    for smiles in range(len(generated_smile_origin)):
        f.write(f'origin: {generated_smile_origin[smiles]}, cl: {generated_smile_cl[smiles]}, target: {target_smile_origin[smiles]}\n')

In [10]:
# generate the CHA smiles with the original model
generated_smile_origin = []
target_smile_origin = []
for i, (zeo, syn, tgt) in enumerate(tqdm(CHA_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    tgt_input = F.one_hot(tgt_input, num_classes=charlen).float()
    generated_smiles = generate_cvae(model=model_origin, condition=condition_synthesis, tgt_input=tgt_input, vocab=vocab, device=device, max_len=MAX_LEN, temperature=1.0)
    generated_smile_origin.extend(generated_smiles)
    generated_smile_origin.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_origin.extend(tgt_smiles)

# calculate the metrics
CHA_validity_rate_origin = validity_rate(generated_smile_origin)
CHA_uniqueness_rate_origin = uniqueness_rate(generated_smile_origin)
CHA_novelty_rate_origin = novelty_rate(generated_smile_origin, target_smile_origin)
CHA_reconstructability_rate_origin = reconstructability_rate(generated_smile_origin, target_smile_origin)
CHA_IntDiv_origin = IntDiv(generated_smile_origin)
CHA_FCD_score_origin = FCD_score(target_smile_origin, generated_smile_origin)
# print the metrics
print('CHA_validity_rate_origin: ', CHA_validity_rate_origin)
print('CHA_uniqueness_rate_origin: ', CHA_uniqueness_rate_origin)
print('CHA_novelty_rate_origin: ', CHA_novelty_rate_origin)
print('CHA_reconstructability_rate_origin: ', CHA_reconstructability_rate_origin)
print('CHA_IntDiv_origin: ', CHA_IntDiv_origin)
print('CHA_FCD_score_origin: ', CHA_FCD_score_origin)

  0%|          | 0/16 [00:00<?, ?it/s]


RuntimeError: one_hot is only applicable to index tensor.

In [None]:
# generate the CHA smiles with the contrastive learning model
generated_smile_cl = []
target_smile_cl = []
for i, (zeo, syn, tgt) in enumerate(tqdm(CHA_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    tgt_input = tgt[:, :-1].contiguous()
    tgt_input = F.one_hot(tgt_input, num_classes=charlen).float()
    generated_smiles = generate_cvae(model=model_cl, condition=condition_synthesis, tgt_input=tgt_input, vocab=vocab, device=device, max_len=MAX_LEN, temperature=1.0)
    generated_smile_cl.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_cl.extend(tgt_smiles)

# calculate the metrics
CHA_validity_rate_cl = validity_rate(generated_smile_cl)
CHA_uniqueness_rate_cl = uniqueness_rate(generated_smile_cl)
CHA_novelty_rate_cl = novelty_rate(generated_smile_cl, target_smile_cl)
CHA_reconstructability_rate_cl = reconstructability_rate(generated_smile_cl, target_smile_cl)
CHA_IntDiv_cl = IntDiv(generated_smile_cl)
CHA_FCD_score_cl = FCD_score(target_smile_cl, generated_smile_cl)
# print the metrics
print('CHA_validity_rate_cl: ', CHA_validity_rate_cl)
print('CHA_uniqueness_rate_cl: ', CHA_uniqueness_rate_cl)
print('CHA_novelty_rate_cl: ', CHA_novelty_rate_cl)
print('CHA_reconstructability_rate_cl: ', CHA_reconstructability_rate_cl)
print('CHA_IntDiv_cl: ', CHA_IntDiv_cl)
print('CHA_FCD_score_cl: ', CHA_FCD_score_cl)

In [None]:
# write the metrics to the folder data_CHA
with open('./data_CHA/CHA_generated_cvae_metrics.txt', 'w') as f:
    # write the mertics
    f.write(f'CHA_validity_rate_origin: {CHA_validity_rate_origin}, CHA_validity_rate_cl: {CHA_validity_rate_cl}\n')
    f.write(f'CHA_uniqueness_rate_origin: {CHA_uniqueness_rate_origin}, CHA_uniqueness_rate_cl: {CHA_uniqueness_rate_cl}\n')
    f.write(f'CHA_novelty_rate_origin: {CHA_novelty_rate_origin}, CHA_novelty_rate_cl: {CHA_novelty_rate_cl}\n')
    f.write(f'CHA_reconstructability_rate_origin: {CHA_reconstructability_rate_origin}, CHA_reconstructability_rate_cl: {CHA_reconstructability_rate_cl}\n')
    f.write(f'CHA_IntDiv_origin: {CHA_IntDiv_origin}, CHA_IntDiv_cl: {CHA_IntDiv_cl}\n')
    f.write(f'CHA_FCD_score_origin: {CHA_FCD_score_origin}, CHA_FCD_score_cl: {CHA_FCD_score_cl}\n')
# write the generated smiles (origin and cl) and target smiles to the folder data_CHA
with open('./data_CHA/CHA_generated_cvae_smiles_origin.txt', 'w') as f:
    for smiles in range(len(generated_smile_origin)):
        f.write(f'origin: {generated_smile_origin[smiles]}, cl: {generated_smile_cl[smiles]}, target: {target_smile_origin[smiles]}\n')

In [None]:
# generate the AEI smiles with the original model
generated_smile_origin = []
target_smile_origin = []
for i, (zeo, syn, tgt) in enumerate(tqdm(AEI_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    tgt_input = F.one_hot(tgt_input, num_classes=charlen).float()
    generated_smiles = generate_cvae(model=model_origin, condition=condition_synthesis, tgt_input=tgt_input, vocab=vocab, device=device, max_len=MAX_LEN, temperature=1.0)
    generated_smile_origin.extend(generated_smiles)
    generated_smile_origin.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_origin.extend(tgt_smiles)

# calculate the metrics
AEI_validity_rate_origin = validity_rate(generated_smile_origin)
AEI_uniqueness_rate_origin = uniqueness_rate(generated_smile_origin)
AEI_novelty_rate_origin = novelty_rate(generated_smile_origin, target_smile_origin)
AEI_reconstructability_rate_origin = reconstructability_rate(generated_smile_origin, target_smile_origin)
AEI_IntDiv_origin = IntDiv(generated_smile_origin)
AEI_FCD_score_origin = FCD_score(target_smile_origin, generated_smile_origin)
# print the metrics
print('AEI_validity_rate_origin: ', AEI_validity_rate_origin)
print('AEI_uniqueness_rate_origin: ', AEI_uniqueness_rate_origin)
print('AEI_novelty_rate_origin: ', AEI_novelty_rate_origin)
print('AEI_reconstructability_rate_origin: ', AEI_reconstructability_rate_origin)
print('AEI_IntDiv_origin: ', AEI_IntDiv_origin)
print('AEI_FCD_score_origin: ', AEI_FCD_score_origin)

In [None]:
# generate the AEI smiles with the contrastive learning model
generated_smile_cl = []
target_smile_cl = []
for i, (zeo, syn, tgt) in enumerate(tqdm(AEI_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    tgt_input = tgt[:, :-1].contiguous()
    tgt_input = F.one_hot(tgt_input, num_classes=charlen).float()
    generated_smiles = generate_cvae(model=model_cl, condition=condition_synthesis, tgt_input=tgt_input, vocab=vocab, device=device, max_len=MAX_LEN, temperature=1.0)
    generated_smile_cl.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_cl.extend(tgt_smiles)

# calculate the metrics
AEI_validity_rate_cl = validity_rate(generated_smile_cl)
AEI_uniqueness_rate_cl = uniqueness_rate(generated_smile_cl)
AEI_novelty_rate_cl = novelty_rate(generated_smile_cl, target_smile_cl)
AEI_reconstructability_rate_cl = reconstructability_rate(generated_smile_cl, target_smile_cl)
AEI_IntDiv_cl = IntDiv(generated_smile_cl)
AEI_FCD_score_cl = FCD_score(target_smile_cl, generated_smile_cl)
# print the metrics
print('AEI_validity_rate_cl: ', AEI_validity_rate_cl)
print('AEI_uniqueness_rate_cl: ', AEI_uniqueness_rate_cl)
print('AEI_novelty_rate_cl: ', AEI_novelty_rate_cl)
print('AEI_reconstructability_rate_cl: ', AEI_reconstructability_rate_cl)
print('AEI_IntDiv_cl: ', AEI_IntDiv_cl)
print('AEI_FCD_score_cl: ', AEI_FCD_score_cl)

In [None]:
# write the metrics to the folder data_AEI
with open('./data_AEI/AEI_generated_cvae_metrics.txt', 'w') as f:
    # write the mertics
    f.write(f'AEI_validity_rate_origin: {AEI_validity_rate_origin}, AEI_validity_rate_cl: {AEI_validity_rate_cl}\n')
    f.write(f'AEI_uniqueness_rate_origin: {AEI_uniqueness_rate_origin}, AEI_uniqueness_rate_cl: {AEI_uniqueness_rate_cl}\n')
    f.write(f'AEI_novelty_rate_origin: {AEI_novelty_rate_origin}, AEI_novelty_rate_cl: {AEI_novelty_rate_cl}\n')
    f.write(f'AEI_reconstructability_rate_origin: {AEI_reconstructability_rate_origin}, AEI_reconstructability_rate_cl: {AEI_reconstructability_rate_cl}\n')
    f.write(f'AEI_IntDiv_origin: {AEI_IntDiv_origin}, AEI_IntDiv_cl: {AEI_IntDiv_cl}\n')
    f.write(f'AEI_FCD_score_origin: {AEI_FCD_score_origin}, AEI_FCD_score_cl: {AEI_FCD_score_cl}\n')
# write the generated smiles (origin and cl) and target smiles to the folder data_AEI
with open('./data_AEI/AEI_generated_cvae_smiles_origin.txt', 'w') as f:
    for smiles in range(len(generated_smile_origin)):
        f.write(f'origin: {generated_smile_origin[smiles]}, cl: {generated_smile_cl[smiles]}, target: {target_smile_origin[smiles]}\n')