In [1]:
%load_ext autoreload
%autoreload 2
import torch, rdkit
import sys, pathlib
from pathlib import Path
PROJECT_ROOT = Path.home()/"바탕화면"/"torch"/"Chem"
sys.path.insert(0, str(PROJECT_ROOT))
from rdkit import Chem
from utils.utils import *


from tqdm import trange
from pathlib import Path
device   = "cuda"

vocab = dataset.vocab
index_to_token = {idx: token for token, idx in vocab.items()}

cuda


30


In [2]:
def select_model(choice, latent):
    if choice == "Trans_MHA":
        from models.Trans_MHA import CVAE, PriorNet
        model    = CVAE(latent_dim=latent).cuda()
        model.decoder.cuda()
        prior = PriorNet(y_dim=3, latent_dim=latent).cuda()

    # Trans
    elif choice == "Trans":
        from models.Trans import CVAE, PriorNet
        model    = CVAE(latent_dim=latent).cuda()
        model.decoder.cuda()
        prior = PriorNet(y_dim=3, latent_dim=latent).cuda()

    # LSTM
    elif choice == "LSTM":
        from models.LSTM import CVAE, PriorNet
        model    = CVAE().cuda()
        model.decoder.cuda()
        prior = PriorNet(y_dim=3, latent_dim=latent).cuda()

    # LSTM + MHA
    elif choice == "LSTM_MHA":
        from models.LSTM_MHA import CVAE, PriorNet
        model    = CVAE().cuda()
        model.decoder.cuda()
        prior = PriorNet(y_dim=3, latent_dim=latent).cuda()

    return model, prior

def get_loss_fn(choice, latent):
    if choice == "Trans_MHA":
        loss_fn = ConditionalVAELoss(
            vocab_size=dataset.vocab_size,
            max_beta=0.7,
            anneal_steps=1200,
            free_bits=0.01,
            capacity_max=10,
            capacity_inc=6e-4,
            gamma=5.0,
            prop_w=2.2,
            nce=0.2,
            sig_pen_p=0.02,
            sig_pen_q=0.05,
            imb=0.1,
            latent_dim=latent
        ).cuda()

    # Trans
    elif choice == "Trans":
        loss_fn = ConditionalVAELoss(
            vocab_size=dataset.vocab_size,
            max_beta=1.0,
            anneal_steps=1200,
            free_bits=0.1,
            capacity_max=8,
            capacity_inc=1e-3,
            gamma=3.0,
            prop_w=3.0,
            nce=0.03,
            sig_pen_p=0.003,
            sig_pen_q=0.0,
            imb=0.05,
            latent_dim=latent
        ).cuda()

    # LSTM
    elif choice == "LSTM":
        loss_fn = ConditionalVAELoss_LSTM(
            vocab_size=dataset.vocab_size,
            max_beta=0.08,
            anneal_steps=2400,
            free_bits=0.03,
            capacity_max=0.0,
            capacity_inc=0.0,
            gamma=0.0,
            prop_w=3.0,
            nce=0.2,
            sig_pen_p=0.05,
            sig_pen_q=0.1,
            imb=0.5
        ).cuda()

    # LSTM + MHA
    elif choice == "LSTM_MHA":
        loss_fn = ConditionalVAELoss(
            vocab_size=dataset.vocab_size,
            max_beta=1.0,
            anneal_steps=2400,
            free_bits=0.03,
            capacity_max=0.0,
            capacity_inc=0.0,
            gamma=0.0,
            prop_w=3.0,
            nce=0.2,
            sig_pen_p=0.3,
            sig_pen_q=0.0,
            imb=0.5
        ).cuda()
        
    return loss_fn

In [3]:
mode = "Trans"
latent_dim = 96
model, prior = select_model(mode, latent_dim)
loss_fn = get_loss_fn(mode, latent_dim)

lr = 3e-5
lr_prior = 1e-4
from torch.optim import AdamW
optim = AdamW(model.parameters(), lr=lr)
optim2 = AdamW(prior.parameters(), lr=lr_prior)

In [None]:
import datetime
from rdkit import Chem, RDLogger
RDLogger.DisableLog('rdApp.error')
status_out = widgets.Output()

display(status_out)
epoch = 1200
model.train()
prior.train()
progress = tqdm(range(epoch), desc="Training")
loss_arr=[]
for i in progress:
    log_var_extract=[]
    log_var_p_extract = []
    batchloss = 0.0
    embeddings = []
    mean_extract = []
    results = []
    kld_raw_batch = 0.0
    bce = 0.0
    kld = 0.0
    if mode=="Trans" or mode=="Trans_MHA":
      for (smiles_enc, smiles_dec_input, smiles_dec_output, properties) in train_dataloader:
        optim.zero_grad()
        optim2.zero_grad()

        smiles_enc = smiles_enc.to(device)
        smiles_dec_input = smiles_dec_input.to(device)
        smiles_dec_output = smiles_dec_output.to(device)
        properties = properties.to(device)
        output, tgt, means, log_var, tgt_z = model.forward(smiles_enc, smiles_dec_input, properties)
      
        mu_p, logvar_p = prior.forward(properties.squeeze())
        with torch.no_grad():
          log_var_extract.append(torch.exp(0.5 * log_var.cpu()).to(device))
          log_var_p_extract.append(torch.exp(0.5 * logvar_p.cpu()).to(device))
          mean_extract.append(means.cpu())
        
        loss, BCE, KLD, prop, kld_raw = loss_fn.forward(output.float(),
                                smiles_dec_output,
                                means, log_var, mu_p, logvar_p,
                                tgt, properties.float().squeeze(), tgt_z,
                                i)
        
        results.append(output)

        loss.backward()
        optim.step()
        optim2.step()

        batchloss += loss.item()
        kld_raw_batch += kld_raw.item()
        bce += BCE.item()
        kld += KLD.item()
    else:
      for (smiles_dec_input, smiles_dec_output, properties) in train_dataloader_LSTM:
        tf_ratio = max(0.1, 1.0 - i / 300)
        optim.zero_grad()
        optim2.zero_grad()

        smiles_dec_input = smiles_dec_input.to(device)
        smiles_dec_output = smiles_dec_output.to(device)
        properties = properties.to(device)

        output, tgt, means, log_var, tgt_z = model.forward(smiles_dec_input, smiles_dec_output, properties, tf_ratio)
        
        mu_p, logvar_p = prior.forward(properties.squeeze())
        with torch.no_grad():
          log_var_extract.append(torch.exp(0.5 * log_var.cpu()).to(device))
          log_var_p_extract.append(torch.exp(0.5 * logvar_p.cpu()).to(device))
          mean_extract.append(means.cpu())
      
        loss, BCE, KLD, prop, kld_raw = loss_fn.forward(output.float(),
                              smiles_dec_output,
                              means, log_var, mu_p, logvar_p,
                              tgt, properties.float().squeeze(), tgt_z,
                              i)
        loss.backward()
        optim.step()
        optim2.step()

        batchloss += loss.item()
        kld_raw_batch += kld_raw.item()
        bce += BCE.item()
        kld += KLD.item()
        
    #Loss 값 추가
    loss = batchloss / len(train_dataloader)
    loss_arr.append(loss)

    #Validity
    results = torch.cat(results, dim=0).cpu()
    results = nn.functional.softmax(results, dim=-1) 
    argmax_indices = torch.argmax(results, dim=-1)
    
    valid_smiles = []
    for row in argmax_indices:
       smiles = tok_ids_to_smiles(row.tolist())
       valid_smiles.append(smiles or "")

    valid_count = sum(bool(s) for s in valid_smiles)
    valid_frac  = valid_count / len(valid_smiles)

    # 진행 바의 속성으로부터 필요한 값들 추출 (예시)
    elapsed = int(progress.format_dict.get("elapsed", 0))
    formatted_elap = str(datetime.timedelta(seconds=elapsed))
    rate = progress.format_dict.get("rate", None)
    sec_per_iter = 1 / rate if rate and rate != 0 else 0
    total = int(sec_per_iter * progress.total)
    formatted_total = str(datetime.timedelta(seconds=total))
    
    # 고정된 상태 정보를 업데이트 (Output 위젯에 출력)
    with status_out:
        clear_output(wait=True)
        print(f"🔹 Elapsed: {formatted_elap} | sec/iter: {sec_per_iter:.3f}s")
        print("🔹 Total time: ", formatted_total)
        print(f"🔹 Step: {i+1}/{progress.total}")
        print("🔹 loss: {:0.6f}".format(loss))
        print("[Posterior] sigma mean : {:0.6f}, ".format(torch.cat(log_var_extract).mean().item()),
              "sigma std : {:0.6f}".format(torch.cat(log_var_extract).std().mean()))
        print("[Prior]     sigma mean : {:0.6f}, ".format(torch.cat(log_var_p_extract).mean().item()),
              "sigma std : {:0.6f}".format(torch.cat(log_var_p_extract).std().mean()))
        print("raw KLD : ", kld_raw_batch / len(train_dataloader))
        print("Posterior μ per-dim std:", torch.cat(mean_extract,dim=0).view(-1, latent_dim).std(axis=0).detach().cpu().numpy().mean())
        print("Posterior μ overall std:", torch.cat(mean_extract,dim=0).view(-1, latent_dim).std().item())
        print("BCE : {:0.6f},".format(bce / 40 / len(train_dataloader)),
              "KLD : {:0.6f},".format(kld / 40 / len(train_dataloader)),
              "prop : {:0.6f},".format(prop))
        print(f"Validity: {valid_frac:.2%}  ({valid_count}/{len(valid_smiles)})")

Output()

Training:   0%|          | 0/1200 [00:00<?, ?it/s]

In [None]:
model.eval()
prior.eval()
results = []
origin = []

properties_results=[]
properties_origin=[]
print(len(val_dataset))
with torch.no_grad():
    try:
        for (smiles_enc, smiles_dec_input, smiles_dec_output, properties) in val_dataloader:
            B = smiles_enc.size(0)

            smiles_enc = smiles_enc.to(device)
            smiles_dec_input = smiles_dec_input.to(device)
            smiles_dec_output = smiles_dec_output.to(device)
            properties = properties.to(device)

            result, tgt, means, log_var, z = model(smiles_enc, smiles_dec_input, properties)
            # result = model.deocder(smiles_dec_input, means.view(B, -1, latent_dim))
            # result = model.predict(result)

            results.append(result)
            origin.append(smiles_dec_output)

            properties_results.append(tgt)
            properties_origin.append(properties)
    except:
        for (smiles_dec_input, smiles_dec_output, properties) in val_dataloader_LSTM:

            smiles_dec_input = smiles_dec_input.to(device)
            smiles_dec_output = smiles_dec_output.to(device)
            properties = properties.to(device)

            result, tgt, means, log_var, z = model(smiles_dec_input, smiles_dec_output, properties)

            results.append(result)
            origin.append(smiles_dec_output)

            properties_results.append(tgt)
            properties_origin.append(properties)

results = torch.cat(results, dim=0)
origin = torch.cat(origin, dim=0)
results = nn.functional.softmax(results, dim=-1) 
argmax_indices = torch.argmax(results, dim=-1)
output = torch.nn.functional.one_hot(argmax_indices, num_classes=results.size(-1))
print(argmax_indices)
print(results.shape)
print(origin.shape)

from sklearn.metrics import mean_absolute_error
properties_origin=torch.cat(properties_origin,dim=0)
properties_results=torch.cat(properties_results,dim=0)
MAE_2 = mean_absolute_error(properties_origin.squeeze().cpu(), properties_results.squeeze().cpu())
print("MAE(properties) : ", MAE_2)


results_smiles = []
origin_smiles = []

for row in argmax_indices:
    smiles = tok_ids_to_smiles(row.tolist())
    results_smiles.append(smiles or "")

for row in origin:
    smiles = tok_ids_to_smiles(row.tolist())
    origin_smiles.append(smiles or "")

1567
tensor([[11, 14, 21,  ..., 22, 22, 22],
        [11, 14, 14,  ..., 22, 22, 22],
        [11, 14, 14,  ..., 22, 22, 22],
        ...,
        [11, 14, 14,  ..., 22, 22, 22],
        [11, 14, 21,  ..., 22, 22, 22],
        [11, 14, 21,  ..., 22, 22, 22]], device='cuda:0')
torch.Size([1567, 37, 30])
torch.Size([1567, 37])
MAE(properties) :  0.24400407524830678


In [None]:
origin_smiles = [smiles.removesuffix("EOS").strip() for smiles in origin_smiles]
results_smiles = [smiles.removesuffix("EOS").strip() for smiles in results_smiles]

for i in range(len(results_smiles)):
    if(origin_smiles[i] != results_smiles[i]):
        print(i, "번째 다름!")
    print("real smiles      : ", origin_smiles[i])
    print("predicted smiles : ", results_smiles[i])


MAE = mean_absolute_error(origin.cpu(), torch.argmax(results.cpu(), dim=-1))
print("MAE : ", MAE)



0 번째 다름!
real smiles      :  [*]CC(C)OC(=O)OCC(C)N([*])CCC
predicted smiles :  
1 번째 다름!
real smiles      :  [*]NC(=O)OC(CC)C([*])C(=O)OCC
predicted smiles :  [*]CC(=O)OC(C)CCCCCCCC=O
2 번째 다름!
real smiles      :  [*]OC(=O)OCC(C)NC(=O)C(C)C([*])CC
predicted smiles :  
3 번째 다름!
real smiles      :  [*]CNC(=O)CCCCNC(=O)NC([*])CC
predicted smiles :  [*]CNC(=O)NCCNNC(=O)NC([*])C
4 번째 다름!
real smiles      :  [*]COC(=O)NC([*])C(=O)NCCCC
predicted smiles :  [*]COC(=O)OC([*])CCC=O
5 번째 다름!
real smiles      :  [*]CNC(=O)COC(=O)OCC([*])C(C)C
predicted smiles :  
6 번째 다름!
real smiles      :  [*]OC(=O)OCC(O)=C(O)C([*])C=O
predicted smiles :  [*]CC(=O)OC(CCCC)CNC([*])CC
7 번째 다름!
real smiles      :  [*]COC(=O)OCCCCC(=O)NC([*])C=C
predicted smiles :  [*]COC(=O)OCCCNNNC=O
8 번째 다름!
real smiles      :  [*]CN(C)CCCNC(=O)OC([*])CSC
predicted smiles :  [*]COCCCCCCCC(=O)OC([*])CCC
9 번째 다름!
real smiles      :  [*]CNC(=O)OC([*])COC(C)(C)C
predicted smiles :  
10 번째 다름!
real smiles      :  [*]CCOC(=O)OC(C)C(C)NC

In [None]:
from rdkit import Chem, RDLogger
from rdkit.Chem import DataStructs, rdFingerprintGenerator
RDLogger.DisableLog('rdApp.error')

generator = rdFingerprintGenerator.GetMorganGenerator(radius=2, fpSize=2048)

def tanimoto_similarity(smiles1, smiles2):
    mol1 = Chem.MolFromSmiles(smiles1)
    mol2 = Chem.MolFromSmiles(smiles2)
    fp1 = generator.GetFingerprint(mol1)
    fp2 = generator.GetFingerprint(mol2)
    return DataStructs.TanimotoSimilarity(fp1, fp2)

def is_valid(smiles):
    return Chem.MolFromSmiles(smiles) is not None

TS = 0.0
canbe = 0
notbe = 0

for sm, orig in zip(results_smiles, origin_smiles):
    if(len(sm)==0):
        notbe += 1
        continue
    if is_valid(sm) and is_valid(orig):
        sim = tanimoto_similarity(sm, orig)
        TS += sim
        canbe += 1
    else:
        notbe += 1

if canbe > 0:
    print("Tanimoto Similarity : ", TS / canbe)
else:
    print("No valid molecules to compare.")

print("가능한 분자 개수 :", canbe)
print("불가능한 분자 개수 :", notbe)
print("Valid fraction      :", canbe / len(results_smiles))


Tanimoto Similarity :  0.23365548452873136
가능한 분자 개수 : 1169
불가능한 분자 개수 : 398
Valid fraction      : 0.746011486917677


In [None]:
def save_weights(choice):
    if choice == "Trans_MHA":
        save_path = (PROJECT_ROOT / "models/weights" / "model_weights_dmodel256.pth")
        torch.save(model.state_dict(), save_path)

        save_path = (PROJECT_ROOT / "models/weights" / "model_weights_prior.pth")        
        torch.save(prior.state_dict(), save_path)

    # Trans
    elif choice == "Trans":
        save_path = (PROJECT_ROOT / "models/weights" / "model_weights_dmodel256_no_mha.pth")
        torch.save(model.state_dict(), save_path)

        save_path = (PROJECT_ROOT / "models/weights" / "model_weights_prior_no_mha.pth")        
        torch.save(prior.state_dict(), save_path)
    # LSTM
    elif choice == "LSTM":
        save_path = (PROJECT_ROOT / "models/weights" / "model_weights_LSTM.pth")        
        torch.save(model.state_dict(), save_path)

        save_path = (PROJECT_ROOT / "models/weights" / "model_weights_LSTM_prior.pth")       
        torch.save(prior.state_dict(), save_path)


    # LSTM + MHA
    elif choice == "LSTM_MHA":
        save_path = (PROJECT_ROOT / "models/weights" / "model_weights_LSTM_MHA.pth")        
        torch.save(model.state_dict(), save_path)

        save_path = (PROJECT_ROOT / "models/weights" / "model_weights_LSTM_MHA_prior.pth")        
        torch.save(prior.state_dict(), save_path)

In [None]:
save_weights(mode)