# Train CVAE Model for Conditional Generation Using Pre-setting Vocab and Contrastive Learning

In [1]:
# change working path to the current file
%cd ..

/home/hudongcheng/Desktop/bo_osda_generator


In [2]:
import numpy as np
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn import functional as F
from tqdm import tqdm
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
import torch.backends.cudnn as cudnn

# import custom modules
from datasets.data_loader import *
from models.cvae import *
from models.loss import InfoNCELoss
from models.trfm import *
from utils.utils import *
from utils.plot_figures import *
from utils.metrics import *
from utils.build_vocab import *

In [3]:
cudnn.benchmark = True
cudnn.enabled = True

train_loss_history = []
train_acc_history = []
test_loss_history = []
test_acc_history = []
train_ce_loss_history = []
test_ce_loss_history = []
train_info_loss_history = []
test_info_loss_history = []

log_dir = './logs/'
save_best_weight_path = './checkpoints/'

now = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')

In [4]:
PAD = 0
UNK = 1
EOS = 2
SOS = 3
MASK = 4
MAX_LEN = 220

In [5]:
# read the data from the file
train_data = pd.read_csv('./data/train_contrastive_dataset.csv')
test_data = pd.read_csv('./data/test_contrastive_dataset.csv')

vocab = WordVocab.load_vocab('./model_hub/vocab.pkl')
print('the vocab size is :', len(vocab))

charlen = len(vocab)
print('the total num of charset is :', charlen)

the vocab size is : 45
the total num of charset is : 45


In [6]:
# hyperparameters
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 64
epoch = 10
InfoNCEloss_weight = 0.1

In [7]:
# create the dataset and dataloader
# train_dataset = Contrastive_Seq2seqDataset(train_data, vocab, MAX_LEN)
# test_dataset = Contrastive_Seq2seqDataset(test_data, vocab, MAX_LEN)
train_dataset = Contrastive_Seq2seqDataset_random(train_data, vocab, MAX_LEN)
test_dataset = Contrastive_Seq2seqDataset_random(test_data, vocab, MAX_LEN)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

100%|██████████| 144938/144938 [00:00<00:00, 887094.07it/s] 
100%|██████████| 144938/144938 [00:00<00:00, 806217.82it/s]
100%|██████████| 144938/144938 [00:02<00:00, 50973.11it/s]
100%|██████████| 14803/14803 [00:00<00:00, 1686631.59it/s]
100%|██████████| 14803/14803 [00:00<00:00, 1064030.06it/s]
100%|██████████| 14803/14803 [00:00<00:00, 51331.55it/s]


In [8]:
# Suppose we have:
params = {
    'NCHARS': charlen,
    'MAX_LEN': MAX_LEN,
    'COND_DIM': 24,
    'hidden_dim': 256,
    'conv_depth': 3,
    'conv_dim_depth': 64,
    'conv_dim_width': 3,
    'conv_d_growth_factor': 2,
    'conv_w_growth_factor': 1,
    'middle_layer': 2,
    'activation': 'tanh',
    'batchnorm_conv': True,
    'batchnorm_mid': True,
    'dropout_rate_mid': 0.2,
    'gru_depth': 2,
    'recurrent_dim': 256,
    # ...
}

encoder = ConditionalEncoder(
    input_channels=params['NCHARS'],
    max_len=params['MAX_LEN'],
    cond_dim=params['COND_DIM'],
    hidden_dim=params['hidden_dim'],
    conv_depth=params['conv_depth'],
    conv_dim_depth=params['conv_dim_depth'],
    conv_dim_width=params['conv_dim_width'],
    conv_d_growth_factor=params['conv_d_growth_factor'],
    conv_w_growth_factor=params['conv_w_growth_factor'],
    middle_layer=params['middle_layer'],
    activation=params['activation'],
    batchnorm_conv=params['batchnorm_conv'],
    batchnorm_mid=params['batchnorm_mid'],
    dropout_rate_mid=params['dropout_rate_mid']
).to(device)

decoder = ConditionalDecoder(
    hidden_dim=params['hidden_dim'],
    cond_dim=params['COND_DIM'],
    n_chars=params['NCHARS'],
    max_len=params['MAX_LEN'],
    gru_depth=params['gru_depth'],
    recurrent_dim=params['recurrent_dim'],
    dropout=0.2
).to(device)

model = ConditionalVAE(encoder, decoder).to(device)

# load model
trfm = TrfmSeq2seq(charlen, 256, charlen, 4).to(device)
trfm.load_state_dict(torch.load('./model_hub/trfm_new_4_130000.pkl'))
trfm.eval()

# set trfm gradient to false which won't be updated
for param in trfm.parameters():
    param.requires_grad = False

# loss
loss_func = torch.nn.CrossEntropyLoss(ignore_index=PAD)
optim = torch.optim.Adam(model.parameters(), lr=6e-4)
total = sum(p.numel() for p in model.parameters())
print('total parameters: %0.2fM' % (total / 1e6))

total parameters: 1.15M


In [9]:
# train function
def train(model, trfm, dataloader, loss_func, optim, device, kl_weight=0.001, pad_idx=0, weight=0.1):
    """
    Args:
        model: your Conditional VAE
        trfm: pretrained transformer model for contrastive learning
        dataloader: yields (zeo, syn, tgt) each step
        loss_func: typically nn.CrossEntropyLoss(ignore_index=pad_idx) or similar
        optim: torch optimizer
        device: 'cuda' or 'cpu'
        kl_weight: scaling factor for the KL loss term
        pad_idx: optional index for <pad>, used in ignoring pad in cross-entropy
        weight: weight for the InfoNCE loss term
    """
    model.train()
    total_loss = 0
    total_loss_crossentropy = 0
    total_loss_infonce = 0
    total_acc = 0
    total_num = 0

    for i, (zeo, syn, tgt, positive_smiles, negative_smiles) in enumerate(tqdm(dataloader)):
        # Move data to device
        zeo = zeo.to(device)
        syn = syn.to(device)
        tgt = tgt.to(device)
        positive_smiles = positive_smiles.to(device)
        negative_smiles = negative_smiles.to(device)

        # Concatenate to form the full condition
        # shape: (batch_size, cond_dim1 + cond_dim2)
        condition = torch.cat([zeo, syn], dim=-1)
        # tgt: shape (B, seq_len, n_chars) if one-hot approach
        tgt_input = tgt[:, :-1].contiguous() # input to the decoder
        tgt_target = tgt[:, 1:].contiguous() # target output from the decoder
        # convert tgt_input to one -hot code
        tgt_input = F.one_hot(tgt_input, num_classes=charlen).float()

        # ==========================
        # 1) Forward pass
        # ==========================.
        optim.zero_grad()
        logits, z_mean, z_log_var = model(
            x_smi=tgt_input,         # SMILES input
            x_cond=condition,  # condition
            teacher_force_inputs=tgt_input
        )

        # ==========================
        # 2) Compute reconstruction loss
        # ==========================
        # logits: shape (B, seq_len, n_chars)
        logits_reshaped = logits.view(-1, logits.size(-1))
        target_ids = tgt_target.view(-1)  # shape (B * seq_len)
        recon_loss = loss_func(logits_reshaped, target_ids)
        

        # ==========================
        # 3) Compute KL divergence
        # ==========================
        # KL = -0.5 * sum(1 + log_var - mean^2 - exp(log_var)) 
        # We'll average by batch size
        kl_loss = -0.5 * torch.sum(
            1 + z_log_var - z_mean.pow(2) - z_log_var.exp()
        )
        kl_loss = kl_loss / tgt.size(0)
        
        # ==========================
        # 4) Compute InfoNCE loss
        # ==========================
        samples = F.gumbel_softmax(logits, tau=1.0, hard=True)
        samples_indices = torch.argmax(samples, dim=-1) # shape (B, seq_len)
        # add the start token to the samples
        stared_token = torch.ones(samples_indices.size(0), 1, dtype=torch.long).fill_(SOS).to(device)
        samples_indices = torch.cat([stared_token, samples_indices], dim=-1)
        
        # Compute InfoNCE Loss
        loss_infonce = InfoNCELoss(samples_indices, positive_smiles, negative_smiles, trfm, temperature=0.07)

        # ==========================
        # 5) Total loss
        # ==========================
        loss = recon_loss + kl_weight * kl_loss + weight * loss_infonce

        # ==========================
        # 6) Backprop & update
        # ==========================
        loss.backward()
        optim.step()

        # ==========================
        # 7) Compute accuracy (optional)
        # ==========================
        preds = torch.argmax(logits, dim=-1)  # shape (B, seq_len)
        num_correct = (preds == tgt_target) & (tgt_target != pad_idx)
        num_words = (tgt_target != pad_idx).sum().item()

        # ==========================
        # 8) Track stats
        # ==========================
        total_loss += loss.item()
        total_loss_crossentropy += recon_loss.item()
        total_loss_infonce += loss_infonce.item()
        total_acc += num_correct.sum().item()
        total_num += num_words
    
    return total_loss / len(dataloader), total_loss_crossentropy / len(dataloader), total_loss_infonce / len(dataloader), total_acc / total_num

In [10]:
# evaluate function
def evaluate(model, trfm, dataloader, loss_func, device, pad_idx=0, weight=0.1):
    """
    Args:
        model: your Conditional VAE
        trfm: pretrained transformer model for contrastive learning
        dataloader: yields (zeo, syn, tgt) each step
        loss_func: typically nn.CrossEntropyLoss(ignore_index=pad_idx) or similar
        device: 'cuda' or 'cpu'
        pad_idx: optional index for <pad>, used in ignoring pad in cross-entropy
        weight: weight for the InfoNCE loss term
    """
    model.eval()
    total_loss = 0
    total_loss_crossentropy = 0
    total_loss_infonce = 0
    total_acc = 0
    total_num = 0

    with torch.no_grad():  # Disable gradient calculation during testing
        for i, (zeo, syn, tgt) in enumerate(tqdm(dataloader)):
            # Move data to device
            zeo = zeo.to(device)
            syn = syn.to(device)
            tgt = tgt.to(device)
            positive_smiles = positive_smiles.to(device)
            negative_smiles = negative_smiles.to(device)

            # Concatenate to form the full condition
            condition = torch.cat([zeo, syn], dim=-1)
            
            # One-hot encode the target labels
            # tgt: shape (B, seq_len, n_chars) if one-hot approach
            tgt_input = tgt[:, :-1].contiguous() # input to the decoder
            tgt_target = tgt[:, 1:].contiguous() # target output from the decoder
            # convert tgt_input to one -hot code
            tgt_input = F.one_hot(tgt_input, num_classes=charlen).float()

            # ==========================
            # 1) Forward pass
            # ==========================
            # We use teacher forcing during testing too
            logits, z_mean, z_log_var = model(
                x_smi=tgt_input,         # SMILES input
                x_cond=condition,        # condition
                teacher_force_inputs=tgt_input  # Use teacher forcing during inference
            )
            # logits: shape (B, seq_len, NCHARS) - Output predictions

            # ==========================
            # 2) Compute reconstruction loss
            # ==========================
            logits_reshaped = logits.view(-1, logits.size(-1))
            target_ids = tgt_target.view(-1)
            recon_loss = loss_func(logits_reshaped, target_ids)
            
        # ==========================
        # 3) Compute KL divergence
        # ==========================
        # KL = -0.5 * sum(1 + log_var - mean^2 - exp(log_var)) 
        # We'll average by batch size
        kl_loss = -0.5 * torch.sum(
            1 + z_log_var - z_mean.pow(2) - z_log_var.exp()
        )
        kl_loss = kl_loss / tgt.size(0)
        
        # ==========================
        # 4) Compute InfoNCE loss
        # ==========================
        samples = F.gumbel_softmax(logits, tau=1.0, hard=True)
        samples_indices = torch.argmax(samples, dim=-1) # shape (B, seq_len)
        # add the start token to the samples
        stared_token = torch.ones(samples_indices.size(0), 1, dtype=torch.long).fill_(SOS).to(device)
        samples_indices = torch.cat([stared_token, samples_indices], dim=-1)
        
        # Compute InfoNCE Loss
        loss_infonce = InfoNCELoss(samples_indices, positive_smiles, negative_smiles, trfm, temperature=0.07)

        # ==========================
        # 5) Total loss
        # ==========================
        kl_weight = 0.1
        loss = recon_loss + kl_weight * kl_loss + weight * loss_infonce

        # ==========================
        # 6) Compute accuracy (optional)
        # ==========================
        preds = torch.argmax(logits, dim=-1)  # shape (B, seq_len)
        num_correct = (preds == tgt_target) & (tgt_target != pad_idx)
        num_words = (tgt_target != pad_idx).sum().item()

        # ==========================
        # 7) Track stats
        # ==========================
        total_loss += loss.item()
        total_loss_crossentropy += recon_loss.item()
        total_loss_infonce += loss_infonce.item()
        total_acc += num_correct.sum().item()
        total_num += num_words
    
    return total_loss / len(dataloader), total_loss_crossentropy / len(dataloader), total_loss_infonce / len(dataloader), total_acc / total_num
            


In [11]:
# train the model
for i in range(epoch):
    train_loss, train_acc, train_ce, train_info = train(model, trfm, train_dataloader, loss_func, optim, device)
    train_loss_history.append(train_loss)
    train_acc_history.append(train_acc)
    train_ce_loss_history.append(train_ce)
    train_info_loss_history.append(train_info)
    print('epoch: %d, train loss: %.4f, train acc: %.4f, train crossentropy loss: %.4f, train infonce loss: %.4f' % (i, train_loss, train_acc, train_ce, train_info))
    test_loss, test_acc, test_ce, test_info = evaluate(model, trfm, test_dataloader, loss_func, device)
    test_loss_history.append(test_loss)
    test_acc_history.append(test_acc)
    test_ce_loss_history.append(test_ce)
    test_info_loss_history.append(test_info)
    print('epoch: %d test loss: %.4f, test acc: %.4f, test crossentropy loss: %.4f, test infonce loss: %.4f' % (i, test_loss, test_acc, test_ce, test_info))
    if i == 0:
        best_acc = test_acc
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), save_best_weight_path + 'best_cvae_contrastive_model.pth')
    torch.save(model.state_dict(), save_best_weight_path + 'last_cvae_contrastive_model.pth')
    
    # save every epoch to ./checkpoints/ddc
    torch.save(model.state_dict(), save_best_weight_path + '/cvae/cvae_contrastive_model_epoch_%d.pth' % i)


 18%|█▊        | 403/2265 [03:21<15:28,  2.00it/s]


KeyboardInterrupt: 

In [None]:
def generate_cvae(model, condition, tgt_input, vocab, device='cuda', max_len=MAX_LEN, temperature=1.0):
    """
    Generates SMILES strings from a trained CVAE model, given an already-merged synthesis condition.

    Args:
        model:       The trained ConditionalVAE model (must implement model.generate).
        condition:   (B, cond_dim) - the already-merged condition vector (e.g., torch.cat([zeo, syn], dim=-1)).
        tgt_input:   (B, seq_len) - the input SMILES sequence (e.g., the '<sos>' token).
        vocab:       The vocabulary object (must implement vocab.idx2char).
        device:      'cuda' or 'cpu'.
        max_len:     Maximum length of generated SMILES.

    Returns:
        smiles_list: A list of generated SMILES strings (length = batch size).
                     Each string excludes '<sos>', ignores '<pad>', and is cut after '<eos>'.
    """
    model.eval()
    smiles_list = []

    with torch.no_grad():
        # Move condition to device
        condition = condition.to(device)
        
        # Generate SMILES strings
        # Assuming the model has a generate function that takes condition as input
        # generated = model.generate(condition, max_len=max_len, device=device, teacher_force_inputs=tgt_input)  # shape: (B, seq_len)
        logits, z_mean, z_log = model(
            x_smi=tgt_input,         # SMILES input
            x_cond=condition,        # condition
            teacher_force_inputs=tgt_input  # Use teacher forcing during inference
        )

        # sample from the logits using multinomial
        probs = F.softmax(logits / temperature, dim=-1)
        generated = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1).view(probs.size(0), probs.size(1))
        
        
        for i in range(generated.size(0)):
            # Convert the generated token IDs to SMILES string
            smi = []
            for idx in generated[i]:
                # Convert index to character, skipping padding or other invalid tokens
                char = vocab.itos[idx.item()]
                if char == '<eos>':  # End of the SMILES string
                    break
                if char not in ('<sos>', '<pad>'):  # Remove unwanted characters
                    smi.append(char)
            smiles_list.append(''.join(smi))

    return smiles_list


In [None]:
# generate the smiles for the test dataset
generated_smile = []
target_smile = []
for i, (zeo, syn, tgt, _, _) in enumerate(tqdm(train_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1).to(device)
    tgt_input = tgt[:, :-1].contiguous()
    # convert the tgt_input to one-hot
    tgt_input = F.one_hot(tgt_input, num_classes=params['NCHARS']).float()
    generated_smiles = generate_cvae(model=model, condition=condition_synthesis, tgt_input=tgt_input, vocab=vocab, device=device, max_len=MAX_LEN, temperature=1.0)
    generated_smile.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile.extend(tgt_smiles)

In [None]:
# calculate the metrics
print('Validity rate:', validity_rate(generated_smiles))
print('Uniqueness rate:', uniqueness_rate(generated_smiles))
print('Novelty rate:', novelty_rate(generated_smiles, target_smile))
print('Reconstructability rate:', reconstructability_rate(generated_smiles, target_smile))
print('Novelty rate:', novelty_rate(generated_smiles, target_smile))
print('IntDiv:', IntDiv(generated_smiles))
# print('KL-divergence:', KL_divergence(target_smile), generated_smile))
print('FCD score:', FCD_score(target_smile, generated_smile))

In [None]:
# plot the loss and acc
plot_loss(train_loss_history, test_loss_history, 'cvae')

In [None]:
# save the history to the csv file in log folder
history = pd.DataFrame({'train_loss': train_loss_history, 'train_acc': train_acc_history, 'test_loss': test_loss_history, 'test_acc': test_acc_history, 'train_ce_loss': train_ce_loss_history, 'test_ce_loss': test_ce_loss_history, 'train_info_loss': train_info_loss_history, 'test_info_loss': test_info_loss_history})
history.to_csv(log_dir + 'ddc_contrastive_history.csv', index=False)