In [1]:
import torch
import numpy as np
import torch.nn as nn
import pandas as pd
from psmiles import PolymerSmiles as PS
from torch.utils.data import DataLoader, Dataset, ConcatDataset
from rdkit import Chem
from rdkit.Chem import Draw
from PIL import Image
from scipy.sparse import csr_matrix, lil_matrix
import atomInSmiles
from collections import Counter
from IPython.display import clear_output, display
import ipywidgets as widgets
from tqdm.notebook import tqdm
import os
import sys
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device=torch.device('cpu')
print(device)

cuda


In [2]:
from sklearn.decomposition import PCA
class load_data(Dataset):
    def __init__(self, path):
        #csv 읽기
        self.raw = pd.read_csv(path)

        #SMILES
        self.SMILES = self.raw.iloc[:, 1:2].values
        self.SMILES = np.squeeze(self.SMILES)

        #특성 데이터
        self.properties = self.raw.iloc[:, 6:11].values
        self.properties = np.squeeze(self.properties)

        #Degree of Polymerization
        self.DP = self.raw.iloc[:, 4:5].values
        self.DP = np.squeeze(self.DP)

        #PSMILES 변환
        psmiles = []
        for smiles in self.SMILES:
            ps = PS(smiles)
            ps.canonicalize
            psmiles.append(ps.psmiles)

        #Atom-In-SMILES Encoding
        ais_encoding = []
        for smiles in psmiles:
            ais_encoding.append(atomInSmiles.encode(smiles))
        


        #Atom-In-SMILES Tokenization (Encoder)
        ais_tokens = []
        for smiles in ais_encoding:
            ais_tokens.append(atomInSmiles.smiles_tokenizer("[SOS] " + smiles +" [EOS]"))

        max_len = len(max(ais_encoding, key=len))
        self.max_len = max_len
        print("max sequence length : ", max_len)

        #vocab 구성
        corpus = []
        for frags in ais_tokens:
            corpus.extend(frags)
        token_count = Counter(corpus)
        vocab = { token:i for i, (token, count) in enumerate(sorted(token_count.items(), key=lambda x: x[1], reverse=True))}
        vocab_size = len(vocab)

        num_data = len(ais_tokens)
        print(vocab)

        ais_tokens_enc = ais_tokens
        ais_tokens_enc = [[tok for tok in tokens if tok not in ['[SOS]', '[EOS]']] for tokens in ais_tokens_enc]

        ais_tokens_dec_input = ais_tokens
        ais_tokens_dec_input = [[tok for tok in tokens if tok not in ['[EOS]']] for tokens in ais_tokens_dec_input]

        ais_tokens_dec_output = ais_tokens
        ais_tokens_dec_output = [[tok for tok in tokens if tok not in ['[SOS]']] for tokens in ais_tokens_dec_output]

        #Tokens to number (encoder)
        ais_token_num_enc = torch.zeros((num_data, max_len), dtype=torch.long)
        i=0
        for tokens in ais_tokens_enc:
            for length in range((len(tokens))):
                ais_token_num_enc[i, length] = vocab[tokens[length]]
            i += 1

        #Tokens to number (Decoder Input)
        ais_token_num_dec_input = torch.zeros((num_data, max_len), dtype=torch.long)
        i=0
        for tokens in ais_tokens_dec_input:
            for length in range((len(tokens))):
                ais_token_num_dec_input[i, length] = vocab[tokens[length]]
            i += 1

        #Tokens to number (Decoder Output)
        ais_token_num_dec_output = torch.zeros((num_data, max_len), dtype=torch.long)
        i=0
        for tokens in ais_tokens_dec_output:
            for length in range((len(tokens))):
                ais_token_num_dec_output[i, length] = vocab[tokens[length]]
            i += 1

        self.SMILES_enc = ais_token_num_enc
        self.SMILES_dec_input = ais_token_num_dec_input
        self.SMILES_dec_output = ais_token_num_dec_output

        
        # self.SMILES = torch.ones((num_data, max_len, vocab_size), dtype=torch.float) * 0.1
        # for smiles in range(ais_token_num_enc.shape[0]):
        #     for length in range(ais_token_num_enc.shape[1]):
        #         if ais_token_num_enc[smiles, length] != 0:
        #             self.SMILES[smiles, length, ais_token_num_enc[smiles, length]] = 0.9

        # self.SMILES_dec_input = torch.ones((num_data, max_len, vocab_size), dtype=torch.float) * 0.1
        # for smiles in range(ais_token_num_dec_input.shape[0]):
        #     for length in range(ais_token_num_dec_input.shape[1]):
        #         if ais_token_num_dec_input[smiles, length] != 0:
        #             self.SMILES_dec_input[smiles, length, ais_token_num_dec_input[smiles, length]] = 0.9

        # self.SMILES_dec_output = torch.zeros((num_data, max_len, vocab_size), dtype=torch.float) * 0.1
        # for smiles in range(ais_token_num_dec_output.shape[0]):
        #     for length in range(ais_token_num_dec_output.shape[1]):
        #         if ais_token_num_dec_output[smiles, length] != 0:
        #             self.SMILES_dec_output[smiles, length, ais_token_num_dec_output[smiles, length]] = 0.9
        
        # self.SMILES_enc = self.SMILES
        
        # print(self.SMILES.shape, self.SMILES_dec_input.shape, self.SMILES_dec_output.shape)
        
        vocab_size, num_data
        print("vocab size : ", vocab_size,"\nnumber of data : ",num_data)
        
        self.vocab = vocab
        self.vocab_size = vocab_size

        

        #PCA
        self.pca = PCA(n_components=1).fit_transform(self.properties[:, 0:4])
        self.pca = torch.tensor(self.pca, dtype=torch.float).to(device)


        print(self.SMILES_enc.shape)
        self.DP = torch.tensor(self.DP, dtype=torch.float).to(device)
        self.properties = torch.tensor(self.properties[:, 4:5], dtype=torch.float).to(device)
        self.properties = torch.cat((self.pca, self.properties, self.DP.unsqueeze(-1)), dim=-1).unsqueeze(-1)
        print(self.properties.shape)

        self.test_data = self.SMILES_enc[50]

        print("PSMILES : ",psmiles[50])
        print("After AIS encoding : ", ais_encoding[50])
        print("After AIS Tokenization : ", ais_tokens_enc[50])
        print("After to number : ", ais_token_num_enc[50])
    
    def __getitem__(self, i):
        return self.SMILES_enc[i], self.SMILES_dec_input[i], self.SMILES_dec_output[i], self.properties[i]
    
    def __len__(self):
        return self.SMILES_enc.shape[0]
    
    def vocab_len(self):
        return self.vocab_size

In [3]:
Polymers = "simulation-trajectory-aggregate.csv"
dataset = load_data(Polymers)

train_dataloader = DataLoader(dataset, batch_size=256, shuffle=True, drop_last=False)


max sequence length :  264
{'(': 0, ')': 1, '=': 2, '[O;!R;C]': 3, '[CH3;!R;C]': 4, '[CH2;!R;CN]': 5, '[CH2;!R;CC]': 6, '[CH2;!R;CO]': 7, '[*;!R;C]': 8, '[SOS]': 9, '[EOS]': 10, '[O;!R;CC]': 11, '[NH;!R;CC]': 12, '[*;!R;O]': 13, '[O;!R;*C]': 14, '[C;!R;*OO]': 15, '[CH;!R;CCO]': 16, '[CH;!R;CCN]': 17, '[C;!R;CNO]': 18, '[N;!R;CCC]': 19, '[C;!R;*NO]': 20, '[*;!R;N]': 21, '[NH;!R;*C]': 22, '[CH3;!R;N]': 23, '[CH;!R;CCC]': 24, '[C;!R;CCCO]': 25, '[CH;!R;CC]': 26, '[F;!R;C]': 27, '[CH2;!R;C]': 28, '[C;!R;COO]': 29, '[CH3;!R;O]': 30, '[C;!R;CCCN]': 31, '#': 32, '[C;!R;CCCC]': 33, '[CH2;!R;CS]': 34, '[C;!R;CC]': 35, '[OH;!R;C]': 36, '[S;!R;CC]': 37, '[N;!R;C]': 38, '[C;!R;CN]': 39, '[CH;!R;C]': 40, '[C;!R;CCC]': 41, '[NH2;!R;C]': 42, '[CH;!R;CFF]': 43, '[O;!R;S]': 44, '[O;!R;CN]': 45, '[CH;!R;CCS]': 46, '[CH2;!R;CF]': 47, '[CH3;!R;S]': 48, '[C;!R;CCO]': 49, '[NH;!R;CO]': 50, '[C;!R;NNO]': 51, '[CH2;!R;*C]': 52, '[C;!R;CFFF]': 53, '[C;!R;CCFF]': 54, '[C;!R;NOO]': 55, '[C;!R;OOO]': 56, '[S;!R;C

In [4]:
import math
#from torch_pca import PCA
from torch.nn import TransformerDecoder, TransformerDecoderLayer, TransformerEncoder, TransformerEncoderLayer
from fast_transformers.masking import TriangularCausalMask

class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-12):
        super(LayerNorm, self).__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x.var(-1, unbiased=False, keepdim=True)
        # '-1' means last dimension. 

        out = (x - mean) / torch.sqrt(var + self.eps)
        out = self.gamma * out + self.beta
        return out

class PositionalEncoding(torch.nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = torch.nn.Dropout(p=dropout)
        self.d_model = d_model
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)

    def forward(self, x):
        pos_embedding =  self.pe[:, :x.size(1), :]
        pos_embedding = torch.repeat_interleave(pos_embedding, x.shape[0], dim=0)
        x =  torch.cat([x, pos_embedding], dim=2)
        return self.dropout(x)

class TFEncoder(nn.Module):
    def __init__(self, d_model=512, n_heads=4, d_ff=32, enc_seq_len=5000, dropout=0.2):
        super().__init__()
        self.normLayer_0 = LayerNorm(d_model=d_model)
        self.normLayer_1 = LayerNorm(d_model=d_model // 2)
        self.normLayer_2 = LayerNorm(d_model=d_model // 4)
        self.normLayer_3 = LayerNorm(d_model=d_model // 8)
        
        self.encoderLayer_0 = TransformerEncoderLayer(batch_first=True,
                                               d_model=d_model,
                                               nhead=n_heads,
                                               dim_feedforward=d_ff,
                                               dropout=dropout,
                                               activation="gelu")
        self.encoderLayer_1 = TransformerEncoderLayer(batch_first=True,
                                               d_model=d_model // 2,
                                               nhead=n_heads,
                                               dim_feedforward=d_ff,
                                               dropout=dropout,
                                               activation="gelu")
        self.encoderLayer_2 = TransformerEncoderLayer(batch_first=True,
                                               d_model=d_model // 4,
                                               nhead=n_heads,
                                               dim_feedforward=d_ff,
                                               dropout=dropout,
                                               activation="gelu")
        self.encoderLayer_3 = TransformerEncoderLayer(batch_first=True,
                                               d_model=d_model // 8,
                                               nhead=n_heads,
                                               dim_feedforward=d_ff,
                                               dropout=dropout,
                                               activation="gelu")
        self.encoder_0 = TransformerEncoder(encoder_layer=self.encoderLayer_0, num_layers=1,
                                          norm=self.normLayer_0)
        self.encoder_1 = TransformerEncoder(encoder_layer=self.encoderLayer_1,num_layers=1,
                                          norm=self.normLayer_1)
        self.encoder_2 = TransformerEncoder(encoder_layer=self.encoderLayer_2,num_layers=1,
                                          norm=self.normLayer_2)
        self.encoder_3 = TransformerEncoder(encoder_layer=self.encoderLayer_3,num_layers=1,
                                          norm=self.normLayer_3)
        self.input_embedding_smiles = nn.Embedding(dataset.vocab_size, d_model // 2)
        self.input_embedding = nn.Sequential(
            nn.Linear(1, d_model // 8),
            nn.GELU(),
            nn.Linear(d_model // 8, d_model // 4),
            nn.GELU(),
            nn.Linear(d_model // 4, d_model // 2),
        )
        self.pos_encoding = PositionalEncoding(d_model // 2, dropout, max_len=enc_seq_len)

        self.to_encoder_1 = nn.Conv1d(in_channels=d_model, out_channels=d_model // 2, kernel_size=1)
        self.to_encoder_2 = nn.Conv1d(in_channels=d_model // 2, out_channels=d_model // 4, kernel_size=1)
        self.to_encoder_3 = nn.Conv1d(in_channels=d_model // 4, out_channels=d_model // 8, kernel_size=1)

    def forward(self, smiles_enc):
        smiles_enc = self.input_embedding_smiles(smiles_enc)
        enc_input_0 = self.pos_encoding(smiles_enc)

        encoded_0 = self.encoder_0(enc_input_0)

        enc_input_1 = self.to_encoder_1(encoded_0.permute(0, 2, 1)).permute(0, 2, 1)
        encoded_1 = self.encoder_1(enc_input_1)
    
        enc_input_2 = self.to_encoder_2(encoded_1.permute(0, 2, 1)).permute(0, 2, 1)
        encoded_2 = self.encoder_2(enc_input_2)

        enc_input_3 = self.to_encoder_3(encoded_2.permute(0, 2, 1)).permute(0, 2, 1)
        encoded_3 = self.encoder_3(enc_input_3)
        return encoded_3

class TFDecoder(nn.Module):
    def __init__(self, d_model=512, n_heads=4, d_ff=32, enc_seq_len=5000, dropout=0.2, ):
        super().__init__()
        self.normLayer_0 = LayerNorm(d_model=d_model // 8)
        self.normLayer_1 = LayerNorm(d_model=d_model // 4)
        self.normLayer_2 = LayerNorm(d_model=d_model // 2)
        self.normLayer_3 = LayerNorm(d_model=d_model)

        
        self.decoderLayer_0 = TransformerDecoderLayer(batch_first=True,
                                               d_model=d_model // 8,
                                               nhead=n_heads,
                                               dim_feedforward=d_ff,
                                               dropout=dropout,
                                               activation="gelu")
        self.decoderLayer_1 = TransformerDecoderLayer(batch_first=True,
                                               d_model=d_model // 4,
                                               nhead=n_heads,
                                               dim_feedforward=d_ff,
                                               dropout=dropout,
                                               activation="gelu")
        self.decoderLayer_2 = TransformerDecoderLayer(batch_first=True,
                                               d_model=d_model // 2,
                                               nhead=n_heads,
                                               dim_feedforward=d_ff,
                                               dropout=dropout,
                                               activation="gelu")
        self.decoderLayer_3 = TransformerDecoderLayer(batch_first=True,
                                               d_model=d_model,
                                               nhead=n_heads,
                                               dim_feedforward=d_ff,
                                               dropout=dropout,
                                               activation="gelu")
        self.decoder_0 = TransformerDecoder(decoder_layer=self.decoderLayer_0,num_layers=1,
                                          norm=self.normLayer_0)
        self.decoder_1 = TransformerDecoder(decoder_layer=self.decoderLayer_1,num_layers=1,
                                          norm=self.normLayer_1)
        self.decoder_2 = TransformerDecoder(decoder_layer=self.decoderLayer_2,num_layers=1,
                                          norm=self.normLayer_2)
        self.decoder_3 = TransformerDecoder(decoder_layer=self.decoderLayer_3,num_layers=1,
                                          norm=self.normLayer_3)
        
        self.input_embedding_smiles = nn.Embedding(dataset.vocab_size, d_model // 16)

        self.pos_encoding = PositionalEncoding(d_model // 16, dropout, max_len=enc_seq_len)

        self.to_decoder_1 = nn.Conv1d(in_channels=d_model // 8, out_channels=d_model // 4, kernel_size=1)
        self.to_decoder_2 = nn.Conv1d(in_channels=d_model // 4, out_channels=d_model // 2, kernel_size=1)
        self.to_decoder_3 = nn.Conv1d(in_channels=d_model // 2, out_channels=d_model, kernel_size=1)

    def forward(self, dec_input, latent):
        dec_input = self.input_embedding_smiles(dec_input)
        dec_input_0 = self.pos_encoding(dec_input)

        x_mask = TriangularCausalMask(dec_input.shape[1], device=device)
        x_mask = x_mask.bool_matrix
        memory_mask = TriangularCausalMask(latent.shape[1], device=device)
        memory_mask = memory_mask.bool_matrix

        decoded_0 = self.decoder_0(dec_input_0, latent, tgt_mask=x_mask, memory_mask = memory_mask)

        dec_input_1 = self.to_decoder_1(decoded_0.permute(0, 2, 1)).permute(0, 2, 1)
        latent = self.to_decoder_1(latent.permute(0, 2, 1)).permute(0, 2, 1)
        decoded_1 = self.decoder_1(dec_input_1, latent, tgt_mask=x_mask, memory_mask = memory_mask)

        dec_input_2 = self.to_decoder_2(decoded_1.permute(0, 2, 1)).permute(0, 2, 1)
        latent = self.to_decoder_2(latent.permute(0, 2, 1)).permute(0, 2, 1)
        decoded_2 = self.decoder_2(dec_input_2, latent, tgt_mask=x_mask, memory_mask = memory_mask)

        dec_input_3 = self.to_decoder_3(decoded_2.permute(0, 2, 1)).permute(0, 2, 1)
        latent = self.to_decoder_3(latent.permute(0, 2, 1)).permute(0, 2, 1)
        decoded_3 = self.decoder_3(dec_input_3, latent, tgt_mask=x_mask, memory_mask = memory_mask)

        return decoded_3



In [5]:
class CVAE(nn.Module):
    def __init__(self, d_model=512, n_layers=4, n_heads=4, d_ff=32, enc_seq_len=5000,
                 d_query=128, dropout=0.2, softmax_temp = None, attention_dropout=0.2, latent_dim = 64):
        super().__init__()

        self.to_means = nn.Linear(latent_dim, latent_dim)
        self.to_var = nn.Linear(latent_dim, latent_dim)

        self.encoder = TFEncoder()
        self.decoder = TFDecoder()

        self.predict = nn.Linear(d_model, dataset.vocab_size)
        
        self.softmax = nn.Softmax(dim=-1)
    
    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var).to(device)
        eps = torch.rand_like(std).to(device)
        return mu + eps * std

    def forward(self, smiles_enc, smiles_dec_input, properties):

        encoded = self.encoder(smiles_enc) # (batch_size, seq_len, d_model // 8) 

        means = self.to_means(encoded).permute(0, 2, 1)
        log_var = self.to_var(encoded).permute(0, 2, 1)

        z = self.reparameterize(means, log_var).permute(0, 2, 1)
        
        output = self.decoder(smiles_dec_input, z)
        output = self.predict(output)

        return output, means, log_var, z
    

In [6]:
def loss_fn(output, input, mean, log_var):
    output = output.view(-1, dataset.vocab_size)
    input = input.view(-1)
    BCE = torch.nn.functional.cross_entropy(
        output, input, reduction='sum'
    )
    KLD = -0.5*torch.sum(1+log_var-mean.pow(2) - log_var.exp())

    return (BCE+KLD) / input.size(0)

In [7]:
def reverse_one_hot_encoding(one_hot_tensor, vocab):
    # 인덱스 → 토큰 매핑 생성
    index_to_token = {idx: token for token, idx in vocab.items()}
    
    #print(index_to_token)
    # 복원된 토큰 시퀀스를 저장할 리스트
    original_tokens_list = []
    
    # 텐서가 GPU에 있다면 CPU로 변환
    if one_hot_tensor.is_cuda:
        one_hot_tensor = one_hot_tensor.cpu()
    try:
        one_hot_tensor.shape[1]
        one_hot_tensor = torch.argmax(one_hot_tensor, dim=-1)
    except:
        one_hot_tensor.shape[0]

    nonzero_indices = torch.nonzero(one_hot_tensor, as_tuple=True)[0]

    try:
        for i in range(nonzero_indices[-1]+1):
            
            # 인덱스를 토큰으로 변환
            tokens = index_to_token[one_hot_tensor[i].item()]

            original_tokens_list.append(tokens)
    
    except:
        original_tokens_list = "not a polymer!"

    return original_tokens_list
print(dataset.test_data)
test = reverse_one_hot_encoding(dataset.test_data, dataset.vocab)
print(test)
print(atomInSmiles.decode(' '.join(test)))
#print(dataset.vocab)


tensor([13, 14, 16,  0,  4,  1,  6,  6,  5, 12, 18,  0,  2,  3,  1, 16,  0,  4,
         1, 11, 15,  0,  8,  1,  2,  3,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,
         0,  0,  0,  0,  0,  0,  0,  0, 

In [8]:
model = CVAE()
model.cuda()
lr = 3e-5
optim = torch.optim.AdamW(model.parameters(), lr=lr)

In [9]:
from torchinfo import summary
smiles = torch.ones([128, dataset.max_len], dtype=torch.long).to(device)
pp = torch.ones([128, 3, 1], dtype=torch.float).to(device)
summary(model, input_data=(smiles, smiles, pp))

Layer (type:depth-idx)                             Output Shape              Param #
CVAE                                               [128, 264, 166]           --
├─TFEncoder: 1-1                                   [128, 264, 64]            1,504,320
│    └─Embedding: 2-1                              [128, 264, 256]           42,496
│    └─PositionalEncoding: 2-2                     [128, 264, 512]           --
│    │    └─Dropout: 3-1                           [128, 264, 512]           --
│    └─TransformerEncoder: 2-3                     [128, 264, 512]           --
│    │    └─ModuleList: 3-2                        --                        1,085,984
│    │    └─LayerNorm: 3-3                         [128, 264, 512]           1,024
│    └─Conv1d: 2-4                                 [128, 256, 264]           131,328
│    └─TransformerEncoder: 2-5                     [128, 264, 256]           --
│    │    └─ModuleList: 3-4                        --                        280,864
│   

In [10]:
def print_fixed_lines(line1, line2, line3, line4, line5):
    # 커서를 위로 3줄 올리고, 3줄을 덮어쓰기
    sys.stdout.write("\033[3F")  # 커서를 위로 3줄 이동
    sys.stdout.write("\033[K" + line1 + "\n")  # 줄 지우고 새로 쓰기
    sys.stdout.write("\033[K" + line2 + "\n")
    sys.stdout.write("\033[K" + line3 + "\n")
    sys.stdout.write("\033[K" + line4 + "\n")
    sys.stdout.write("\033[K" + line5 + "\n")
    sys.stdout.flush()

In [None]:
status_out = widgets.Output()
display(status_out)

epoch = 4000
model.train()
progress = tqdm(range(epoch), desc="Training")

loss_arr = list()
real = list()
predict = list()

for i in progress:
    batchloss = 0.0
    for (smiles_enc, smiles_dec_input, smiles_dec_output, properties) in train_dataloader:
        optim.zero_grad()

        smiles_enc = smiles_enc.to(device)
        smiles_dec_input = smiles_dec_input.to(device)
        smiles_dec_output = smiles_dec_output.to(device)
        properties = properties.to(device)


        # smiles_dec_input = model.softmax(smiles_dec_input)
        #smiles_dec_output = model.softmax(smiles_dec_output)

        result, means, log_var, z = model(smiles_enc, smiles_dec_input, properties)
        
        loss = loss_fn(result.float(), smiles_dec_output, means, log_var)
        loss.backward()
        optim.step()
        batchloss += loss
    
    loss = batchloss.cpu().item() / len(train_dataloader)
    loss_arr.append(loss)
    
    #progress.set_description("loss: {:0.6f}".format(loss))


    argmax_indices = torch.argmax(result, dim=-1)
    output = torch.nn.functional.one_hot(argmax_indices, num_classes=result.size(-1))

    original_tokens = reverse_one_hot_encoding(smiles_dec_output[50], dataset.vocab)
    predicted_tokens = reverse_one_hot_encoding(output[50], dataset.vocab)

    original_str = atomInSmiles.decode(' '.join(original_tokens))
    predicted_str = atomInSmiles.decode(' '.join(predicted_tokens))


    # 진행 바의 속성으로부터 필요한 값들 추출 (예시)
    elapsed = progress.format_dict.get("elapsed", 0)
    rate = progress.format_dict.get("rate", None)
    sec_per_iter = 1 / rate if rate and rate != 0 else 0
    
    # 고정된 2줄 상태 정보를 업데이트 (Output 위젯에 출력)
    with status_out:
        clear_output(wait=True)
        print(f"🔹 Elapsed: {elapsed:.1f}s | sec/iter: {sec_per_iter:.3f}s")
        print(f"🔹 Step: {i+1}/{progress.total}")
        print("🔹 loss: {:0.6f}".format(loss))
        print(f"[Epoch {i}] Original : {original_str}")
        print(f"[Epoch {i}] Predict  : {predicted_str}")

    


Output()

Training:   0%|          | 0/4000 [00:00<?, ?it/s]