# Train CVAE Model for Conditional Generation

In [None]:
# change working path to the current file
%cd ..

In [None]:
import numpy as np
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.nn import functional as F
from tqdm import tqdm
from datetime import datetime
from torch.utils.tensorboard import SummaryWriter
import torch.backends.cudnn as cudnn

# import custom modules
from models.cvae import *
from utils.utils import *
from datasets.data_loader import *
from utils.plot_figures import *
from utils.metrics import *
from utils.build_vocab import *

In [None]:
cudnn.benchmark = True
cudnn.enabled = True

train_loss_history = []
train_acc_history = []
test_loss_history = []
test_acc_history = []

log_dir = './logs/'
save_best_weight_path = './checkpoints/'

now = datetime.now().strftime('%Y-%m-%d-%H-%M-%S')

In [None]:
PAD = 0
UNK = 1
EOS = 2
SOS = 3
MASK = 4
MAX_LEN = 220

In [None]:
# read the data and convert to the format we need
train_smiles = read_strings('./data/train_smiles.csv', idx=False)
train_zeo = read_vec('./data/train_zeo.csv', idx=False)
train_syn = read_vec('./data/train_syn.csv', idx=False)
train_codes = read_strings('./data/train_codes.csv', idx=False)
test_smiles = read_strings('./data/test_smiles.csv', idx=False)
test_zeo = read_vec('./data/test_zeo.csv', idx=False)
test_syn = read_vec('./data/test_syn.csv', idx=False)
test_codes = read_strings('./data/test_codes.csv', idx=False)

vocab = WordVocab.load_vocab('./model_hub/vocab.pkl')
print('the vocab size is :', len(vocab))

charlen = len(vocab)
print('the total num of charset is :', charlen)

In [None]:
# hyperparameters
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
batch_size = 512
epoch = 50

In [None]:
# create the dataset and dataloader
train_dataset = Seq2seqDataset(train_zeo, train_syn, train_smiles, vocab)
test_dataset = Seq2seqDataset(test_zeo, test_syn, test_smiles, vocab)
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

In [None]:
# Suppose we have:
params = {
    'NCHARS': charlen,
    'MAX_LEN': MAX_LEN,
    'COND_DIM': 24,
    'hidden_dim': 256,
    'conv_depth': 3,
    'conv_dim_depth': 64,
    'conv_dim_width': 3,
    'conv_d_growth_factor': 2,
    'conv_w_growth_factor': 1,
    'middle_layer': 2,
    'activation': 'tanh',
    'batchnorm_conv': True,
    'batchnorm_mid': True,
    'dropout_rate_mid': 0.2,
    'gru_depth': 2,
    'recurrent_dim': 256,
    # ...
}

encoder = ConditionalEncoder(
    input_channels=params['NCHARS'],
    max_len=params['MAX_LEN'],
    cond_dim=params['COND_DIM'],
    hidden_dim=params['hidden_dim'],
    conv_depth=params['conv_depth'],
    conv_dim_depth=params['conv_dim_depth'],
    conv_dim_width=params['conv_dim_width'],
    conv_d_growth_factor=params['conv_d_growth_factor'],
    conv_w_growth_factor=params['conv_w_growth_factor'],
    middle_layer=params['middle_layer'],
    activation=params['activation'],
    batchnorm_conv=params['batchnorm_conv'],
    batchnorm_mid=params['batchnorm_mid'],
    dropout_rate_mid=params['dropout_rate_mid']
).to(device)

decoder = ConditionalDecoder(
    hidden_dim=params['hidden_dim'],
    cond_dim=params['COND_DIM'],
    n_chars=params['NCHARS'],
    max_len=params['MAX_LEN'],
    gru_depth=params['gru_depth'],
    recurrent_dim=params['recurrent_dim'],
    dropout=0.2
).to(device)

model = ConditionalVAE(encoder, decoder).to(device)

# loss
loss_func = torch.nn.CrossEntropyLoss(ignore_index=PAD)
optim = torch.optim.Adam(model.parameters(), lr=6e-4)
total = sum(p.numel() for p in model.parameters())
print('total parameters: %0.2fM' % (total / 1e6))

In [None]:
# train function
def train(model, dataloader, loss_func, optim, device, kl_weight=0.001, pad_idx=0):
    """
    Args:
        model: your Conditional VAE
        dataloader: yields (zeo, syn, tgt) each step
        loss_func: typically nn.CrossEntropyLoss(ignore_index=pad_idx) or similar
        optim: torch optimizer
        device: 'cuda' or 'cpu'
        kl_weight: scaling factor for the KL loss term
        pad_idx: optional index for <pad>, used in ignoring pad in cross-entropy
    """
    model.train()
    total_loss = 0
    total_acc = 0
    total_num = 0

    for i, (zeo, syn, tgt) in enumerate(tqdm(dataloader)):
        # Move data to device
        zeo = zeo.to(device)
        syn = syn.to(device)
        tgt = tgt.to(device)

        # Concatenate to form the full condition
        # shape: (batch_size, cond_dim1 + cond_dim2)
        condition = torch.cat([zeo, syn], dim=-1)
        # tgt: shape (B, seq_len, n_chars) if one-hot approach
        tgt_input = tgt[:, :-1].contiguous() # input to the decoder
        tgt_target = tgt[:, 1:].contiguous() # target output from the decoder
        # convert tgt_input to one -hot code
        tgt_input = F.one_hot(tgt_input, num_classes=charlen).float()

        # ==========================
        # 1) Forward pass
        # ==========================
        logits, z_mean, z_log_var = model(
            x_smi=tgt_input,         # SMILES input
            x_cond=condition,  # condition
            teacher_force_inputs=tgt_input
        )
        # logits: shape (B, seq_len, n_chars)
        logits_reshaped = logits.view(-1, logits.size(-1))
        target_ids = tgt_target.view(-1)  # shape (B * seq_len)
        recon_loss = loss_func(logits_reshaped, target_ids)

        # ==========================
        # 3) Compute KL divergence
        # ==========================
        # KL = -0.5 * sum(1 + log_var - mean^2 - exp(log_var)) 
        # We'll average by batch size
        kl_loss = -0.5 * torch.sum(
            1 + z_log_var - z_mean.pow(2) - z_log_var.exp()
        )
        kl_loss = kl_loss / tgt.size(0)

        # ==========================
        # 4) Total loss
        # ==========================
        loss = recon_loss + kl_weight * kl_loss

        # ==========================
        # 5) Backprop & update
        # ==========================
        optim.zero_grad()
        loss.backward()
        optim.step()

        # ==========================
        # 6) Compute accuracy (optional)
        # ==========================
        preds = torch.argmax(logits, dim=-1)  # shape (B, seq_len)
        num_correct = (preds == tgt_target) & (tgt_target != pad_idx)
        num_words = (tgt_target != pad_idx).sum().item()

        # ==========================
        # 7) Track stats
        # ==========================
        total_loss += loss.item()
        total_acc += num_correct.sum().item()
        total_num += num_words

    avg_loss = total_loss / len(dataloader)
    avg_acc = total_acc / total_num

    return avg_loss, avg_acc

In [None]:
# evaluate function
def evaluate(model, dataloader, loss_func, device, pad_idx=0):
    """
    Args:
        model: your Conditional VAE
        dataloader: yields (zeo, syn, tgt) each step
        loss_func: typically nn.CrossEntropyLoss(ignore_index=pad_idx) or similar
        device: 'cuda' or 'cpu'
        pad_idx: optional index for <pad>, used in ignoring pad in cross-entropy
    """
    model.eval()  # Set the model to evaluation mode
    total_loss = 0
    total_acc = 0
    total_num = 0

    with torch.no_grad():  # Disable gradient calculation during testing
        for i, (zeo, syn, tgt) in enumerate(tqdm(dataloader)):
            # Move data to device
            zeo = zeo.to(device)
            syn = syn.to(device)
            tgt = tgt.to(device)

            # Concatenate to form the full condition
            condition = torch.cat([zeo, syn], dim=-1)
            
            # tgt: shape (B, seq_len, n_chars) if one-hot approach
            tgt_input = tgt[:, :-1].contiguous() # input to the decoder
            tgt_target = tgt[:, 1:].contiguous() # target output from the decoder
            # convert tgt_input to one -hot code
            tgt_input = F.one_hot(tgt_input, num_classes=charlen).float()

            # ==========================
            # 1) Forward pass
            # ==========================
            # We use teacher forcing during testing too
            logits, z_mean, z_log_var = model(
                x_smi=tgt_input,         # SMILES input
                x_cond=condition,        # condition
                teacher_force_inputs=tgt_input  # Use teacher forcing during inference
            )
            
            # Compute the reconstruction loss
            logits_reshaped = logits.view(-1, logits.size(-1))
            target_ids = tgt_target.view(-1)
            recon_loss = loss_func(logits_reshaped, target_ids)
            
            # Compute accuracy
            preds = torch.argmax(logits, dim=-1)
            num_correct = (preds == tgt_target) & (tgt_target != pad_idx)
            num_words = (tgt_target != pad_idx).sum().item()
            
            # Track stats
            total_loss += recon_loss.item()
            total_acc += num_correct.sum().item()
            total_num += num_words
        return total_loss / len(dataloader), total_acc / total_num

In [None]:
# train the model
for i in range(epoch):
    train_loss, train_acc = train(model, train_dataloader, loss_func, optim, device)
    test_loss, test_acc = evaluate(model, test_dataloader, loss_func, device)
    print('epoch %d, train loss %.4f, train acc %.4f, test loss %.4f, test acc %.4f' % (i, train_loss, train_acc, test_loss, test_acc))
    train_loss_history.append(train_loss)
    train_acc_history.append(train_acc)
    test_loss_history.append(test_loss)
    test_acc_history.append(test_acc)
    if i == 0:
        best_acc = test_acc
    if test_acc > best_acc:
        best_acc = test_acc
        torch.save(model.state_dict(), os.path.join(save_best_weight_path, 'best_cvae_model.pth'))
    torch.save(model.state_dict(), os.path.join(save_best_weight_path, 'last_cvae_model.pth'))


In [None]:
def generate_cvae(model, condition, tgt_input, vocab, device='cuda', max_len=MAX_LEN, temperature=1.0):
    """
    Generates SMILES strings from a trained CVAE model, given an already-merged synthesis condition.

    Args:
        model:       The trained ConditionalVAE model (must implement model.generate).
        condition:   (B, cond_dim) - the already-merged condition vector (e.g., torch.cat([zeo, syn], dim=-1)).
        tgt_input:   (B, seq_len) - the input SMILES sequence (e.g., the '<sos>' token).
        vocab:       The vocabulary object (must implement vocab.idx2char).
        device:      'cuda' or 'cpu'.
        max_len:     Maximum length of generated SMILES.

    Returns:
        smiles_list: A list of generated SMILES strings (length = batch size).
                     Each string excludes '<sos>', ignores '<pad>', and is cut after '<eos>'.
    """
    model.eval()
    smiles_list = []

    with torch.no_grad():
        # Move condition to device
        condition = condition.to(device)
        
        # Generate SMILES strings
        # Assuming the model has a generate function that takes condition as input
        # generated = model.generate(condition, max_len=max_len, device=device, teacher_force_inputs=tgt_input)  # shape: (B, seq_len)
        logits, z_mean, z_log = model(
            x_smi=tgt_input,         # SMILES input
            x_cond=condition,        # condition
            teacher_force_inputs=tgt_input  # Use teacher forcing during inference
        )

        # sample from the logits using multinomial
        probs = F.softmax(logits / temperature, dim=-1)
        generated = torch.multinomial(probs.view(-1, probs.size(-1)), num_samples=1).view(probs.size(0), probs.size(1))
        
        
        for i in range(generated.size(0)):
            # Convert the generated token IDs to SMILES string
            smi = []
            for idx in generated[i]:
                # Convert index to character, skipping padding or other invalid tokens
                char = vocab.itos[idx.item()]
                if char == '<eos>':  # End of the SMILES string
                    break
                if char not in ('<sos>', '<pad>'):  # Remove unwanted characters
                    smi.append(char)
            smiles_list.append(''.join(smi))

    return smiles_list


In [None]:
# generate the smiles for the test dataset
generated_smile = []
target_smile = []
for i, (zeo, syn, tgt) in enumerate(tqdm(train_dataloader)):
    zeo = zeo.to(device)
    syn = syn.to(device)
    tgt = tgt.to(device)
    condition_synthesis = torch.cat([zeo, syn], dim=1).to(device)
    tgt_input = tgt[:, :-1].contiguous()
    # convert the tgt_input to one-hot
    tgt_input = F.one_hot(tgt_input, num_classes=params['NCHARS']).float()
    generated_smiles = generate_cvae(model=model, condition=condition_synthesis, tgt_input=tgt_input, vocab=vocab, device=device, max_len=MAX_LEN, temperature=1.0)
    generated_smile.extend(generated_smiles)
    # convert the tgt to smiles
    tgt_smiles = []
    for seq in tgt:
        smiles = ''
        for idx in seq:
            if idx.item() == EOS:
                break
            elif idx.item() != PAD and idx.item() != SOS:
                smiles += vocab.itos[idx.item()]
        tgt_smiles.append(smiles)
    target_smile.extend(tgt_smiles)

In [None]:
# calculate the metrics
print('Validity rate:', validity_rate(generated_smiles))
print('Uniqueness rate:', uniqueness_rate(generated_smiles))
print('Novelty rate:', novelty_rate(generated_smiles, target_smile))
print('Reconstructability rate:', reconstructability_rate(generated_smiles, target_smile))
print('Novelty rate:', novelty_rate(generated_smiles, target_smile))
print('IntDiv:', IntDiv(generated_smiles))
# print('KL-divergence:', KL_divergence(target_smile), generated_smile))
print('FCD score:', FCD_score(target_smile, generated_smile))

In [None]:
# plot the loss and acc
plot_loss(train_loss_history, test_loss_history, 'cvae')