In [1]:
!pip install PyTDC
!pip install datasets

!pip install transformers

Collecting PyTDC
  Downloading PyTDC-0.4.1.tar.gz (107 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m107.7/107.7 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting rdkit-pypi (from PyTDC)
  Downloading rdkit_pypi-2022.9.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.4/29.4 MB[0m [31m39.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting fuzzywuzzy (from PyTDC)
  Downloading fuzzywuzzy-0.18.0-py2.py3-none-any.whl (18 kB)
Collecting dataclasses (from PyTDC)
  Downloading dataclasses-0.6-py3-none-any.whl (14 kB)
Building wheels for collected packages: PyTDC
  Building wheel for PyTDC (setup.py) ... [?25l[?25hdone
  Created wheel for PyTDC: filename=PyTDC-0.4.1-py3-none-any.whl size=140644 sha256=30ad08387a081b32cee89fe05107b681823efe8986a34cbc9ec6e3064ed4a211
  Stored in directory: /root/.cache/pip/wheels/14/b7/b

Collecting datasets
  Downloading datasets-2.15.0-py3-none-any.whl (521 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m521.2/521.2 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
Collecting pyarrow-hotfix (from datasets)
  Downloading pyarrow_hotfix-0.6-py3-none-any.whl (7.9 kB)
Collecting dill<0.3.8,>=0.3.0 (from datasets)
  Downloading dill-0.3.7-py3-none-any.whl (115 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m115.3/115.3 kB[0m [31m12.3 MB/s[0m eta [36m0:00:00[0m
Collecting multiprocess (from datasets)
  Downloading multiprocess-0.70.15-py310-none-any.whl (134 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m134.8/134.8 kB[0m [31m14.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: pyarrow-hotfix, dill, multiprocess, datasets
Successfully installed datasets-2.15.0 dill-0.3.7 multiprocess-0.70.15 pyarrow-hotfix-0.6


In [2]:
import tqdm
import numpy as np
import pandas as pd
import plotly.express as px
from tdc.multi_pred import DTI
from tdc.generation import MolGen
import plotly.figure_factory as ff

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [3]:
data = MolGen(name = 'MOSES')
data=data.get_data()[:100000]

Downloading...
100%|██████████| 75.3M/75.3M [00:04<00:00, 18.1MiB/s]
Loading...
Done!


In [4]:
data=data.sample(frac=1)
data=data.reset_index(drop=True)

## Some Basic Analysis

In [5]:
data['l_smiles']=data.smiles.apply(len)

In [6]:
data['l_smiles'].describe(percentiles=[i/10 for i in range(1,10)])

count    100000.000000
mean         35.150310
std           4.566592
min          15.000000
10%          29.000000
20%          31.000000
30%          33.000000
40%          34.000000
50%          35.000000
60%          36.000000
70%          38.000000
80%          39.000000
90%          41.000000
max          54.000000
Name: l_smiles, dtype: float64

## A solid preset length for drug would be 50

In [7]:
def tokenize(input_string):
  return [ord(char) for char in input_string]
def encode(input_string,max_length=128,padding=True):
  tokens=tokenize(input_string)
  if len(tokens)>max_length:
    tokens=tokens[:max_length]
  if (len(tokens)<max_length) & padding:
    tokens.extend([0 for _ in range(max_length-len(tokens))])
  return tokens
def decode(input_tokens):
  return ''.join(list(map(lambda x:chr(x), input_tokens)))

In [8]:
l_tokenizer=encode('z',padding=False)[0]+1

In [9]:
class Drug_Dataset(Dataset):
    def __init__(self, df,drug_max_length):
        self.df = df
        self.dml=drug_max_length
    def __len__(self):
        return len(self.df)

    def __getitem__(self, idx):
        row=self.df.iloc[idx]
        input_drug=torch.tensor(encode(row['smiles'],max_length=self.dml))
        return {'input_drug':input_drug}

In [10]:
dml=50

In [11]:
l=int(data.shape[0]*0.8)
train_p=Drug_Dataset(data[:l],drug_max_length=dml)
test_p=Drug_Dataset(data[l:],drug_max_length=dml)

In [12]:
train_loader=DataLoader(train_p,batch_size=32,shuffle=True)
test_loader=DataLoader(test_p,batch_size=32)

## Let's create the model

## This is Attention, as Torch's attention does not work with mask.

In [37]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
class Encoder(nn.Module):
    def __init__(self, latent_dim=8,embed_dim=16):
        super(Encoder, self).__init__()
        self.latent_dim=latent_dim
        self.embeddings = nn.Embedding(l_tokenizer, embed_dim)
        self.conv1=nn.Conv1d(embed_dim,embed_dim//2,3)
        self.pool1=nn.MaxPool1d(2)
        self.conv2=nn.Conv1d(embed_dim//2,embed_dim//4,3)
        self.pool2=nn.MaxPool1d(2)
        self.conv3=nn.Conv1d(embed_dim//4,embed_dim//8,2)
        self.pool3=nn.MaxPool1d(2)
        self.lsigma=nn.Linear((embed_dim//8)*5, latent_dim)
        self.llogvar=nn.Linear((embed_dim//8)*5, latent_dim)

        # self.N = torch.distributions.Normal(0, 1)
        # self.N.loc = self.N.loc # hack to get sampling on the GPU
        # self.N.scale = self.N.scale
        self.kl = 0
    def reparameterize(self, mu , logvar):
        std = torch.exp(logvar*0.5)
        eps = torch.randn_like(std).to(device)
        z = mu + eps * std
        return z
    def forward(self, x):
        x=self.embeddings(x)
        bn=x.size(0)
        x=torch.transpose(x,1,2)
        x = self.pool1(F.relu(self.conv1(x)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = self.pool3(F.relu(self.conv3(x)))
        # x=torch.transpose(x,1,2)
        x=x.view(bn,-1)
        mu =  self.lsigma(x)
        logvar = self.llogvar(x)
        z=self.reparameterize(mu , logvar)
        self.kl = 0.5*(logvar.exp() + mu**2 - logvar - 1).sum()/x.size(0)/dml/10
        return z

class Decoder(nn.Module):
    def __init__(self, latent_dim=8,dim1=16):
        super(Decoder, self).__init__()
        self.linear=nn.Linear(latent_dim,24)
        self.conv1 = nn.ConvTranspose1d(1, dim1,3,stride=2)
        self.conv2 = nn.ConvTranspose1d(dim1, 2*dim1,2)
        self.linear2 = nn.Linear(2*dim1,l_tokenizer)

    def forward(self, z):
        z = F.relu(self.linear(z))
        z = z.view(-1,1,24)
        z = F.relu(self.conv1(z))
        z = F.relu(self.conv2(z))
        z = torch.transpose(z,1,2)
        z = self.linear2(z)
        return z

class VariationalAutoencoder(nn.Module):
    def __init__(self, latent_dims=64):
        super(VariationalAutoencoder, self).__init__()
        self.encoder = Encoder(latent_dims,64).to(device)
        self.decoder = Decoder(latent_dims,32).to(device)

    def forward(self, x):
        z = self.encoder(x)
        return self.decoder(z)

In [38]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
autoencoder=VariationalAutoencoder()
opt = torch.optim.Adam(autoencoder.parameters(),lr=0.0005)
loss_fn=torch.nn.CrossEntropyLoss()
def acc_fn(y,y_hat):
  return torch.mean((y==torch.argmax(y_hat,dim=1)).float())
for epoch in range(100):
    train_loss1=0
    train_loss2=0
    train_acc=0
    autoencoder.train()
    for batch in tqdm.tqdm(train_loader):
        x = batch['input_drug'].to(device) # GPU
        opt.zero_grad()
        x_hat = autoencoder(x)
        # x2=torch.transpose(x,1,2)
        x2_hat=torch.transpose(x_hat,1,2)
        loss1=loss_fn(x2_hat,x)
        loss2=autoencoder.encoder.kl
        acc=acc_fn(x,x2_hat)
        loss = loss1 + loss2
        train_acc+=acc.detach().cpu().numpy()
        train_loss1+=loss1.detach().cpu().numpy()
        train_loss2+=loss2.detach().cpu().numpy()
        loss.backward()
        opt.step()

    test_loss1=0
    test_loss2=0
    test_acc=0
    autoencoder.eval()
    with torch.no_grad():
      for batch in tqdm.tqdm(test_loader):
          x = batch['input_drug'].to(device) # GPU
          x_hat = autoencoder(x)
          # x2=torch.transpose(x,1,2)
          x2_hat=torch.transpose(x_hat,1,2)
          loss1=loss_fn(x2_hat,x)
          loss2=autoencoder.encoder.kl
          acc=acc_fn(x,x2_hat)
          loss = loss1 + loss2
          test_acc+=acc.detach().cpu().numpy()
          test_loss1+=loss1.detach().cpu().numpy()
          test_loss2+=loss2.detach().cpu().numpy()

    print(f"TRAIN: EPOCH {epoch}: SSE: {train_loss1/len(train_loader)}, KL_LOSS: {train_loss2/len(train_loader)}, ACC: {train_acc/len(train_loader)}   \nTEST: EPOCH {epoch}: SSE: {test_loss1/len(test_loader)}, KL_LOSS: {test_loss2/len(test_loader)}, ACC: {test_acc/len(test_loader)}")

100%|██████████| 2500/2500 [00:24<00:00, 101.02it/s]
100%|██████████| 625/625 [00:03<00:00, 175.79it/s]


TRAIN: EPOCH 0: SSE: 1.7329068256378173, KL_LOSS: 0.04536705662759487, ACC: 0.4831164893992711   
TEST: EPOCH 0: SSE: 1.557160227203369, KL_LOSS: 0.04142424785494805, ACC: 0.521051988697052


100%|██████████| 2500/2500 [00:24<00:00, 100.61it/s]
100%|██████████| 625/625 [00:03<00:00, 176.78it/s]


TRAIN: EPOCH 1: SSE: 1.4094014980316163, KL_LOSS: 0.06349502699226141, ACC: 0.5593669870257377   
TEST: EPOCH 1: SSE: 1.2341641063690185, KL_LOSS: 0.08825779937505722, ACC: 0.6001389857292175


100%|██████████| 2500/2500 [00:24<00:00, 101.04it/s]
100%|██████████| 625/625 [00:03<00:00, 175.64it/s]


TRAIN: EPOCH 2: SSE: 1.1337827871084214, KL_LOSS: 0.1014793445289135, ACC: 0.6447377353191376   
TEST: EPOCH 2: SSE: 1.0637198895454407, KL_LOSS: 0.10796176235675811, ACC: 0.6651439858436584


100%|██████████| 2500/2500 [00:25<00:00, 99.16it/s] 
100%|██████████| 625/625 [00:03<00:00, 168.31it/s]


TRAIN: EPOCH 3: SSE: 1.0271055659294128, KL_LOSS: 0.10870277592837811, ACC: 0.6782552352428436   
TEST: EPOCH 3: SSE: 0.9945407009124756, KL_LOSS: 0.11257987473011016, ACC: 0.6909109842300415


100%|██████████| 2500/2500 [00:24<00:00, 101.28it/s]
100%|██████████| 625/625 [00:03<00:00, 157.83it/s]


TRAIN: EPOCH 4: SSE: 0.9678002001523972, KL_LOSS: 0.11267001200318337, ACC: 0.7019804843902588   
TEST: EPOCH 4: SSE: 0.9428338751792907, KL_LOSS: 0.11465626987218856, ACC: 0.7121459829330444


100%|██████████| 2500/2500 [00:24<00:00, 101.89it/s]
100%|██████████| 625/625 [00:04<00:00, 141.32it/s]


TRAIN: EPOCH 5: SSE: 0.9277705830335617, KL_LOSS: 0.1147374362140894, ACC: 0.7168684832811355   
TEST: EPOCH 5: SSE: 0.9123746046066284, KL_LOSS: 0.11447037304639816, ACC: 0.7232589825630188


100%|██████████| 2500/2500 [00:24<00:00, 103.61it/s]
100%|██████████| 625/625 [00:04<00:00, 132.67it/s]


TRAIN: EPOCH 6: SSE: 0.9005432815074921, KL_LOSS: 0.11568438049256802, ACC: 0.7257027338266373   
TEST: EPOCH 6: SSE: 0.8842336497306824, KL_LOSS: 0.11747664399147034, ACC: 0.7307519826889038


100%|██████████| 2500/2500 [00:23<00:00, 105.70it/s]
100%|██████████| 625/625 [00:04<00:00, 129.14it/s]


TRAIN: EPOCH 7: SSE: 0.8795268325090408, KL_LOSS: 0.11658605397641658, ACC: 0.7312957329273224   
TEST: EPOCH 7: SSE: 0.877216932106018, KL_LOSS: 0.11717942887544631, ACC: 0.7307409835815429


100%|██████████| 2500/2500 [00:23<00:00, 105.14it/s]
100%|██████████| 625/625 [00:04<00:00, 131.56it/s]


TRAIN: EPOCH 8: SSE: 0.8634593811511994, KL_LOSS: 0.11665506655275822, ACC: 0.7348562330722809   
TEST: EPOCH 8: SSE: 0.865338073348999, KL_LOSS: 0.11528468533754349, ACC: 0.7333449836730958


100%|██████████| 2500/2500 [00:24<00:00, 103.13it/s]
100%|██████████| 625/625 [00:04<00:00, 144.19it/s]


TRAIN: EPOCH 9: SSE: 0.8498242005825043, KL_LOSS: 0.11654270901381969, ACC: 0.7381854835748672   
TEST: EPOCH 9: SSE: 0.8427015717506409, KL_LOSS: 0.11693298007249832, ACC: 0.7413379835128784


100%|██████████| 2500/2500 [00:24<00:00, 102.38it/s]
100%|██████████| 625/625 [00:04<00:00, 156.09it/s]


TRAIN: EPOCH 10: SSE: 0.8389622572898865, KL_LOSS: 0.11616859417557716, ACC: 0.740850983452797   
TEST: EPOCH 10: SSE: 0.8308414995193482, KL_LOSS: 0.11499279592037201, ACC: 0.7434469840049743


100%|██████████| 2500/2500 [00:24<00:00, 102.00it/s]
100%|██████████| 625/625 [00:03<00:00, 170.34it/s]


TRAIN: EPOCH 11: SSE: 0.8283717918395996, KL_LOSS: 0.11563461984395981, ACC: 0.7435132335662842   
TEST: EPOCH 11: SSE: 0.8185758823394775, KL_LOSS: 0.11574002152681351, ACC: 0.7463339835166931


100%|██████████| 2500/2500 [00:24<00:00, 100.91it/s]
100%|██████████| 625/625 [00:03<00:00, 181.01it/s]


TRAIN: EPOCH 12: SSE: 0.8191111178874969, KL_LOSS: 0.11540149481594562, ACC: 0.7461329836845398   
TEST: EPOCH 12: SSE: 0.8209346473693848, KL_LOSS: 0.11565997811555863, ACC: 0.7466039834976196


100%|██████████| 2500/2500 [00:24<00:00, 100.77it/s]
100%|██████████| 625/625 [00:03<00:00, 174.80it/s]


TRAIN: EPOCH 13: SSE: 0.809559364748001, KL_LOSS: 0.11554014559686183, ACC: 0.7489302329063415   
TEST: EPOCH 13: SSE: 0.8056609112739563, KL_LOSS: 0.1154859933257103, ACC: 0.7498419825553894


100%|██████████| 2500/2500 [00:24<00:00, 100.17it/s]
100%|██████████| 625/625 [00:03<00:00, 171.06it/s]


TRAIN: EPOCH 14: SSE: 0.8001186698198318, KL_LOSS: 0.11573752527534961, ACC: 0.7516884828567505   
TEST: EPOCH 14: SSE: 0.7964272310256958, KL_LOSS: 0.11500392144918442, ACC: 0.7538719827651977


100%|██████████| 2500/2500 [00:24<00:00, 100.53it/s]
100%|██████████| 625/625 [00:03<00:00, 171.90it/s]


TRAIN: EPOCH 15: SSE: 0.7928438481330872, KL_LOSS: 0.11605127340853215, ACC: 0.7543254835367202   
TEST: EPOCH 15: SSE: 0.7926189618110657, KL_LOSS: 0.11547503355741501, ACC: 0.753768982887268


100%|██████████| 2500/2500 [00:24<00:00, 101.10it/s]
100%|██████████| 625/625 [00:03<00:00, 177.16it/s]


TRAIN: EPOCH 16: SSE: 0.784993405175209, KL_LOSS: 0.11626617548465729, ACC: 0.7566267332315445   
TEST: EPOCH 16: SSE: 0.7856118793487549, KL_LOSS: 0.11677026090621949, ACC: 0.7575159834861755


100%|██████████| 2500/2500 [00:24<00:00, 101.13it/s]
100%|██████████| 625/625 [00:03<00:00, 174.22it/s]


TRAIN: EPOCH 17: SSE: 0.7781660021305085, KL_LOSS: 0.11622671588361264, ACC: 0.758937483382225   
TEST: EPOCH 17: SSE: 0.7751104142189026, KL_LOSS: 0.11624498453140258, ACC: 0.7602289840698242


100%|██████████| 2500/2500 [00:25<00:00, 99.66it/s] 
100%|██████████| 625/625 [00:03<00:00, 172.68it/s]


TRAIN: EPOCH 18: SSE: 0.7715345689296722, KL_LOSS: 0.11651883058249951, ACC: 0.7612654831171036   
TEST: EPOCH 18: SSE: 0.7640383650779724, KL_LOSS: 0.11729955383539199, ACC: 0.7638679838180542


100%|██████████| 2500/2500 [00:25<00:00, 99.72it/s] 
100%|██████████| 625/625 [00:03<00:00, 172.21it/s]


TRAIN: EPOCH 19: SSE: 0.7640887642145157, KL_LOSS: 0.11662399659752845, ACC: 0.7640754834413529   
TEST: EPOCH 19: SSE: 0.7733195554733276, KL_LOSS: 0.11721762299537658, ACC: 0.7599279832839966


100%|██████████| 2500/2500 [00:25<00:00, 99.36it/s] 
100%|██████████| 625/625 [00:03<00:00, 162.90it/s]


TRAIN: EPOCH 20: SSE: 0.7573182398557663, KL_LOSS: 0.11681853067278862, ACC: 0.7663729828357696   
TEST: EPOCH 20: SSE: 0.7483379940032959, KL_LOSS: 0.11613980050086975, ACC: 0.7698289833068848


100%|██████████| 2500/2500 [00:24<00:00, 100.26it/s]
100%|██████████| 625/625 [00:04<00:00, 148.11it/s]


TRAIN: EPOCH 21: SSE: 0.7505861282348633, KL_LOSS: 0.11696781914234161, ACC: 0.768779733800888   
TEST: EPOCH 21: SSE: 0.7410781036376953, KL_LOSS: 0.11713458781242371, ACC: 0.7719579835891723


100%|██████████| 2500/2500 [00:24<00:00, 101.66it/s]
100%|██████████| 625/625 [00:04<00:00, 129.52it/s]


TRAIN: EPOCH 22: SSE: 0.7453479760885239, KL_LOSS: 0.11726087546944618, ACC: 0.7705187333106994   
TEST: EPOCH 22: SSE: 0.7653258853912354, KL_LOSS: 0.11591403906345367, ACC: 0.7640869829177857


100%|██████████| 2500/2500 [00:24<00:00, 103.44it/s]
100%|██████████| 625/625 [00:04<00:00, 127.42it/s]


TRAIN: EPOCH 23: SSE: 0.7399511790275574, KL_LOSS: 0.11734315926730633, ACC: 0.7720497329950332   
TEST: EPOCH 23: SSE: 0.7363853050231933, KL_LOSS: 0.11707089885473251, ACC: 0.773295982170105


100%|██████████| 2500/2500 [00:24<00:00, 102.04it/s]
100%|██████████| 625/625 [00:04<00:00, 141.11it/s]


TRAIN: EPOCH 24: SSE: 0.7351134815931321, KL_LOSS: 0.11746311227679253, ACC: 0.7734722329616547   
TEST: EPOCH 24: SSE: 0.7369793982505798, KL_LOSS: 0.11887115293741227, ACC: 0.7742539832115173


100%|██████████| 2500/2500 [00:24<00:00, 100.44it/s]
100%|██████████| 625/625 [00:03<00:00, 163.72it/s]


TRAIN: EPOCH 25: SSE: 0.7304807121992111, KL_LOSS: 0.11754181440770627, ACC: 0.7748999833106994   
TEST: EPOCH 25: SSE: 0.7283621737480164, KL_LOSS: 0.11836755788326263, ACC: 0.7761179832458496


100%|██████████| 2500/2500 [00:25<00:00, 99.70it/s]
100%|██████████| 625/625 [00:03<00:00, 177.05it/s]


TRAIN: EPOCH 26: SSE: 0.7262233528375626, KL_LOSS: 0.11760061578750611, ACC: 0.7759209835767746   
TEST: EPOCH 26: SSE: 0.7315339291572571, KL_LOSS: 0.11707573872804641, ACC: 0.7746639841079712


100%|██████████| 2500/2500 [00:25<00:00, 99.43it/s] 
100%|██████████| 625/625 [00:03<00:00, 174.20it/s]


TRAIN: EPOCH 27: SSE: 0.7223341166496277, KL_LOSS: 0.11766115251779556, ACC: 0.777122232913971   
TEST: EPOCH 27: SSE: 0.7201931910514832, KL_LOSS: 0.11814857335090637, ACC: 0.7786949831008911


100%|██████████| 2500/2500 [00:25<00:00, 99.73it/s] 
100%|██████████| 625/625 [00:03<00:00, 176.00it/s]


TRAIN: EPOCH 28: SSE: 0.7188558028697968, KL_LOSS: 0.11778261029422284, ACC: 0.778017733001709   
TEST: EPOCH 28: SSE: 0.715818482875824, KL_LOSS: 0.11828742027282715, ACC: 0.7787429831504822


100%|██████████| 2500/2500 [00:24<00:00, 100.13it/s]
100%|██████████| 625/625 [00:03<00:00, 171.68it/s]


TRAIN: EPOCH 29: SSE: 0.7149104240655899, KL_LOSS: 0.11786732676923276, ACC: 0.7792554833412171   
TEST: EPOCH 29: SSE: 0.7130943150520325, KL_LOSS: 0.11772075281143188, ACC: 0.779673982334137


100%|██████████| 2500/2500 [00:25<00:00, 99.04it/s] 
100%|██████████| 625/625 [00:03<00:00, 173.54it/s]


TRAIN: EPOCH 30: SSE: 0.7117750764131546, KL_LOSS: 0.11791992458105087, ACC: 0.7802517327785492   
TEST: EPOCH 30: SSE: 0.7149127103805542, KL_LOSS: 0.11778641177415848, ACC: 0.781265982913971


100%|██████████| 2500/2500 [00:25<00:00, 99.49it/s] 
100%|██████████| 625/625 [00:03<00:00, 171.45it/s]


TRAIN: EPOCH 31: SSE: 0.7090681595087052, KL_LOSS: 0.11781352780461311, ACC: 0.7813659826993943   
TEST: EPOCH 31: SSE: 0.7101480493545532, KL_LOSS: 0.1179412888765335, ACC: 0.7803739847183228


100%|██████████| 2500/2500 [00:25<00:00, 97.83it/s] 
100%|██████████| 625/625 [00:03<00:00, 166.53it/s]


TRAIN: EPOCH 32: SSE: 0.705400061249733, KL_LOSS: 0.1178754622220993, ACC: 0.7825659831762314   
TEST: EPOCH 32: SSE: 0.7011581179618835, KL_LOSS: 0.11812040815353393, ACC: 0.7858509818077087


100%|██████████| 2500/2500 [00:25<00:00, 99.46it/s] 
100%|██████████| 625/625 [00:04<00:00, 143.46it/s]


TRAIN: EPOCH 33: SSE: 0.7029669037342071, KL_LOSS: 0.11788761659264564, ACC: 0.7835994825601578   
TEST: EPOCH 33: SSE: 0.7216844084739685, KL_LOSS: 0.11783850011825561, ACC: 0.7772479818344116


100%|██████████| 2500/2500 [00:25<00:00, 99.12it/s] 
100%|██████████| 625/625 [00:05<00:00, 124.23it/s]


TRAIN: EPOCH 34: SSE: 0.700458885216713, KL_LOSS: 0.11777619373500348, ACC: 0.7845777321338654   
TEST: EPOCH 34: SSE: 0.7021063164710999, KL_LOSS: 0.11861642158031463, ACC: 0.7857189821243287


100%|██████████| 2500/2500 [00:24<00:00, 100.02it/s]
100%|██████████| 625/625 [00:04<00:00, 137.85it/s]


TRAIN: EPOCH 35: SSE: 0.6974419225931168, KL_LOSS: 0.11764722608029843, ACC: 0.7855197324752807   
TEST: EPOCH 35: SSE: 0.6914135228157043, KL_LOSS: 0.11753517122268677, ACC: 0.7886919817924499


100%|██████████| 2500/2500 [00:25<00:00, 97.68it/s]
100%|██████████| 625/625 [00:03<00:00, 168.31it/s]


TRAIN: EPOCH 36: SSE: 0.6950045680999756, KL_LOSS: 0.11769795615971089, ACC: 0.7864057324647904   
TEST: EPOCH 36: SSE: 0.6905817556381225, KL_LOSS: 0.11432704916000366, ACC: 0.7883109823226929


100%|██████████| 2500/2500 [00:25<00:00, 97.35it/s] 
100%|██████████| 625/625 [00:03<00:00, 176.88it/s]


TRAIN: EPOCH 37: SSE: 0.6932112695217133, KL_LOSS: 0.1176265963613987, ACC: 0.7867784825801849   
TEST: EPOCH 37: SSE: 0.6883880308151246, KL_LOSS: 0.11847392779588699, ACC: 0.7885179825782775


100%|██████████| 2500/2500 [00:25<00:00, 98.05it/s] 
100%|██████████| 625/625 [00:03<00:00, 171.27it/s]


TRAIN: EPOCH 38: SSE: 0.6905246349811553, KL_LOSS: 0.1173207237124443, ACC: 0.7874127329826355   
TEST: EPOCH 38: SSE: 0.6956292452812195, KL_LOSS: 0.11763849881887437, ACC: 0.7854829830169677


100%|██████████| 2500/2500 [00:25<00:00, 97.21it/s] 
100%|██████████| 625/625 [00:03<00:00, 171.67it/s]


TRAIN: EPOCH 39: SSE: 0.687369520497322, KL_LOSS: 0.11670451721251011, ACC: 0.7884322321891785   
TEST: EPOCH 39: SSE: 0.6874946796417236, KL_LOSS: 0.11712166907787323, ACC: 0.7885299825668335


100%|██████████| 2500/2500 [00:25<00:00, 97.51it/s] 
100%|██████████| 625/625 [00:03<00:00, 172.03it/s]


TRAIN: EPOCH 40: SSE: 0.68445865380764, KL_LOSS: 0.1162454935014248, ACC: 0.7894247323036194   
TEST: EPOCH 40: SSE: 0.6912456891059876, KL_LOSS: 0.11719298497438431, ACC: 0.7861189816474915


100%|██████████| 2500/2500 [00:25<00:00, 97.36it/s] 
100%|██████████| 625/625 [00:03<00:00, 160.67it/s]


TRAIN: EPOCH 41: SSE: 0.6819502267599106, KL_LOSS: 0.1162149561226368, ACC: 0.7902917315244675   
TEST: EPOCH 41: SSE: 0.6855641693115234, KL_LOSS: 0.1155694287776947, ACC: 0.7910699822425842


100%|██████████| 2500/2500 [00:25<00:00, 98.90it/s] 
100%|██████████| 625/625 [00:04<00:00, 138.21it/s]


TRAIN: EPOCH 42: SSE: 0.6791555242776871, KL_LOSS: 0.11601162765920162, ACC: 0.7912909823656082   
TEST: EPOCH 42: SSE: 0.6855878747940064, KL_LOSS: 0.1160145749092102, ACC: 0.7895229809761047


100%|██████████| 2500/2500 [00:24<00:00, 101.30it/s]
100%|██████████| 625/625 [00:05<00:00, 123.60it/s]


TRAIN: EPOCH 43: SSE: 0.6763143486022949, KL_LOSS: 0.11606789807379246, ACC: 0.7919947319984436   
TEST: EPOCH 43: SSE: 0.6799611323356628, KL_LOSS: 0.11558189492225647, ACC: 0.7915449821472168


100%|██████████| 2500/2500 [00:25<00:00, 99.26it/s]
100%|██████████| 625/625 [00:04<00:00, 145.40it/s]


TRAIN: EPOCH 44: SSE: 0.6734736032485962, KL_LOSS: 0.11600300777852535, ACC: 0.7927747322320938   
TEST: EPOCH 44: SSE: 0.6698344612121582, KL_LOSS: 0.11584136279821396, ACC: 0.7928689827919007


100%|██████████| 2500/2500 [00:25<00:00, 97.89it/s]
100%|██████████| 625/625 [00:03<00:00, 168.53it/s]


TRAIN: EPOCH 45: SSE: 0.6702210710763932, KL_LOSS: 0.11599781037569046, ACC: 0.79377323179245   
TEST: EPOCH 45: SSE: 0.6655724457740784, KL_LOSS: 0.11710775783061982, ACC: 0.7952689826965332


100%|██████████| 2500/2500 [00:25<00:00, 97.32it/s] 
100%|██████████| 625/625 [00:03<00:00, 168.62it/s]


TRAIN: EPOCH 46: SSE: 0.6669195628881455, KL_LOSS: 0.11598269351422787, ACC: 0.7945127320289612   
TEST: EPOCH 46: SSE: 0.66668550863266, KL_LOSS: 0.11548554999828338, ACC: 0.7947139830589295


100%|██████████| 2500/2500 [00:25<00:00, 97.59it/s] 
100%|██████████| 625/625 [00:03<00:00, 173.16it/s]


TRAIN: EPOCH 47: SSE: 0.6638757270812988, KL_LOSS: 0.11609259955883026, ACC: 0.7952609818458557   
TEST: EPOCH 47: SSE: 0.6572712460517883, KL_LOSS: 0.11660121195316314, ACC: 0.7963539819717407


100%|██████████| 2500/2500 [00:25<00:00, 99.09it/s] 
100%|██████████| 625/625 [00:03<00:00, 172.11it/s]


TRAIN: EPOCH 48: SSE: 0.6615822263002396, KL_LOSS: 0.11612231781482696, ACC: 0.7958812318325043   
TEST: EPOCH 48: SSE: 0.6594481927871704, KL_LOSS: 0.1157727690577507, ACC: 0.797081981754303


100%|██████████| 2500/2500 [00:25<00:00, 97.99it/s] 
100%|██████████| 625/625 [00:03<00:00, 176.87it/s]


TRAIN: EPOCH 49: SSE: 0.6587161833763122, KL_LOSS: 0.1162397931009531, ACC: 0.7965359816074371   
TEST: EPOCH 49: SSE: 0.6488492684364319, KL_LOSS: 0.1169176399588585, ACC: 0.8010229821205139


100%|██████████| 2500/2500 [00:25<00:00, 98.33it/s] 
100%|██████████| 625/625 [00:03<00:00, 173.16it/s]


TRAIN: EPOCH 50: SSE: 0.6565090020179749, KL_LOSS: 0.11608487208485603, ACC: 0.7971004817247391   
TEST: EPOCH 50: SSE: 0.6576165038108825, KL_LOSS: 0.11661431782245636, ACC: 0.7969069820404052


100%|██████████| 2500/2500 [00:25<00:00, 97.78it/s] 
100%|██████████| 625/625 [00:04<00:00, 144.61it/s]


TRAIN: EPOCH 51: SSE: 0.6543844082832336, KL_LOSS: 0.11621552296578884, ACC: 0.7975474819183349   
TEST: EPOCH 51: SSE: 0.6465545546531677, KL_LOSS: 0.11636469593048096, ACC: 0.8009359808921814


100%|██████████| 2500/2500 [00:24<00:00, 100.28it/s]
100%|██████████| 625/625 [00:04<00:00, 126.54it/s]


TRAIN: EPOCH 52: SSE: 0.652092325758934, KL_LOSS: 0.11626948049366474, ACC: 0.7981864820241928   
TEST: EPOCH 52: SSE: 0.6454785597801208, KL_LOSS: 0.11596537050008773, ACC: 0.8006229819297791


100%|██████████| 2500/2500 [00:24<00:00, 101.37it/s]
100%|██████████| 625/625 [00:04<00:00, 131.09it/s]


TRAIN: EPOCH 53: SSE: 0.6499115877151489, KL_LOSS: 0.11637852600216865, ACC: 0.7987509816408157   
TEST: EPOCH 53: SSE: 0.6534251608848571, KL_LOSS: 0.11615170118808746, ACC: 0.7976119820594788


100%|██████████| 2500/2500 [00:25<00:00, 98.53it/s]
100%|██████████| 625/625 [00:03<00:00, 159.57it/s]


TRAIN: EPOCH 54: SSE: 0.6482025257349014, KL_LOSS: 0.1163394667237997, ACC: 0.7992597316026687   
TEST: EPOCH 54: SSE: 0.6534238882064819, KL_LOSS: 0.11694780459403992, ACC: 0.798764981174469


100%|██████████| 2500/2500 [00:25<00:00, 96.77it/s] 
100%|██████████| 625/625 [00:03<00:00, 171.96it/s]


TRAIN: EPOCH 55: SSE: 0.6460598526954651, KL_LOSS: 0.11638799896240234, ACC: 0.7996987314462661   
TEST: EPOCH 55: SSE: 0.6499976044654846, KL_LOSS: 0.11681795052289963, ACC: 0.7986459815979003


100%|██████████| 2500/2500 [00:25<00:00, 97.43it/s] 
100%|██████████| 625/625 [00:03<00:00, 172.39it/s]


TRAIN: EPOCH 56: SSE: 0.6433130456209183, KL_LOSS: 0.11639911673367023, ACC: 0.8005082313776016   
TEST: EPOCH 56: SSE: 0.6584482955932617, KL_LOSS: 0.1158792232632637, ACC: 0.7962919815063476


100%|██████████| 2500/2500 [00:26<00:00, 95.45it/s] 
100%|██████████| 625/625 [00:03<00:00, 173.02it/s]


TRAIN: EPOCH 57: SSE: 0.6425043110966683, KL_LOSS: 0.11658024950325489, ACC: 0.8007574810504914   
TEST: EPOCH 57: SSE: 0.6428528070449829, KL_LOSS: 0.11720048576593399, ACC: 0.8013009818077087


100%|██████████| 2500/2500 [00:25<00:00, 98.89it/s] 
100%|██████████| 625/625 [00:03<00:00, 171.36it/s]


TRAIN: EPOCH 58: SSE: 0.6407325593948364, KL_LOSS: 0.1165599247276783, ACC: 0.8013954815149307   
TEST: EPOCH 58: SSE: 0.6439101280212403, KL_LOSS: 0.11652633564472198, ACC: 0.7992179812431336


100%|██████████| 2500/2500 [00:25<00:00, 99.70it/s] 
100%|██████████| 625/625 [00:03<00:00, 167.73it/s]


TRAIN: EPOCH 59: SSE: 0.6386386773109436, KL_LOSS: 0.11661316037178039, ACC: 0.8018594814777374   
TEST: EPOCH 59: SSE: 0.6435345254898072, KL_LOSS: 0.11863545680046081, ACC: 0.8003129820823669


100%|██████████| 2500/2500 [00:24<00:00, 100.05it/s]
100%|██████████| 625/625 [00:04<00:00, 150.20it/s]


TRAIN: EPOCH 60: SSE: 0.6376457797884941, KL_LOSS: 0.1165140224993229, ACC: 0.8020747310638427   
TEST: EPOCH 60: SSE: 0.6354898194313049, KL_LOSS: 0.11665266495943069, ACC: 0.8031039813995361


100%|██████████| 2500/2500 [00:24<00:00, 101.70it/s]
100%|██████████| 625/625 [00:04<00:00, 135.43it/s]


TRAIN: EPOCH 61: SSE: 0.6351735815048217, KL_LOSS: 0.11652128240764141, ACC: 0.802769731092453   
TEST: EPOCH 61: SSE: 0.641186383152008, KL_LOSS: 0.11576834622621536, ACC: 0.8011249814033509


100%|██████████| 2500/2500 [00:24<00:00, 102.55it/s]
100%|██████████| 625/625 [00:04<00:00, 125.54it/s]


TRAIN: EPOCH 62: SSE: 0.634415907716751, KL_LOSS: 0.11652661941349507, ACC: 0.8031484818458557   
TEST: EPOCH 62: SSE: 0.6321776560783386, KL_LOSS: 0.11954442871809005, ACC: 0.8038139818191529


100%|██████████| 2500/2500 [00:24<00:00, 102.50it/s]
100%|██████████| 625/625 [00:04<00:00, 138.04it/s]


TRAIN: EPOCH 63: SSE: 0.6317395915269852, KL_LOSS: 0.11664162010848522, ACC: 0.8038877310752869   
TEST: EPOCH 63: SSE: 0.6229192549705506, KL_LOSS: 0.11597328445911408, ACC: 0.8074939805984497


100%|██████████| 2500/2500 [00:24<00:00, 100.04it/s]
100%|██████████| 625/625 [00:03<00:00, 160.10it/s]


TRAIN: EPOCH 64: SSE: 0.6318272531151772, KL_LOSS: 0.11666317461431026, ACC: 0.8038179811239242   
TEST: EPOCH 64: SSE: 0.6290526224136352, KL_LOSS: 0.11722398697137833, ACC: 0.8048669819831848


100%|██████████| 2500/2500 [00:25<00:00, 98.80it/s]
100%|██████████| 625/625 [00:03<00:00, 172.21it/s]


TRAIN: EPOCH 65: SSE: 0.6297280779361725, KL_LOSS: 0.11659754391312599, ACC: 0.8044474811315536   
TEST: EPOCH 65: SSE: 0.6235644764900208, KL_LOSS: 0.11519989544153214, ACC: 0.8061109815597535


100%|██████████| 2500/2500 [00:24<00:00, 100.28it/s]
100%|██████████| 625/625 [00:03<00:00, 174.81it/s]


TRAIN: EPOCH 66: SSE: 0.6283803198575973, KL_LOSS: 0.11672065583467484, ACC: 0.8047724809408188   
TEST: EPOCH 66: SSE: 0.6322458711624146, KL_LOSS: 0.11742766519784928, ACC: 0.8039669821739197


100%|██████████| 2500/2500 [00:25<00:00, 99.62it/s] 
100%|██████████| 625/625 [00:03<00:00, 174.99it/s]


TRAIN: EPOCH 67: SSE: 0.6275057754397392, KL_LOSS: 0.1167595045953989, ACC: 0.8050599811792374   
TEST: EPOCH 67: SSE: 0.6257556488037109, KL_LOSS: 0.11570579468011856, ACC: 0.8056209803581238


100%|██████████| 2500/2500 [00:25<00:00, 98.11it/s] 
100%|██████████| 625/625 [00:03<00:00, 176.05it/s]


TRAIN: EPOCH 68: SSE: 0.6252704157471657, KL_LOSS: 0.11675542843937874, ACC: 0.8058479810237884   
TEST: EPOCH 68: SSE: 0.6284410574913025, KL_LOSS: 0.11789172103404999, ACC: 0.8052169814109802


100%|██████████| 2500/2500 [00:25<00:00, 99.51it/s] 
100%|██████████| 625/625 [00:03<00:00, 171.36it/s]


TRAIN: EPOCH 69: SSE: 0.6249612688779831, KL_LOSS: 0.11692740949094295, ACC: 0.8059657313346863   
TEST: EPOCH 69: SSE: 0.6221507454872132, KL_LOSS: 0.11524372462034226, ACC: 0.805762981414795


100%|██████████| 2500/2500 [00:25<00:00, 99.39it/s] 
100%|██████████| 625/625 [00:03<00:00, 177.90it/s]


TRAIN: EPOCH 70: SSE: 0.6238126917481422, KL_LOSS: 0.11686963433921337, ACC: 0.8060604811668396   
TEST: EPOCH 70: SSE: 0.6185556183815002, KL_LOSS: 0.11795735193490982, ACC: 0.808048980808258


100%|██████████| 2500/2500 [00:25<00:00, 99.93it/s] 
100%|██████████| 625/625 [00:03<00:00, 173.78it/s]


TRAIN: EPOCH 71: SSE: 0.6221719986200333, KL_LOSS: 0.11688678766191006, ACC: 0.8063109813690186   
TEST: EPOCH 71: SSE: 0.6121811844348908, KL_LOSS: 0.11749732018709183, ACC: 0.8094899809837341


100%|██████████| 2500/2500 [00:25<00:00, 99.57it/s] 
100%|██████████| 625/625 [00:03<00:00, 159.64it/s]


TRAIN: EPOCH 72: SSE: 0.6216951677203179, KL_LOSS: 0.11691436794698239, ACC: 0.8065844810247421   
TEST: EPOCH 72: SSE: 0.6274884305000306, KL_LOSS: 0.11639401502609253, ACC: 0.8056129816055297


100%|██████████| 2500/2500 [00:24<00:00, 100.53it/s]
100%|██████████| 625/625 [00:04<00:00, 142.01it/s]


TRAIN: EPOCH 73: SSE: 0.6200710716962814, KL_LOSS: 0.11687725624144077, ACC: 0.8069804811239243   
TEST: EPOCH 73: SSE: 0.6142493158817292, KL_LOSS: 0.11536579501628876, ACC: 0.8095229804039001


100%|██████████| 2500/2500 [00:24<00:00, 101.69it/s]
100%|██████████| 625/625 [00:04<00:00, 127.38it/s]


TRAIN: EPOCH 74: SSE: 0.6190272648215294, KL_LOSS: 0.11676805248260498, ACC: 0.8073224805355071   
TEST: EPOCH 74: SSE: 0.6144691015243531, KL_LOSS: 0.11819208245277404, ACC: 0.8084709809303283


100%|██████████| 2500/2500 [00:24<00:00, 102.73it/s]
100%|██████████| 625/625 [00:04<00:00, 130.88it/s]


TRAIN: EPOCH 75: SSE: 0.6172735946536064, KL_LOSS: 0.11691615612506867, ACC: 0.8077134812831879   
TEST: EPOCH 75: SSE: 0.6148780479431152, KL_LOSS: 0.11585054163932801, ACC: 0.8082759819984436


100%|██████████| 2500/2500 [00:25<00:00, 98.90it/s]
100%|██████████| 625/625 [00:04<00:00, 150.93it/s]


TRAIN: EPOCH 76: SSE: 0.6163380808472634, KL_LOSS: 0.11706533734202385, ACC: 0.808034480714798   
TEST: EPOCH 76: SSE: 0.6122622131824493, KL_LOSS: 0.11788603727817536, ACC: 0.8094359807014465


100%|██████████| 2500/2500 [00:25<00:00, 97.03it/s] 
100%|██████████| 625/625 [00:03<00:00, 170.93it/s]


TRAIN: EPOCH 77: SSE: 0.6149301977276802, KL_LOSS: 0.1168828119635582, ACC: 0.8083187307119369   
TEST: EPOCH 77: SSE: 0.6148911553382873, KL_LOSS: 0.11668104555606842, ACC: 0.8078569806098937


100%|██████████| 2500/2500 [00:25<00:00, 97.06it/s] 
100%|██████████| 625/625 [00:03<00:00, 167.60it/s]


TRAIN: EPOCH 78: SSE: 0.6138522547960281, KL_LOSS: 0.11701571427285672, ACC: 0.8087574807882308   
TEST: EPOCH 78: SSE: 0.6227143013954163, KL_LOSS: 0.11641445434093475, ACC: 0.805732980632782


100%|██████████| 2500/2500 [00:25<00:00, 97.45it/s] 
100%|██████████| 625/625 [00:03<00:00, 171.07it/s]


TRAIN: EPOCH 79: SSE: 0.6131038857102394, KL_LOSS: 0.11696640854179859, ACC: 0.8088582318067551   
TEST: EPOCH 79: SSE: 0.6343743365287781, KL_LOSS: 0.11752509219646454, ACC: 0.8013389821052551


100%|██████████| 2500/2500 [00:25<00:00, 96.67it/s] 
100%|██████████| 625/625 [00:03<00:00, 171.24it/s]


TRAIN: EPOCH 80: SSE: 0.6122819123864174, KL_LOSS: 0.11703064700067044, ACC: 0.8092429814338684   
TEST: EPOCH 80: SSE: 0.607630662727356, KL_LOSS: 0.11594589298963547, ACC: 0.8113719824790955


100%|██████████| 2500/2500 [00:25<00:00, 98.08it/s]
100%|██████████| 625/625 [00:03<00:00, 163.99it/s]


TRAIN: EPOCH 81: SSE: 0.6115778997540474, KL_LOSS: 0.11711679244935512, ACC: 0.8092534816026687   
TEST: EPOCH 81: SSE: 0.6113002650260926, KL_LOSS: 0.11540246758460998, ACC: 0.809624979877472


100%|██████████| 2500/2500 [00:25<00:00, 97.91it/s] 
100%|██████████| 625/625 [00:04<00:00, 141.17it/s]


TRAIN: EPOCH 82: SSE: 0.6100941003918647, KL_LOSS: 0.11709222916662693, ACC: 0.8098059810161591   
TEST: EPOCH 82: SSE: 0.602946471452713, KL_LOSS: 0.11692820888757706, ACC: 0.812160980129242


100%|██████████| 2500/2500 [00:25<00:00, 99.42it/s] 
100%|██████████| 625/625 [00:04<00:00, 125.49it/s]


TRAIN: EPOCH 83: SSE: 0.608991111433506, KL_LOSS: 0.11703199808299541, ACC: 0.8101192307949067   
TEST: EPOCH 83: SSE: 0.6213941442489624, KL_LOSS: 0.11623374491930008, ACC: 0.8063319806098938


100%|██████████| 2500/2500 [00:25<00:00, 99.00it/s]
100%|██████████| 625/625 [00:04<00:00, 142.19it/s]


TRAIN: EPOCH 84: SSE: 0.6081390682816505, KL_LOSS: 0.11706687515377999, ACC: 0.8103967316627503   
TEST: EPOCH 84: SSE: 0.6028573438644409, KL_LOSS: 0.11659682074785233, ACC: 0.811552981376648


100%|██████████| 2500/2500 [00:25<00:00, 96.94it/s]
100%|██████████| 625/625 [00:03<00:00, 168.02it/s]


TRAIN: EPOCH 85: SSE: 0.6072362722873688, KL_LOSS: 0.11708347898125648, ACC: 0.8105204813480377   
TEST: EPOCH 85: SSE: 0.6038442438125611, KL_LOSS: 0.11714227417707443, ACC: 0.8120599813461303


100%|██████████| 2500/2500 [00:25<00:00, 97.29it/s] 
100%|██████████| 625/625 [00:03<00:00, 169.28it/s]


TRAIN: EPOCH 86: SSE: 0.6070812373876572, KL_LOSS: 0.11717338350713254, ACC: 0.8106184817075729   
TEST: EPOCH 86: SSE: 0.6046298551082611, KL_LOSS: 0.11679863662719726, ACC: 0.8106169815063476


100%|██████████| 2500/2500 [00:25<00:00, 96.23it/s] 
100%|██████████| 625/625 [00:03<00:00, 168.16it/s]


TRAIN: EPOCH 87: SSE: 0.6059851869821549, KL_LOSS: 0.11703428862988949, ACC: 0.8110009813785553   
TEST: EPOCH 87: SSE: 0.6013735325336457, KL_LOSS: 0.11676893094778061, ACC: 0.8117109815597534


100%|██████████| 2500/2500 [00:25<00:00, 98.53it/s] 
100%|██████████| 625/625 [00:03<00:00, 169.71it/s]


TRAIN: EPOCH 88: SSE: 0.6042182997345924, KL_LOSS: 0.11720406223237514, ACC: 0.8116437311649323   
TEST: EPOCH 88: SSE: 0.6048448748588562, KL_LOSS: 0.11865500626564025, ACC: 0.8107849817276


100%|██████████| 2500/2500 [00:25<00:00, 97.11it/s] 
100%|██████████| 625/625 [00:03<00:00, 167.70it/s]


TRAIN: EPOCH 89: SSE: 0.6043081852078438, KL_LOSS: 0.11716388656198978, ACC: 0.8115697306632995   
TEST: EPOCH 89: SSE: 0.5989034048080444, KL_LOSS: 0.11647504591941833, ACC: 0.8138629817962646


100%|██████████| 2500/2500 [00:26<00:00, 95.83it/s] 
100%|██████████| 625/625 [00:04<00:00, 141.82it/s]


TRAIN: EPOCH 90: SSE: 0.6037131248474121, KL_LOSS: 0.1171868954628706, ACC: 0.8118107311487198   
TEST: EPOCH 90: SSE: 0.6204586577415466, KL_LOSS: 0.1184281340122223, ACC: 0.8064729804039001


100%|██████████| 2500/2500 [00:25<00:00, 99.40it/s] 
100%|██████████| 625/625 [00:05<00:00, 124.13it/s]


TRAIN: EPOCH 91: SSE: 0.6025799540996551, KL_LOSS: 0.11712534169852734, ACC: 0.8120907314300537   
TEST: EPOCH 91: SSE: 0.6001993793487549, KL_LOSS: 0.11679884657859803, ACC: 0.8121789807319642


100%|██████████| 2500/2500 [00:25<00:00, 98.65it/s]
100%|██████████| 625/625 [00:04<00:00, 141.61it/s]


TRAIN: EPOCH 92: SSE: 0.6024697159171104, KL_LOSS: 0.1171352532029152, ACC: 0.8121949813365936   
TEST: EPOCH 92: SSE: 0.6157246738433838, KL_LOSS: 0.11771498396396637, ACC: 0.8085089817047119


100%|██████████| 2500/2500 [00:25<00:00, 97.46it/s]
100%|██████████| 625/625 [00:03<00:00, 169.50it/s]


TRAIN: EPOCH 93: SSE: 0.6016265895009041, KL_LOSS: 0.11719986858963967, ACC: 0.8125704812526703   
TEST: EPOCH 93: SSE: 0.6018652724266053, KL_LOSS: 0.11647849552631379, ACC: 0.8121919811248779


100%|██████████| 2500/2500 [00:25<00:00, 97.79it/s] 
100%|██████████| 625/625 [00:03<00:00, 172.15it/s]


TRAIN: EPOCH 94: SSE: 0.6008564864873887, KL_LOSS: 0.11709973154962063, ACC: 0.812729981136322   
TEST: EPOCH 94: SSE: 0.612155260181427, KL_LOSS: 0.11695669375658035, ACC: 0.8084549820899963


100%|██████████| 2500/2500 [00:25<00:00, 98.13it/s] 
100%|██████████| 625/625 [00:03<00:00, 173.53it/s]


TRAIN: EPOCH 95: SSE: 0.599756815636158, KL_LOSS: 0.11720977382957935, ACC: 0.8128664811611176   
TEST: EPOCH 95: SSE: 0.5991219827651978, KL_LOSS: 0.11817963800430298, ACC: 0.8126069814682007


100%|██████████| 2500/2500 [00:25<00:00, 97.06it/s]
100%|██████████| 625/625 [00:03<00:00, 165.96it/s]


TRAIN: EPOCH 96: SSE: 0.5988728234171867, KL_LOSS: 0.1172390012025833, ACC: 0.8132567316770554   
TEST: EPOCH 96: SSE: 0.6046705497741699, KL_LOSS: 0.11901417447328567, ACC: 0.8112999816894532


100%|██████████| 2500/2500 [00:25<00:00, 98.82it/s] 
100%|██████████| 625/625 [00:03<00:00, 173.13it/s]


TRAIN: EPOCH 97: SSE: 0.5986001903891564, KL_LOSS: 0.11718855932354927, ACC: 0.813389230465889   
TEST: EPOCH 97: SSE: 0.6020316452026367, KL_LOSS: 0.1171969602227211, ACC: 0.8120879807472229


100%|██████████| 2500/2500 [00:25<00:00, 97.01it/s] 
100%|██████████| 625/625 [00:04<00:00, 154.84it/s]


TRAIN: EPOCH 98: SSE: 0.5966527721047401, KL_LOSS: 0.11717203666865826, ACC: 0.8136819811344147   
TEST: EPOCH 98: SSE: 0.590332763671875, KL_LOSS: 0.11704137345552444, ACC: 0.8149979801177979


100%|██████████| 2500/2500 [00:25<00:00, 98.23it/s] 
100%|██████████| 625/625 [00:04<00:00, 130.89it/s]

TRAIN: EPOCH 99: SSE: 0.5967556195020676, KL_LOSS: 0.11728634235560895, ACC: 0.8137077312707901   
TEST: EPOCH 99: SSE: 0.6018684244155884, KL_LOSS: 0.11682128728628159, ACC: 0.8115099796295167





In [39]:
x_hat[0].argmax(-1)

tensor([67, 99, 49, 99, 99, 40, 67, 41, 99, 99, 40, 45, 99, 49, 99, 99, 40, 45,
        99, 50, 99, 99, 99, 40, 99, 41, 99, 99, 99, 41, 99, 99, 99, 41, 40, 50,
        41, 99, 49,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0],
       device='cuda:0')

In [41]:
x[0]

tensor([ 67,  99,  49,  99,  99,  40,  67,  41, 110,  99,  40,  78,  99,  50,
        110,  99,  40,  45,  99,  51,  99,  99,  99,  40,  78,  41,  99,  99,
         51,  41, 110,  91, 110,  72,  93,  50,  41, 110,  49,   0,   0,   0,
          0,   0,   0,   0,   0,   0,   0,   0], device='cuda:0')