In [None]:
# change working path to the current file
%cd ..

In [None]:
import numpy as np
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn import functional as F
from tqdm import tqdm
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
import torch.backends.cudnn as cudnn

# import custom modules
from models.ddc import SMILESGenerator
from utils.utils import *
from datasets.data_loader import *
from utils.plot_figures import *
from utils.metrics import *
from utils.build_vocab import *

In [None]:
cudnn.benchmark = True
cudnn.enabled = True

train_loss_history = []
train_acc_history = []
test_loss_history = []
test_acc_history = []

now = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')

PAD = 0
UNK = 1
EOS = 2
SOS = 3
MASK = 4
MAX_LEN = 220

In [None]:
def generate_ddc(model, start_sequence, condition, max_length, vocab,  device, temperature=0.5, top_k=0):
    """
    Autoregressive generation process for a GPT model.

    Args:
        model (ddc): The pre-trained GPT model for token generation.
        start_sequence (torch.Tensor): The initial sequence to start generation (batch_size, seq_length).
        condition_props (torch.Tensor): The conditional property vector (batch_size, num_props).
        max_length (int): The maximum length of the generated sequence.
        vocab: The vocabulary object for encoding and decoding SMILES strings.
        device (torch.device): The device on which to run the generation.
        temperature (float): Temperature parameter for sampling; higher values increase randomness.
        top_k (int): Limits sampling to top-k logits; if 0, no top-k sampling is applied.

    Returns:
        List[str]: A list of generated SMILES strings.
    """

    model.eval()
    batch_size = start_sequence.size(0)
    generated_seq = start_sequence.clone().to(device)
    
    with torch.no_grad():
        # We will iteratively fill positions from [start_len ... seqlen-1]
        for cur_len in range(max_length - start_sequence.size(1)):
            current_len = generated_seq.size(1)
            generated_seq_hot = F.one_hot(generated_seq, num_classes=len(vocab)).float()
            # forward pass to get logits
            logits = model(condition, generated_seq_hot)
            # extract the logits for the next token
            next_token_logits = logits[:, -1, :]
            # temperature scaling
            next_token_logits /= temperature
            # top-k sampling
            if top_k > 0:
                next_token_logits, top_k_indices = torch.topk(next_token_logits, top_k, dim=-1)
                next_token_probs = F.softmax(next_token_logits, dim=-1)
                next_token = torch.multinomial(next_token_probs, num_samples=1)
            else:
                # sample from the distribution
                next_token = torch.multinomial(F.softmax(next_token_logits, dim=-1), num_samples=1)
            # append the new token to the sequence
            generated_seq = torch.cat([generated_seq, next_token], dim=-1)
            
            # check if all sequences have reached EOS
            if all(next_token[i].item() == EOS for i in range(batch_size)):
                break
    
    # Decode the generated sequences into SMILES strings
    generated_smiles = []
    for seq in generated_seq:
        # Convert indices to characters, ignoring padding and start tokens
        # check if the generated sequence contains the end token, if meet, stop decoding
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        generated_smiles.append(smiles)

    return generated_smiles


In [None]:
# load the data
AFI_smiles = read_strings('./data_AFI/AFI_smiles.csv', idx=False)
AFI_zeo = read_vec('./data_AFI/AFI_zeo.csv', idx=False)
AFI_syn = read_vec('./data_AFI/AFI_syn.csv', idx=False)
CHA_smiles = read_strings('./data_CHA/CHA_smiles.csv', idx=False)
CHA_zeo = read_vec('./data_CHA/CHA_zeo.csv', idx=False)
CHA_syn = read_vec('./data_CHA/CHA_syn.csv', idx=False)

vocab = WordVocab.load_vocab('./model_hub/vocab.pkl')
print('the vocab size is :', len(vocab))
charlen = len(vocab)
print('the total num of charset is :', charlen)

cudnn.benchmark = True
batch_size = 64

manual_seed = 42
random.seed(manual_seed)
torch.manual_seed(manual_seed)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# create the dataset and dataloader
AFI_dataset = Seq2seqDataset(AFI_zeo, AFI_syn, AFI_smiles, vocab)
CHA_dataset = Seq2seqDataset(CHA_zeo, CHA_syn, CHA_smiles, vocab)
AFI_dataloader = DataLoader(AFI_dataset, batch_size=batch_size, shuffle=True)
CHA_dataloader = DataLoader(CHA_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# load the model
model_origin = SMILESGenerator(condition_dim=24, lstm_dim=256, dec_layers=3, charset_size=charlen)
model_origin.load_state_dict(torch.load('./checkpoints/best_ddc_model.pth'))
model_origin.to(device)
model_origin.eval()

model_cl = SMILESGenerator(condition_dim=24, lstm_dim=256, dec_layers=3, charset_size=charlen)
model_cl.load_state_dict(torch.load('./checkpoints/best_ddc_contrastive_model.pth'))
model_cl.to(device)
model_cl.eval()

total = sum(p.numel() for p in model_origin.parameters())
print('total parameters: %0.2fM' % (total / 1e6))  # print the total parameters

In [None]:
# generate the AFI smiles with the original model
generated_smile_origin = []
target_smile_origin = []
for i, (zeo, syn, tgt) in enumerate(tqdm(AFI_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_ddc(model_origin, tgt[:, :10], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_origin.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_origin.extend(tgt_smiles)
    
# calculate the metrics
AFI_validity_rate_origin = validity_rate(generated_smile_origin)
AFI_uniqueness_rate_origin = uniqueness_rate(generated_smile_origin)
AFI_novelty_rate_origin = novelty_rate(generated_smile_origin, target_smile_origin)
AFI_reconstructability_rate_origin = reconstructability_rate(generated_smile_origin, target_smile_origin)
AFI_IntDiv_origin = IntDiv(generated_smile_origin)
AFI_FCD_score_origin = FCD_score(target_smile_origin, generated_smile_origin)
# print the metrics
print('AFI_validity_rate_origin: ', AFI_validity_rate_origin)
print('AFI_uniqueness_rate_origin: ', AFI_uniqueness_rate_origin)
print('AFI_novelty_rate_origin: ', AFI_novelty_rate_origin)
print('AFI_reconstructability_rate_origin: ', AFI_reconstructability_rate_origin)
print('AFI_IntDiv_origin: ', AFI_IntDiv_origin)
print('AFI_FCD_score_origin: ', AFI_FCD_score_origin)

In [None]:
# generate the AFI smiles with the contrastive learning model
generated_smile_cl = []
target_smile_cl = []
for i, (zeo, syn, tgt) in enumerate(tqdm(AFI_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_ddc(model_origin, tgt[:, :10], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_cl.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_cl.extend(tgt_smiles)

# calculate the metrics
AFI_validity_rate_cl = validity_rate(generated_smile_cl)
AFI_uniqueness_rate_cl = uniqueness_rate(generated_smile_cl)
AFI_novelty_rate_cl = novelty_rate(generated_smile_cl, target_smile_cl)
AFI_reconstructability_rate_cl = reconstructability_rate(generated_smile_cl, target_smile_cl)
AFI_IntDiv_cl = IntDiv(generated_smile_cl)
AFI_FCD_score_cl = FCD_score(target_smile_cl, generated_smile_cl)
# print the metrics
print('AFI_validity_rate_cl: ', AFI_validity_rate_cl)
print('AFI_uniqueness_rate_cl: ', AFI_uniqueness_rate_cl)
print('AFI_novelty_rate_cl: ', AFI_novelty_rate_cl)
print('AFI_reconstructability_rate_cl: ', AFI_reconstructability_rate_cl)
print('AFI_IntDiv_cl: ', AFI_IntDiv_cl)
print('AFI_FCD_score_cl: ', AFI_FCD_score_cl)

In [None]:
# write the metrics to the folder data_AFI
with open('./data_AFI/AFI_generated_ddc_metrics.txt', 'w') as f:
    # write the mertics
    f.write(f'AFI_validity_rate_origin: {AFI_validity_rate_origin}, AFI_validity_rate_cl: {AFI_validity_rate_cl}\n')
    f.write(f'AFI_uniqueness_rate_origin: {AFI_uniqueness_rate_origin}, AFI_uniqueness_rate_cl: {AFI_uniqueness_rate_cl}\n')
    f.write(f'AFI_novelty_rate_origin: {AFI_novelty_rate_origin}, AFI_novelty_rate_cl: {AFI_novelty_rate_cl}\n')
    f.write(f'AFI_reconstructability_rate_origin: {AFI_reconstructability_rate_origin}, AFI_reconstructability_rate_cl: {AFI_reconstructability_rate_cl}\n')
    f.write(f'AFI_IntDiv_origin: {AFI_IntDiv_origin}, AFI_IntDiv_cl: {AFI_IntDiv_cl}\n')
    f.write(f'AFI_FCD_score_origin: {AFI_FCD_score_origin}, AFI_FCD_score_cl: {AFI_FCD_score_cl}\n')

# write the generated smiles (origin and cl) and target smiles to the folder data_AFI
with open('./data_AFI/AFI_generated_ddc_smiles_origin.txt', 'w') as f:
    for smiles in range(len(generated_smile_origin)):
        f.write(f'origin: {generated_smile_origin[smiles]}, cl: {generated_smile_cl[smiles]}, target: {target_smile_origin[smiles]}\n')

In [None]:
# generate the CHA smiles with the original model
generated_smile_origin = []
target_smile_origin = []
for i, (zeo, syn, tgt) in enumerate(tqdm(CHA_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_ddc(model_origin, tgt[:, :10], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_origin.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_origin.extend(tgt_smiles)

# calculate the metrics
CHA_validity_rate_origin = validity_rate(generated_smile_origin)
CHA_uniqueness_rate_origin = uniqueness_rate(generated_smile_origin)
CHA_novelty_rate_origin = novelty_rate(generated_smile_origin, target_smile_origin)
CHA_reconstructability_rate_origin = reconstructability_rate(generated_smile_origin, target_smile_origin)
CHA_IntDiv_origin = IntDiv(generated_smile_origin)
CHA_FCD_score_origin = FCD_score(target_smile_origin, generated_smile_origin)
# print the metrics
print('CHA_validity_rate_origin: ', CHA_validity_rate_origin)
print('CHA_uniqueness_rate_origin: ', CHA_uniqueness_rate_origin)
print('CHA_novelty_rate_origin: ', CHA_novelty_rate_origin)
print('CHA_reconstructability_rate_origin: ', CHA_reconstructability_rate_origin)
print('CHA_IntDiv_origin: ', CHA_IntDiv_origin)
print('CHA_FCD_score_origin: ', CHA_FCD_score_origin)

In [None]:
# generate the CHA smiles with the contrastive learning model
generated_smile_cl = []
target_smile_cl = []
for i, (zeo, syn, tgt) in enumerate(tqdm(CHA_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1)
    generated_smiles = generate_ddc(model_cl, tgt[:, :10], condition_synthesis, MAX_LEN, vocab, device, 0.5)
    generated_smile_cl.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile_cl.extend(tgt_smiles)

# calculate the metrics
CHA_validity_rate_cl = validity_rate(generated_smile_cl)
CHA_uniqueness_rate_cl = uniqueness_rate(generated_smile_cl)
CHA_novelty_rate_cl = novelty_rate(generated_smile_cl, target_smile_cl)
CHA_reconstructability_rate_cl = reconstructability_rate(generated_smile_cl, target_smile_cl)
CHA_IntDiv_cl = IntDiv(generated_smile_cl)
CHA_FCD_score_cl = FCD_score(target_smile_cl, generated_smile_cl)
# print the metrics
print('CHA_validity_rate_cl: ', CHA_validity_rate_cl)
print('CHA_uniqueness_rate_cl: ', CHA_uniqueness_rate_cl)
print('CHA_novelty_rate_cl: ', CHA_novelty_rate_cl)
print('CHA_reconstructability_rate_cl: ', CHA_reconstructability_rate_cl)
print('CHA_IntDiv_cl: ', CHA_IntDiv_cl)
print('CHA_FCD_score_cl: ', CHA_FCD_score_cl)

In [None]:
# write the metrics to the folder data_CHA
with open('./data_CHA/CHA_generated_ddc_metrics.txt', 'w') as f:
    # write the mertics
    f.write(f'CHA_validity_rate_origin: {CHA_validity_rate_origin}, CHA_validity_rate_cl: {CHA_validity_rate_cl}\n')
    f.write(f'CHA_uniqueness_rate_origin: {CHA_uniqueness_rate_origin}, CHA_uniqueness_rate_cl: {CHA_uniqueness_rate_cl}\n')
    f.write(f'CHA_novelty_rate_origin: {CHA_novelty_rate_origin}, CHA_novelty_rate_cl: {CHA_novelty_rate_cl}\n')
    f.write(f'CHA_reconstructability_rate_origin: {CHA_reconstructability_rate_origin}, CHA_reconstructability_rate_cl: {CHA_reconstructability_rate_cl}\n')
    f.write(f'CHA_IntDiv_origin: {CHA_IntDiv_origin}, CHA_IntDiv_cl: {CHA_IntDiv_cl}\n')
    f.write(f'CHA_FCD_score_origin: {CHA_FCD_score_origin}, CHA_FCD_score_cl: {CHA_FCD_score_cl}\n')
# write the generated smiles (origin and cl) and target smiles to the folder data_CHA
with open('./data_CHA/CHA_generated_ddc_smiles_origin.txt', 'w') as f:
    for smiles in range(len(generated_smile_origin)):
        f.write(f'origin: {generated_smile_origin[smiles]}, cl: {generated_smile_cl[smiles]}, target: {target_smile_origin[smiles]}\n')