In [1]:
import sys
sys.path.append('..')
from data.data_reader import *
from models.VAEs import *

In [2]:
import tqdm
import scanpy as sc
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.figure_factory as ff

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [3]:
# download_file('https://plus.figshare.com/ndownloader/files/35775512','35775512.h5ad')
# adata_orig = sc.read_h5ad("35775512.h5ad")
download_file('https://plus.figshare.com/ndownloader/files/35773217','35773217.h5ad')
adata_orig = sc.read_h5ad("35773217.h5ad")
adata_orig.X[adata_orig.X == float("inf")]=0

File downloaded successfully to 35773217.h5ad


In [4]:
adata_orig.obs['gene_name']=list(pd.Series(adata_orig.obs.index).apply(lambda x:x.split("_")[1]))
adata_orig.obs['id']=range(adata_orig.obs.shape[0])

In [5]:
def cosine_similarity(A):
  AAt=np.matmul(A,A.transpose())
  n_A=np.sqrt((A**2).sum(axis=1)).reshape(-1,1)
  n_A=np.matmul(n_A,n_A.transpose())
  return AAt/(n_A)

### Let's first train a VAE model

In [6]:
class X_dataset(Dataset):
    def __init__(self,data):
        self.data=data
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx):
        return {'x':torch.tensor(self.data.X[idx]),'c':torch.tensor(self.data.obs.iloc[idx]['core_control'])}


In [7]:
dataset=X_dataset(adata_orig)
train_loader=DataLoader(dataset,batch_size=32,shuffle=True)

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
autoencoder=VariationalAutoencoder(dataset[0]['x'].shape[0],32,1e-11,1024,device)
opt = torch.optim.Adam(autoencoder.parameters(),lr=0.001)
loss_fn=torch.nn.MSELoss()

  return {'x':torch.tensor(self.data.X[idx]),'c':torch.tensor(self.data.obs.iloc[idx]['core_control'])}


In [9]:
train(autoencoder,opt,loss_fn,train_loader,None,device,500)

  return {'x':torch.tensor(self.data.X[idx]),'c':torch.tensor(self.data.obs.iloc[idx]['core_control'])}
100%|██████████| 352/352 [00:03<00:00, 101.95it/s]


TRAIN: EPOCH 0: MSE: 0.018752270338485356, KL_LOSS: 3.789141329742356e-08


100%|██████████| 352/352 [00:03<00:00, 109.41it/s]


TRAIN: EPOCH 1: MSE: 0.013030178657076745, KL_LOSS: 4.750603878460198e-08


100%|██████████| 352/352 [00:03<00:00, 108.29it/s]


TRAIN: EPOCH 2: MSE: 0.012605211073109373, KL_LOSS: 4.606061499888948e-08


100%|██████████| 352/352 [00:03<00:00, 108.94it/s]


TRAIN: EPOCH 3: MSE: 0.01237584383819591, KL_LOSS: 4.549681112661377e-08


100%|██████████| 352/352 [00:03<00:00, 105.67it/s]


TRAIN: EPOCH 4: MSE: 0.012116443311573345, KL_LOSS: 4.450629699125453e-08


100%|██████████| 352/352 [00:03<00:00, 106.67it/s]


TRAIN: EPOCH 5: MSE: 0.011939483776752075, KL_LOSS: 4.541345225125584e-08


100%|██████████| 352/352 [00:03<00:00, 107.46it/s]


TRAIN: EPOCH 6: MSE: 0.011811496541750703, KL_LOSS: 4.63290876874144e-08


100%|██████████| 352/352 [00:03<00:00, 105.87it/s]


TRAIN: EPOCH 7: MSE: 0.011755921520737254, KL_LOSS: 5.518599997180538e-08


100%|██████████| 352/352 [00:03<00:00, 105.73it/s]


TRAIN: EPOCH 8: MSE: 0.011692151518549177, KL_LOSS: 5.579258431011944e-08


100%|██████████| 352/352 [00:03<00:00, 107.49it/s]


TRAIN: EPOCH 9: MSE: 0.011712749501880766, KL_LOSS: 6.922653769417506e-08


100%|██████████| 352/352 [00:03<00:00, 105.86it/s]


TRAIN: EPOCH 10: MSE: 0.011591012208637867, KL_LOSS: 7.180952123839936e-08


100%|██████████| 352/352 [00:03<00:00, 105.74it/s]


TRAIN: EPOCH 11: MSE: 0.011476051727764901, KL_LOSS: 6.887905909272715e-08


100%|██████████| 352/352 [00:03<00:00, 107.73it/s]


TRAIN: EPOCH 12: MSE: 0.011429111917757175, KL_LOSS: 6.92190305477554e-08


100%|██████████| 352/352 [00:03<00:00, 109.06it/s]


TRAIN: EPOCH 13: MSE: 0.01143724768718874, KL_LOSS: 7.153784847330707e-08


100%|██████████| 352/352 [00:03<00:00, 112.16it/s]


TRAIN: EPOCH 14: MSE: 0.011334367937789384, KL_LOSS: 7.224644090046568e-08


100%|██████████| 352/352 [00:03<00:00, 111.37it/s]


TRAIN: EPOCH 15: MSE: 0.011306186615415341, KL_LOSS: 8.000045799553974e-08


100%|██████████| 352/352 [00:03<00:00, 112.49it/s]


TRAIN: EPOCH 16: MSE: 0.011337403111445548, KL_LOSS: 8.610007046721596e-08


100%|██████████| 352/352 [00:03<00:00, 111.34it/s]


TRAIN: EPOCH 17: MSE: 0.011237190132801929, KL_LOSS: 8.453495416802716e-08


100%|██████████| 352/352 [00:03<00:00, 112.54it/s]


TRAIN: EPOCH 18: MSE: 0.011258680055933919, KL_LOSS: 8.640340838765455e-08


100%|██████████| 352/352 [00:03<00:00, 111.89it/s]


TRAIN: EPOCH 19: MSE: 0.011215988201332617, KL_LOSS: 8.90496887138361e-08


100%|██████████| 352/352 [00:03<00:00, 112.64it/s]


TRAIN: EPOCH 20: MSE: 0.0111155637542569, KL_LOSS: 8.819097464428761e-08


100%|██████████| 352/352 [00:03<00:00, 112.48it/s]


TRAIN: EPOCH 21: MSE: 0.011239684898596764, KL_LOSS: 1.0442123997801321e-07


100%|██████████| 352/352 [00:03<00:00, 113.10it/s]


TRAIN: EPOCH 22: MSE: 0.011235761308026586, KL_LOSS: 1.129258915064859e-07


100%|██████████| 352/352 [00:03<00:00, 114.23it/s]


TRAIN: EPOCH 23: MSE: 0.011081453387636098, KL_LOSS: 1.0651622024808334e-07


100%|██████████| 352/352 [00:03<00:00, 113.87it/s]


TRAIN: EPOCH 24: MSE: 0.011085876100961204, KL_LOSS: 1.1094544232809705e-07


100%|██████████| 352/352 [00:03<00:00, 100.69it/s]


TRAIN: EPOCH 25: MSE: 0.011012996267791923, KL_LOSS: 1.0888827657721963e-07


100%|██████████| 352/352 [00:03<00:00, 111.77it/s]


TRAIN: EPOCH 26: MSE: 0.010927898860906927, KL_LOSS: 1.044446074475887e-07


100%|██████████| 352/352 [00:03<00:00, 101.04it/s]


TRAIN: EPOCH 27: MSE: 0.010901159366105938, KL_LOSS: 1.0550355564944235e-07


100%|██████████| 352/352 [00:03<00:00, 99.58it/s] 


TRAIN: EPOCH 28: MSE: 0.010959736703873867, KL_LOSS: 1.0942159725776693e-07


100%|██████████| 352/352 [00:07<00:00, 50.02it/s]


TRAIN: EPOCH 29: MSE: 0.01084400760274465, KL_LOSS: 1.1371465912336408e-07


100%|██████████| 352/352 [00:07<00:00, 47.65it/s]


TRAIN: EPOCH 30: MSE: 0.010774733036999929, KL_LOSS: 1.1529938123447633e-07


100%|██████████| 352/352 [00:07<00:00, 44.56it/s]


TRAIN: EPOCH 31: MSE: 0.01089122193463316, KL_LOSS: 1.2036927065185343e-07


100%|██████████| 352/352 [00:07<00:00, 45.58it/s]


TRAIN: EPOCH 32: MSE: 0.010746351883757266, KL_LOSS: 1.1982516811105304e-07


100%|██████████| 352/352 [00:08<00:00, 43.58it/s]


TRAIN: EPOCH 33: MSE: 0.011159195038172502, KL_LOSS: 1.3128575826470978e-07


100%|██████████| 352/352 [00:08<00:00, 40.09it/s]


TRAIN: EPOCH 34: MSE: 0.010854163237540475, KL_LOSS: 1.3503048511654858e-07


100%|██████████| 352/352 [00:06<00:00, 51.46it/s]


TRAIN: EPOCH 35: MSE: 0.010824196112067015, KL_LOSS: 1.3558573260090287e-07


100%|██████████| 352/352 [00:08<00:00, 42.38it/s]


TRAIN: EPOCH 36: MSE: 0.010767274160488423, KL_LOSS: 1.361016609716356e-07


100%|██████████| 352/352 [00:09<00:00, 38.73it/s]


TRAIN: EPOCH 37: MSE: 0.010688431676085615, KL_LOSS: 1.3253036629750686e-07


100%|██████████| 352/352 [00:09<00:00, 39.06it/s]


TRAIN: EPOCH 38: MSE: 0.010677746183567004, KL_LOSS: 1.3575690012550143e-07


100%|██████████| 352/352 [00:07<00:00, 44.97it/s]


TRAIN: EPOCH 39: MSE: 0.010652042594632472, KL_LOSS: 1.403199604187545e-07


100%|██████████| 352/352 [00:07<00:00, 49.22it/s]


TRAIN: EPOCH 40: MSE: 0.010600184901374054, KL_LOSS: 1.4490850378220702e-07


100%|██████████| 352/352 [00:07<00:00, 44.43it/s]


TRAIN: EPOCH 41: MSE: 0.010658046321953985, KL_LOSS: 1.4536419314723695e-07


100%|██████████| 352/352 [00:08<00:00, 41.82it/s]


TRAIN: EPOCH 42: MSE: 0.010604729575623589, KL_LOSS: 1.472427837434509e-07


100%|██████████| 352/352 [00:07<00:00, 46.91it/s]


TRAIN: EPOCH 43: MSE: 0.010485948100474408, KL_LOSS: 1.438336513327375e-07


100%|██████████| 352/352 [00:07<00:00, 45.09it/s]


TRAIN: EPOCH 44: MSE: 0.010494415698551827, KL_LOSS: 1.441852906564572e-07


100%|██████████| 352/352 [00:08<00:00, 41.81it/s]


TRAIN: EPOCH 45: MSE: 0.010531401390008714, KL_LOSS: 1.5039767701936893e-07


100%|██████████| 352/352 [00:08<00:00, 42.36it/s]


TRAIN: EPOCH 46: MSE: 0.010517125967784192, KL_LOSS: 1.556845280975045e-07


100%|██████████| 352/352 [00:08<00:00, 42.69it/s]


TRAIN: EPOCH 47: MSE: 0.010510093335390346, KL_LOSS: 1.5377422200237653e-07


100%|██████████| 352/352 [00:08<00:00, 41.89it/s]


TRAIN: EPOCH 48: MSE: 0.010399418564206413, KL_LOSS: 1.533793189022571e-07


100%|██████████| 352/352 [00:08<00:00, 39.28it/s]


TRAIN: EPOCH 49: MSE: 0.010357432513212023, KL_LOSS: 1.5338416030226523e-07


100%|██████████| 352/352 [00:08<00:00, 43.04it/s]


TRAIN: EPOCH 50: MSE: 0.010426716931926256, KL_LOSS: 1.5897899841021967e-07


100%|██████████| 352/352 [00:08<00:00, 42.91it/s]


TRAIN: EPOCH 51: MSE: 0.010467571625750597, KL_LOSS: 1.6134050483497153e-07


100%|██████████| 352/352 [00:08<00:00, 42.57it/s]


TRAIN: EPOCH 52: MSE: 0.01041730134403968, KL_LOSS: 1.6323494354188803e-07


100%|██████████| 352/352 [00:08<00:00, 41.30it/s]


TRAIN: EPOCH 53: MSE: 0.010308530782243575, KL_LOSS: 1.635318328429104e-07


100%|██████████| 352/352 [00:08<00:00, 43.48it/s]


TRAIN: EPOCH 54: MSE: 0.010299947021518496, KL_LOSS: 1.6999192821064535e-07


100%|██████████| 352/352 [00:07<00:00, 44.90it/s]


TRAIN: EPOCH 55: MSE: 0.010234873411139812, KL_LOSS: 1.6743482975847065e-07


100%|██████████| 352/352 [00:07<00:00, 45.23it/s]


TRAIN: EPOCH 56: MSE: 0.010322782896268605, KL_LOSS: 1.710606368877553e-07


100%|██████████| 352/352 [00:07<00:00, 48.02it/s]


TRAIN: EPOCH 57: MSE: 0.010349760986917483, KL_LOSS: 1.8083835810629848e-07


100%|██████████| 352/352 [00:07<00:00, 48.89it/s]


TRAIN: EPOCH 58: MSE: 0.010186842309353366, KL_LOSS: 1.7747721890095455e-07


100%|██████████| 352/352 [00:08<00:00, 43.73it/s]


TRAIN: EPOCH 59: MSE: 0.010215670478911225, KL_LOSS: 1.8193671583429963e-07


100%|██████████| 352/352 [00:08<00:00, 43.05it/s]


TRAIN: EPOCH 60: MSE: 0.01013486152276693, KL_LOSS: 1.7691378895179355e-07


100%|██████████| 352/352 [00:07<00:00, 44.74it/s]


TRAIN: EPOCH 61: MSE: 0.010194575287063013, KL_LOSS: 1.784992743109564e-07


100%|██████████| 352/352 [00:08<00:00, 42.73it/s]


TRAIN: EPOCH 62: MSE: 0.010364744717001238, KL_LOSS: 1.9336898970934158e-07


100%|██████████| 352/352 [00:07<00:00, 49.06it/s]


TRAIN: EPOCH 63: MSE: 0.010210867211836476, KL_LOSS: 1.9462254615990301e-07


100%|██████████| 352/352 [00:09<00:00, 38.71it/s]


TRAIN: EPOCH 64: MSE: 0.010112388029070147, KL_LOSS: 1.9000026978007026e-07


100%|██████████| 352/352 [00:08<00:00, 42.87it/s]


TRAIN: EPOCH 65: MSE: 0.01007445206976792, KL_LOSS: 1.9080451969925735e-07


100%|██████████| 352/352 [00:07<00:00, 47.96it/s]


TRAIN: EPOCH 66: MSE: 0.010054987884359434, KL_LOSS: 1.9152863622357524e-07


100%|██████████| 352/352 [00:08<00:00, 39.13it/s]


TRAIN: EPOCH 67: MSE: 0.010050457025169057, KL_LOSS: 1.909172554997502e-07


100%|██████████| 352/352 [00:08<00:00, 42.06it/s]


TRAIN: EPOCH 68: MSE: 0.010110596623483369, KL_LOSS: 1.9230725896561426e-07


100%|██████████| 352/352 [00:07<00:00, 47.63it/s]


TRAIN: EPOCH 69: MSE: 0.010113884386638265, KL_LOSS: 1.952971089897398e-07


100%|██████████| 352/352 [00:07<00:00, 47.01it/s]


TRAIN: EPOCH 70: MSE: 0.010061888661849398, KL_LOSS: 1.9182104845941436e-07


100%|██████████| 352/352 [00:07<00:00, 45.31it/s]


TRAIN: EPOCH 71: MSE: 0.009983818925155158, KL_LOSS: 1.9349003575121834e-07


100%|██████████| 352/352 [00:07<00:00, 45.36it/s]


TRAIN: EPOCH 72: MSE: 0.010031417697477578, KL_LOSS: 1.9426303991694607e-07


100%|██████████| 352/352 [00:08<00:00, 39.80it/s]


TRAIN: EPOCH 73: MSE: 0.010103529009608214, KL_LOSS: 1.984217527657832e-07


100%|██████████| 352/352 [00:08<00:00, 41.80it/s]


TRAIN: EPOCH 74: MSE: 0.0100740442086879, KL_LOSS: 2.0123052178113518e-07


100%|██████████| 352/352 [00:08<00:00, 40.65it/s]


TRAIN: EPOCH 75: MSE: 0.010091120104665275, KL_LOSS: 2.0321183681236799e-07


100%|██████████| 352/352 [00:08<00:00, 42.13it/s]


TRAIN: EPOCH 76: MSE: 0.009972003707926806, KL_LOSS: 2.0666766948290222e-07


100%|██████████| 352/352 [00:08<00:00, 41.83it/s]


TRAIN: EPOCH 77: MSE: 0.009955182691390457, KL_LOSS: 2.090263965955201e-07


100%|██████████| 352/352 [00:08<00:00, 41.48it/s]


TRAIN: EPOCH 78: MSE: 0.009913365462456237, KL_LOSS: 2.0940594030511245e-07


100%|██████████| 352/352 [00:08<00:00, 42.71it/s]


TRAIN: EPOCH 79: MSE: 0.00989626731924628, KL_LOSS: 2.1042908522743306e-07


100%|██████████| 352/352 [00:07<00:00, 44.95it/s]


TRAIN: EPOCH 80: MSE: 0.009978519281668758, KL_LOSS: 2.0980682875885878e-07


100%|██████████| 352/352 [00:07<00:00, 44.49it/s]


TRAIN: EPOCH 81: MSE: 0.00992677982272157, KL_LOSS: 2.0719296096997783e-07


100%|██████████| 352/352 [00:08<00:00, 41.96it/s]


TRAIN: EPOCH 82: MSE: 0.009848635861999355, KL_LOSS: 2.0782874823120004e-07


100%|██████████| 352/352 [00:08<00:00, 42.15it/s]


TRAIN: EPOCH 83: MSE: 0.009829589366828177, KL_LOSS: 2.1052929266573983e-07


100%|██████████| 352/352 [00:08<00:00, 43.04it/s]


TRAIN: EPOCH 84: MSE: 0.009865932757913304, KL_LOSS: 2.0897647277286843e-07


100%|██████████| 352/352 [00:08<00:00, 41.12it/s]


TRAIN: EPOCH 85: MSE: 0.009901084287610667, KL_LOSS: 2.143444478877851e-07


100%|██████████| 352/352 [00:08<00:00, 40.21it/s]


TRAIN: EPOCH 86: MSE: 0.0099019786469977, KL_LOSS: 2.1619357479998122e-07


100%|██████████| 352/352 [00:07<00:00, 48.71it/s]


TRAIN: EPOCH 87: MSE: 0.009858092793199996, KL_LOSS: 2.160477193914222e-07


100%|██████████| 352/352 [00:06<00:00, 51.16it/s]


TRAIN: EPOCH 88: MSE: 0.009798333155975508, KL_LOSS: 2.1722548870783953e-07


100%|██████████| 352/352 [00:07<00:00, 49.15it/s]


TRAIN: EPOCH 89: MSE: 0.009771357943960042, KL_LOSS: 2.1790274191085307e-07


100%|██████████| 352/352 [00:08<00:00, 42.93it/s]


TRAIN: EPOCH 90: MSE: 0.009814551469224336, KL_LOSS: 2.2229060459011407e-07


100%|██████████| 352/352 [00:08<00:00, 40.25it/s]


TRAIN: EPOCH 91: MSE: 0.009805096690384247, KL_LOSS: 2.2186592466252622e-07


100%|██████████| 352/352 [00:08<00:00, 42.21it/s]


TRAIN: EPOCH 92: MSE: 0.009739646700009789, KL_LOSS: 2.2006585309597426e-07


100%|██████████| 352/352 [00:07<00:00, 45.02it/s]


TRAIN: EPOCH 93: MSE: 0.009725412384299985, KL_LOSS: 2.1965186651358562e-07


100%|██████████| 352/352 [00:08<00:00, 43.13it/s]


TRAIN: EPOCH 94: MSE: 0.009732450973718765, KL_LOSS: 2.187612372866516e-07


100%|██████████| 352/352 [00:08<00:00, 39.33it/s]


TRAIN: EPOCH 95: MSE: 0.009745157256309705, KL_LOSS: 2.2449026774880435e-07


100%|██████████| 352/352 [00:07<00:00, 44.35it/s]


TRAIN: EPOCH 96: MSE: 0.00976440017587844, KL_LOSS: 2.2401889712949954e-07


100%|██████████| 352/352 [00:07<00:00, 48.54it/s]


TRAIN: EPOCH 97: MSE: 0.00973519745589742, KL_LOSS: 2.2486128268474978e-07


100%|██████████| 352/352 [00:06<00:00, 50.29it/s]


TRAIN: EPOCH 98: MSE: 0.009701935152406804, KL_LOSS: 2.2616381293002777e-07


100%|██████████| 352/352 [00:07<00:00, 47.47it/s]


TRAIN: EPOCH 99: MSE: 0.009741640076919628, KL_LOSS: 2.2825791499349954e-07


100%|██████████| 352/352 [00:07<00:00, 46.04it/s]


TRAIN: EPOCH 100: MSE: 0.009698722842255269, KL_LOSS: 2.3043391364749058e-07


100%|██████████| 352/352 [00:08<00:00, 41.28it/s]


TRAIN: EPOCH 101: MSE: 0.009694571630627086, KL_LOSS: 2.3420464384022405e-07


100%|██████████| 352/352 [00:08<00:00, 42.66it/s]


TRAIN: EPOCH 102: MSE: 0.009605770683265291, KL_LOSS: 2.3400938677718067e-07


100%|██████████| 352/352 [00:08<00:00, 40.44it/s]


TRAIN: EPOCH 103: MSE: 0.009637124041381123, KL_LOSS: 2.3351334054278927e-07


100%|██████████| 352/352 [00:08<00:00, 41.95it/s]


TRAIN: EPOCH 104: MSE: 0.009661988289074296, KL_LOSS: 2.3340786608533842e-07


100%|██████████| 352/352 [00:07<00:00, 46.46it/s]


TRAIN: EPOCH 105: MSE: 0.009661883306266232, KL_LOSS: 2.3330517077853455e-07


100%|██████████| 352/352 [00:08<00:00, 42.29it/s]


TRAIN: EPOCH 106: MSE: 0.009608473263638602, KL_LOSS: 2.3365435009026214e-07


100%|██████████| 352/352 [00:08<00:00, 42.67it/s]


TRAIN: EPOCH 107: MSE: 0.009604727336078544, KL_LOSS: 2.3239321485361716e-07


100%|██████████| 352/352 [00:08<00:00, 43.24it/s]


TRAIN: EPOCH 108: MSE: 0.009558006287658249, KL_LOSS: 2.329271008395427e-07


100%|██████████| 352/352 [00:07<00:00, 44.51it/s]


TRAIN: EPOCH 109: MSE: 0.009572459396705117, KL_LOSS: 2.339067290161009e-07


100%|██████████| 352/352 [00:07<00:00, 46.79it/s]


TRAIN: EPOCH 110: MSE: 0.0095961810690245, KL_LOSS: 2.3210546987425831e-07


100%|██████████| 352/352 [00:07<00:00, 44.74it/s]


TRAIN: EPOCH 111: MSE: 0.009582778038086624, KL_LOSS: 2.3597221849252553e-07


100%|██████████| 352/352 [00:08<00:00, 43.84it/s]


TRAIN: EPOCH 112: MSE: 0.009609712099931627, KL_LOSS: 2.3778716684574874e-07


100%|██████████| 352/352 [00:08<00:00, 43.49it/s]


TRAIN: EPOCH 113: MSE: 0.009680823026097973, KL_LOSS: 2.4965730771597805e-07


100%|██████████| 352/352 [00:07<00:00, 45.02it/s]


TRAIN: EPOCH 114: MSE: 0.009546273712079379, KL_LOSS: 2.4898043570299355e-07


100%|██████████| 352/352 [00:07<00:00, 46.47it/s]


TRAIN: EPOCH 115: MSE: 0.009545137401966547, KL_LOSS: 2.463704859517241e-07


100%|██████████| 352/352 [00:07<00:00, 45.49it/s]


TRAIN: EPOCH 116: MSE: 0.009550540316426619, KL_LOSS: 2.4799816814277425e-07


100%|██████████| 352/352 [00:07<00:00, 45.91it/s]


TRAIN: EPOCH 117: MSE: 0.009570059104589745, KL_LOSS: 2.49538485228775e-07


100%|██████████| 352/352 [00:07<00:00, 46.03it/s]


TRAIN: EPOCH 118: MSE: 0.009506591907914051, KL_LOSS: 2.5001296810600593e-07


100%|██████████| 352/352 [00:07<00:00, 45.93it/s]


TRAIN: EPOCH 119: MSE: 0.00950304825462147, KL_LOSS: 2.4918386093924544e-07


100%|██████████| 352/352 [00:07<00:00, 48.78it/s]


TRAIN: EPOCH 120: MSE: 0.009483877554885112, KL_LOSS: 2.4887744798363087e-07


100%|██████████| 352/352 [00:07<00:00, 49.85it/s]


TRAIN: EPOCH 121: MSE: 0.009537146746879444, KL_LOSS: 2.494562134636169e-07


100%|██████████| 352/352 [00:07<00:00, 44.42it/s]


TRAIN: EPOCH 122: MSE: 0.00951499989870089, KL_LOSS: 2.479302563251906e-07


100%|██████████| 352/352 [00:09<00:00, 38.89it/s]


TRAIN: EPOCH 123: MSE: 0.009502679491891864, KL_LOSS: 2.473986121823399e-07


100%|██████████| 352/352 [00:08<00:00, 39.47it/s]


TRAIN: EPOCH 124: MSE: 0.009471534788396886, KL_LOSS: 2.5054898790021307e-07


100%|██████████| 352/352 [00:08<00:00, 39.18it/s]


TRAIN: EPOCH 125: MSE: 0.009481180881266482, KL_LOSS: 2.5239504643857136e-07


100%|██████████| 352/352 [00:09<00:00, 38.40it/s]


TRAIN: EPOCH 126: MSE: 0.009481287765083835, KL_LOSS: 2.539148426870177e-07


100%|██████████| 352/352 [00:08<00:00, 41.79it/s]


TRAIN: EPOCH 127: MSE: 0.009444107773809017, KL_LOSS: 2.539931588284684e-07


100%|██████████| 352/352 [00:08<00:00, 41.44it/s]


TRAIN: EPOCH 128: MSE: 0.009486747857987542, KL_LOSS: 2.55819249339595e-07


100%|██████████| 352/352 [00:08<00:00, 42.03it/s]


TRAIN: EPOCH 129: MSE: 0.009460865109841425, KL_LOSS: 2.5354858047453954e-07


100%|██████████| 352/352 [00:08<00:00, 40.19it/s]


TRAIN: EPOCH 130: MSE: 0.009419878366646695, KL_LOSS: 2.5499638828757104e-07


100%|██████████| 352/352 [00:08<00:00, 40.16it/s]


TRAIN: EPOCH 131: MSE: 0.009459743317248385, KL_LOSS: 2.560975365106799e-07


100%|██████████| 352/352 [00:08<00:00, 40.31it/s]


TRAIN: EPOCH 132: MSE: 0.009423442732606252, KL_LOSS: 2.5610947043927273e-07


100%|██████████| 352/352 [00:08<00:00, 41.25it/s]


TRAIN: EPOCH 133: MSE: 0.009437926904700527, KL_LOSS: 2.616496761476128e-07


100%|██████████| 352/352 [00:08<00:00, 42.43it/s]


TRAIN: EPOCH 134: MSE: 0.009420142452158458, KL_LOSS: 2.62367916483506e-07


100%|██████████| 352/352 [00:08<00:00, 42.19it/s]


TRAIN: EPOCH 135: MSE: 0.009426897690652615, KL_LOSS: 2.623691659890555e-07


100%|██████████| 352/352 [00:08<00:00, 41.28it/s]


TRAIN: EPOCH 136: MSE: 0.009405115989714184, KL_LOSS: 2.613947296998214e-07


100%|██████████| 352/352 [00:08<00:00, 39.16it/s]


TRAIN: EPOCH 137: MSE: 0.009453454486157914, KL_LOSS: 2.659214616587943e-07


100%|██████████| 352/352 [00:08<00:00, 40.23it/s]


TRAIN: EPOCH 138: MSE: 0.009415452811729418, KL_LOSS: 2.686879099413695e-07


100%|██████████| 352/352 [00:08<00:00, 40.92it/s]


TRAIN: EPOCH 139: MSE: 0.00937325816167603, KL_LOSS: 2.6524134367665975e-07


100%|██████████| 352/352 [00:08<00:00, 39.42it/s]


TRAIN: EPOCH 140: MSE: 0.00940317373368254, KL_LOSS: 2.680858600970479e-07


100%|██████████| 352/352 [00:07<00:00, 45.57it/s]


TRAIN: EPOCH 141: MSE: 0.009391693677488629, KL_LOSS: 2.6513839192852307e-07


100%|██████████| 352/352 [00:07<00:00, 48.48it/s]


TRAIN: EPOCH 142: MSE: 0.009390467082002115, KL_LOSS: 2.6906290674265743e-07


100%|██████████| 352/352 [00:07<00:00, 46.45it/s]


TRAIN: EPOCH 143: MSE: 0.009377818623430689, KL_LOSS: 2.694379561887026e-07


100%|██████████| 352/352 [00:07<00:00, 48.03it/s]


TRAIN: EPOCH 144: MSE: 0.009416753718555397, KL_LOSS: 2.739307332078056e-07


100%|██████████| 352/352 [00:07<00:00, 46.71it/s]


TRAIN: EPOCH 145: MSE: 0.009396959158369678, KL_LOSS: 2.7395207863827925e-07


100%|██████████| 352/352 [00:07<00:00, 49.12it/s]


TRAIN: EPOCH 146: MSE: 0.009369225302097302, KL_LOSS: 2.74976717920791e-07


100%|██████████| 352/352 [00:07<00:00, 49.73it/s]


TRAIN: EPOCH 147: MSE: 0.00935474322415592, KL_LOSS: 2.780201012842941e-07


100%|██████████| 352/352 [00:07<00:00, 47.78it/s]


TRAIN: EPOCH 148: MSE: 0.009409071823359805, KL_LOSS: 2.763204205245209e-07


100%|██████████| 352/352 [00:08<00:00, 41.55it/s]


TRAIN: EPOCH 149: MSE: 0.009324310481050898, KL_LOSS: 2.7723102161400254e-07


100%|██████████| 352/352 [00:08<00:00, 41.33it/s]


TRAIN: EPOCH 150: MSE: 0.009321535672907803, KL_LOSS: 2.781424820968487e-07


100%|██████████| 352/352 [00:08<00:00, 39.58it/s]


TRAIN: EPOCH 151: MSE: 0.009328228795890358, KL_LOSS: 2.768184666541112e-07


100%|██████████| 352/352 [00:08<00:00, 40.10it/s]


TRAIN: EPOCH 152: MSE: 0.009338687000722117, KL_LOSS: 2.771552929785189e-07


100%|██████████| 352/352 [00:07<00:00, 44.95it/s]


TRAIN: EPOCH 153: MSE: 0.009324857578824529, KL_LOSS: 2.7365250610987923e-07


100%|██████████| 352/352 [00:08<00:00, 43.23it/s]


TRAIN: EPOCH 154: MSE: 0.009322709526961924, KL_LOSS: 2.754825094910338e-07


100%|██████████| 352/352 [00:07<00:00, 46.13it/s]


TRAIN: EPOCH 155: MSE: 0.009318120904605497, KL_LOSS: 2.8004795625914303e-07


100%|██████████| 352/352 [00:07<00:00, 44.76it/s]


TRAIN: EPOCH 156: MSE: 0.009298645979883573, KL_LOSS: 2.8148800801375273e-07


100%|██████████| 352/352 [00:08<00:00, 42.73it/s]


TRAIN: EPOCH 157: MSE: 0.009326492667737924, KL_LOSS: 2.8055008982619256e-07


100%|██████████| 352/352 [00:07<00:00, 45.22it/s]


TRAIN: EPOCH 158: MSE: 0.009323124533032322, KL_LOSS: 2.784478618991024e-07


100%|██████████| 352/352 [00:08<00:00, 40.82it/s]


TRAIN: EPOCH 159: MSE: 0.00928340354717379, KL_LOSS: 2.8147116104548786e-07


100%|██████████| 352/352 [00:03<00:00, 104.98it/s]


TRAIN: EPOCH 160: MSE: 0.009276153431115248, KL_LOSS: 2.7956111591332833e-07


100%|██████████| 352/352 [00:03<00:00, 105.60it/s]


TRAIN: EPOCH 161: MSE: 0.00929832740670959, KL_LOSS: 2.82642815334229e-07


100%|██████████| 352/352 [00:03<00:00, 105.87it/s]


TRAIN: EPOCH 162: MSE: 0.009309596925793978, KL_LOSS: 2.828194289670385e-07


100%|██████████| 352/352 [00:03<00:00, 108.80it/s]


TRAIN: EPOCH 163: MSE: 0.00934375178555704, KL_LOSS: 2.900611497254472e-07


100%|██████████| 352/352 [00:03<00:00, 108.93it/s]


TRAIN: EPOCH 164: MSE: 0.009324761839128438, KL_LOSS: 2.8403750624244656e-07


100%|██████████| 352/352 [00:03<00:00, 108.88it/s]


TRAIN: EPOCH 165: MSE: 0.00927158995711414, KL_LOSS: 2.8219484325055855e-07


100%|██████████| 352/352 [00:03<00:00, 109.14it/s]


TRAIN: EPOCH 166: MSE: 0.009271419994712976, KL_LOSS: 2.8156979943587274e-07


100%|██████████| 352/352 [00:03<00:00, 109.72it/s]


TRAIN: EPOCH 167: MSE: 0.009267765921487642, KL_LOSS: 2.815565069576435e-07


100%|██████████| 352/352 [00:03<00:00, 109.42it/s]


TRAIN: EPOCH 168: MSE: 0.009241221815252422, KL_LOSS: 2.810676840764808e-07


100%|██████████| 352/352 [00:03<00:00, 107.88it/s]


TRAIN: EPOCH 169: MSE: 0.009240416648522527, KL_LOSS: 2.83801074963851e-07


100%|██████████| 352/352 [00:03<00:00, 108.93it/s]


TRAIN: EPOCH 170: MSE: 0.009249290606186894, KL_LOSS: 2.829175879429355e-07


100%|██████████| 352/352 [00:03<00:00, 108.93it/s]


TRAIN: EPOCH 171: MSE: 0.009278314225163987, KL_LOSS: 2.826634641098477e-07


100%|██████████| 352/352 [00:03<00:00, 96.99it/s] 


TRAIN: EPOCH 172: MSE: 0.009258398315234279, KL_LOSS: 2.872001468968609e-07


100%|██████████| 352/352 [00:03<00:00, 105.44it/s]


TRAIN: EPOCH 173: MSE: 0.009266105087590404, KL_LOSS: 2.8988666376441396e-07


100%|██████████| 352/352 [00:07<00:00, 48.61it/s] 


TRAIN: EPOCH 174: MSE: 0.009223646242224442, KL_LOSS: 2.8722475916867746e-07


100%|██████████| 352/352 [00:10<00:00, 34.62it/s]


TRAIN: EPOCH 175: MSE: 0.009237260299745354, KL_LOSS: 2.886014065708118e-07


100%|██████████| 352/352 [00:09<00:00, 37.45it/s]


TRAIN: EPOCH 176: MSE: 0.009248662466151554, KL_LOSS: 2.901179007737524e-07


100%|██████████| 352/352 [00:07<00:00, 46.95it/s]


TRAIN: EPOCH 177: MSE: 0.009233745362673124, KL_LOSS: 2.9129855069501526e-07


100%|██████████| 352/352 [00:08<00:00, 43.73it/s]


TRAIN: EPOCH 178: MSE: 0.009241527248046954, KL_LOSS: 2.914241832498796e-07


100%|██████████| 352/352 [00:08<00:00, 43.68it/s]


TRAIN: EPOCH 179: MSE: 0.009223714374291541, KL_LOSS: 2.9523014271665654e-07


100%|██████████| 352/352 [00:08<00:00, 43.41it/s]


TRAIN: EPOCH 180: MSE: 0.00922937558921562, KL_LOSS: 3.0149036592516415e-07


100%|██████████| 352/352 [00:07<00:00, 47.16it/s]


TRAIN: EPOCH 181: MSE: 0.009218065818210809, KL_LOSS: 3.010211190212043e-07


100%|██████████| 352/352 [00:08<00:00, 41.50it/s]


TRAIN: EPOCH 182: MSE: 0.009200343706073578, KL_LOSS: 3.0233780416595707e-07


100%|██████████| 352/352 [00:08<00:00, 40.42it/s]


TRAIN: EPOCH 183: MSE: 0.009181722296158445, KL_LOSS: 3.0172706103312213e-07


100%|██████████| 352/352 [00:08<00:00, 41.82it/s]


TRAIN: EPOCH 184: MSE: 0.009192110334300774, KL_LOSS: 3.022118028959333e-07


100%|██████████| 352/352 [00:07<00:00, 44.72it/s]


TRAIN: EPOCH 185: MSE: 0.009224961792245846, KL_LOSS: 3.0277713992564836e-07


100%|██████████| 352/352 [00:08<00:00, 41.18it/s]


TRAIN: EPOCH 186: MSE: 0.009231857856238175, KL_LOSS: 3.033213204646627e-07


100%|██████████| 352/352 [00:08<00:00, 42.40it/s]


TRAIN: EPOCH 187: MSE: 0.00917730950194792, KL_LOSS: 3.0707491676713e-07


100%|██████████| 352/352 [00:08<00:00, 42.59it/s]


TRAIN: EPOCH 188: MSE: 0.00917558200440412, KL_LOSS: 3.083005782788074e-07


100%|██████████| 352/352 [00:08<00:00, 42.72it/s]


TRAIN: EPOCH 189: MSE: 0.009176825306812216, KL_LOSS: 3.061776107800411e-07


100%|██████████| 352/352 [00:08<00:00, 39.22it/s]


TRAIN: EPOCH 190: MSE: 0.00919904207429764, KL_LOSS: 3.0285390486348907e-07


100%|██████████| 352/352 [00:08<00:00, 41.41it/s]


TRAIN: EPOCH 191: MSE: 0.009185833730787801, KL_LOSS: 3.0843734314086907e-07


100%|██████████| 352/352 [00:08<00:00, 42.04it/s]


TRAIN: EPOCH 192: MSE: 0.009172521687684242, KL_LOSS: 3.080985919723389e-07


100%|██████████| 352/352 [00:08<00:00, 43.73it/s]


TRAIN: EPOCH 193: MSE: 0.009158854876824824, KL_LOSS: 3.11963561327001e-07


100%|██████████| 352/352 [00:07<00:00, 44.44it/s]


TRAIN: EPOCH 194: MSE: 0.00914881115948612, KL_LOSS: 3.091081616930906e-07


100%|██████████| 352/352 [00:07<00:00, 46.58it/s]


TRAIN: EPOCH 195: MSE: 0.00914978433998344, KL_LOSS: 3.123905091778118e-07


100%|██████████| 352/352 [00:08<00:00, 40.70it/s]


TRAIN: EPOCH 196: MSE: 0.009165249325716022, KL_LOSS: 3.114912424964897e-07


100%|██████████| 352/352 [00:08<00:00, 40.42it/s]


TRAIN: EPOCH 197: MSE: 0.00917225999000948, KL_LOSS: 3.144392070356358e-07


100%|██████████| 352/352 [00:08<00:00, 40.48it/s]


TRAIN: EPOCH 198: MSE: 0.009163980921254155, KL_LOSS: 3.168954952993307e-07


100%|██████████| 352/352 [00:07<00:00, 45.71it/s]


TRAIN: EPOCH 199: MSE: 0.009139359963062981, KL_LOSS: 3.174664433558405e-07


100%|██████████| 352/352 [00:09<00:00, 36.74it/s]


TRAIN: EPOCH 200: MSE: 0.009136178545304574, KL_LOSS: 3.1540968269506597e-07


100%|██████████| 352/352 [00:08<00:00, 41.61it/s]


TRAIN: EPOCH 201: MSE: 0.0091772884612014, KL_LOSS: 3.264966525019407e-07


100%|██████████| 352/352 [00:08<00:00, 43.40it/s]


TRAIN: EPOCH 202: MSE: 0.009124941918840208, KL_LOSS: 3.1264587847927444e-07


100%|██████████| 352/352 [00:08<00:00, 41.75it/s]


TRAIN: EPOCH 203: MSE: 0.009124410236836411, KL_LOSS: 3.157707201172467e-07


100%|██████████| 352/352 [00:07<00:00, 45.57it/s]


TRAIN: EPOCH 204: MSE: 0.009129104155528528, KL_LOSS: 3.178757844421201e-07


100%|██████████| 352/352 [00:07<00:00, 46.01it/s]


TRAIN: EPOCH 205: MSE: 0.00913682101633061, KL_LOSS: 3.198626664545724e-07


100%|██████████| 352/352 [00:07<00:00, 47.52it/s]


TRAIN: EPOCH 206: MSE: 0.00911047057624356, KL_LOSS: 3.196492765023994e-07


100%|██████████| 352/352 [00:07<00:00, 44.73it/s]


TRAIN: EPOCH 207: MSE: 0.009142030553035014, KL_LOSS: 3.251184957616973e-07


100%|██████████| 352/352 [00:07<00:00, 45.25it/s]


TRAIN: EPOCH 208: MSE: 0.009146413830669851, KL_LOSS: 3.2260353782282375e-07


100%|██████████| 352/352 [00:07<00:00, 45.72it/s]


TRAIN: EPOCH 209: MSE: 0.009112271746165457, KL_LOSS: 3.2368199894121247e-07


100%|██████████| 352/352 [00:06<00:00, 51.71it/s]


TRAIN: EPOCH 210: MSE: 0.009104426718708552, KL_LOSS: 3.2155451893086967e-07


100%|██████████| 352/352 [00:06<00:00, 53.29it/s]


TRAIN: EPOCH 211: MSE: 0.009082757963121614, KL_LOSS: 3.220205482470171e-07


100%|██████████| 352/352 [00:06<00:00, 51.68it/s]


TRAIN: EPOCH 212: MSE: 0.009094084370934235, KL_LOSS: 3.2161696029607934e-07


100%|██████████| 352/352 [00:06<00:00, 51.82it/s]


TRAIN: EPOCH 213: MSE: 0.009130921300807544, KL_LOSS: 3.2388345565111237e-07


100%|██████████| 352/352 [00:06<00:00, 50.49it/s]


TRAIN: EPOCH 214: MSE: 0.009119759928911331, KL_LOSS: 3.2624563890819175e-07


100%|██████████| 352/352 [00:07<00:00, 45.27it/s]


TRAIN: EPOCH 215: MSE: 0.009109054479102435, KL_LOSS: 3.286177642730143e-07


100%|██████████| 352/352 [00:07<00:00, 47.24it/s]


TRAIN: EPOCH 216: MSE: 0.009091801749574106, KL_LOSS: 3.284023096289895e-07


100%|██████████| 352/352 [00:07<00:00, 46.01it/s]


TRAIN: EPOCH 217: MSE: 0.00909636276430154, KL_LOSS: 3.2409627515050943e-07


100%|██████████| 352/352 [00:07<00:00, 46.95it/s]


TRAIN: EPOCH 218: MSE: 0.009076821313514798, KL_LOSS: 3.2802776042925336e-07


100%|██████████| 352/352 [00:07<00:00, 46.30it/s]


TRAIN: EPOCH 219: MSE: 0.009126154125128365, KL_LOSS: 3.296997684966062e-07


100%|██████████| 352/352 [00:08<00:00, 43.82it/s]


TRAIN: EPOCH 220: MSE: 0.00909008050670805, KL_LOSS: 3.290121707199102e-07


100%|██████████| 352/352 [00:07<00:00, 45.40it/s]


TRAIN: EPOCH 221: MSE: 0.00906365925038699, KL_LOSS: 3.274178470972929e-07


100%|██████████| 352/352 [00:08<00:00, 43.56it/s]


TRAIN: EPOCH 222: MSE: 0.009062993480421772, KL_LOSS: 3.270264263025266e-07


100%|██████████| 352/352 [00:09<00:00, 37.06it/s]


TRAIN: EPOCH 223: MSE: 0.009064783627929335, KL_LOSS: 3.306829174124201e-07


100%|██████████| 352/352 [00:08<00:00, 40.99it/s]


TRAIN: EPOCH 224: MSE: 0.009067601143297825, KL_LOSS: 3.3249074564297004e-07


100%|██████████| 352/352 [00:08<00:00, 43.20it/s]


TRAIN: EPOCH 225: MSE: 0.009078214310151949, KL_LOSS: 3.309860699861712e-07


100%|██████████| 352/352 [00:08<00:00, 43.77it/s]


TRAIN: EPOCH 226: MSE: 0.00906609043489549, KL_LOSS: 3.326065735558876e-07


100%|██████████| 352/352 [00:07<00:00, 47.49it/s]


TRAIN: EPOCH 227: MSE: 0.009068483933912252, KL_LOSS: 3.333962877892883e-07


100%|██████████| 352/352 [00:08<00:00, 40.28it/s]


TRAIN: EPOCH 228: MSE: 0.009067097624541599, KL_LOSS: 3.306345401102559e-07


100%|██████████| 352/352 [00:07<00:00, 46.25it/s]


TRAIN: EPOCH 229: MSE: 0.009063760810849171, KL_LOSS: 3.3292891764670107e-07


100%|██████████| 352/352 [00:05<00:00, 63.77it/s] 


TRAIN: EPOCH 230: MSE: 0.009058434041269886, KL_LOSS: 3.325351085401642e-07


100%|██████████| 352/352 [00:04<00:00, 77.49it/s]


TRAIN: EPOCH 231: MSE: 0.009043078566636805, KL_LOSS: 3.343978475133337e-07


100%|██████████| 352/352 [00:06<00:00, 53.38it/s]


TRAIN: EPOCH 232: MSE: 0.009065254334762523, KL_LOSS: 3.4039080160591145e-07


100%|██████████| 352/352 [00:08<00:00, 41.10it/s]


TRAIN: EPOCH 233: MSE: 0.00909196582770991, KL_LOSS: 3.4282012073560963e-07


100%|██████████| 352/352 [00:09<00:00, 37.18it/s]


TRAIN: EPOCH 234: MSE: 0.009092396863376383, KL_LOSS: 3.41613319934905e-07


100%|██████████| 352/352 [00:08<00:00, 42.04it/s]


TRAIN: EPOCH 235: MSE: 0.009042625996781599, KL_LOSS: 3.3649409246872997e-07


100%|██████████| 352/352 [00:08<00:00, 40.71it/s]


TRAIN: EPOCH 236: MSE: 0.009031601103329607, KL_LOSS: 3.417667120621879e-07


100%|██████████| 352/352 [00:10<00:00, 34.71it/s]


TRAIN: EPOCH 237: MSE: 0.009040604799255643, KL_LOSS: 3.4232008314851536e-07


100%|██████████| 352/352 [00:08<00:00, 39.88it/s]


TRAIN: EPOCH 238: MSE: 0.009038201543841173, KL_LOSS: 3.4524109179559753e-07


100%|██████████| 352/352 [00:07<00:00, 46.39it/s]


TRAIN: EPOCH 239: MSE: 0.009065580445828593, KL_LOSS: 3.4670801211818386e-07


100%|██████████| 352/352 [00:08<00:00, 39.84it/s]


TRAIN: EPOCH 240: MSE: 0.00902031316019764, KL_LOSS: 3.4742391379955445e-07


100%|██████████| 352/352 [00:08<00:00, 43.38it/s]


TRAIN: EPOCH 241: MSE: 0.009071409366284073, KL_LOSS: 3.4844003375540794e-07


100%|██████████| 352/352 [00:08<00:00, 41.75it/s]


TRAIN: EPOCH 242: MSE: 0.009039480424359102, KL_LOSS: 3.4958925160760207e-07


100%|██████████| 352/352 [00:06<00:00, 50.51it/s]


TRAIN: EPOCH 243: MSE: 0.009024231522661548, KL_LOSS: 3.4965171428909386e-07


100%|██████████| 352/352 [00:07<00:00, 44.21it/s]


TRAIN: EPOCH 244: MSE: 0.009041833084087226, KL_LOSS: 3.4960866234324717e-07


100%|██████████| 352/352 [00:07<00:00, 46.31it/s]


TRAIN: EPOCH 245: MSE: 0.009052247983742167, KL_LOSS: 3.4669901325592245e-07


100%|██████████| 352/352 [00:07<00:00, 44.70it/s]


TRAIN: EPOCH 246: MSE: 0.009029693656802092, KL_LOSS: 3.4659924005611936e-07


100%|██████████| 352/352 [00:06<00:00, 53.48it/s]


TRAIN: EPOCH 247: MSE: 0.009037287450734188, KL_LOSS: 3.468794048767553e-07


100%|██████████| 352/352 [00:06<00:00, 53.38it/s]


TRAIN: EPOCH 248: MSE: 0.009025949788999489, KL_LOSS: 3.451661202514732e-07


100%|██████████| 352/352 [00:06<00:00, 52.37it/s]


TRAIN: EPOCH 249: MSE: 0.00902918642333878, KL_LOSS: 3.502596899387165e-07


100%|██████████| 352/352 [00:07<00:00, 49.74it/s]


TRAIN: EPOCH 250: MSE: 0.009019778346886265, KL_LOSS: 3.482598585271355e-07


100%|██████████| 352/352 [00:06<00:00, 54.77it/s]


TRAIN: EPOCH 251: MSE: 0.009011661940762266, KL_LOSS: 3.487517854303881e-07


100%|██████████| 352/352 [00:07<00:00, 48.84it/s]


TRAIN: EPOCH 252: MSE: 0.009027438394365494, KL_LOSS: 3.5244774186018257e-07


100%|██████████| 352/352 [00:06<00:00, 57.33it/s]


TRAIN: EPOCH 253: MSE: 0.009009878988515331, KL_LOSS: 3.534983397718799e-07


100%|██████████| 352/352 [00:06<00:00, 53.92it/s]


TRAIN: EPOCH 254: MSE: 0.009016813337273727, KL_LOSS: 3.5411168494451056e-07


100%|██████████| 352/352 [00:06<00:00, 50.94it/s]


TRAIN: EPOCH 255: MSE: 0.00902182275603991, KL_LOSS: 3.489044725862362e-07


100%|██████████| 352/352 [00:07<00:00, 45.57it/s]


TRAIN: EPOCH 256: MSE: 0.009063203565099022, KL_LOSS: 3.568337847426051e-07


100%|██████████| 352/352 [00:07<00:00, 45.49it/s]


TRAIN: EPOCH 257: MSE: 0.009044110712817532, KL_LOSS: 3.585492255421162e-07


100%|██████████| 352/352 [00:07<00:00, 46.40it/s]


TRAIN: EPOCH 258: MSE: 0.009014826385695911, KL_LOSS: 3.610268756272138e-07


100%|██████████| 352/352 [00:07<00:00, 49.83it/s]


TRAIN: EPOCH 259: MSE: 0.009009812577542934, KL_LOSS: 3.6009789792436986e-07


100%|██████████| 352/352 [00:06<00:00, 50.86it/s]


TRAIN: EPOCH 260: MSE: 0.009020506084198132, KL_LOSS: 3.6396691446586e-07


100%|██████████| 352/352 [00:06<00:00, 50.42it/s]


TRAIN: EPOCH 261: MSE: 0.00901806793990545, KL_LOSS: 3.637911504160808e-07


100%|██████████| 352/352 [00:07<00:00, 50.21it/s]


TRAIN: EPOCH 262: MSE: 0.009001885576501743, KL_LOSS: 3.6587680376307863e-07


100%|██████████| 352/352 [00:06<00:00, 50.80it/s]


TRAIN: EPOCH 263: MSE: 0.009002913506595756, KL_LOSS: 3.658823266179793e-07


100%|██████████| 352/352 [00:06<00:00, 50.33it/s]


TRAIN: EPOCH 264: MSE: 0.009005984574535185, KL_LOSS: 3.657387025154693e-07


100%|██████████| 352/352 [00:06<00:00, 50.52it/s]


TRAIN: EPOCH 265: MSE: 0.009018168000197462, KL_LOSS: 3.6534718507074226e-07


100%|██████████| 352/352 [00:04<00:00, 74.01it/s] 


TRAIN: EPOCH 266: MSE: 0.008998031452806159, KL_LOSS: 3.622846446272764e-07


100%|██████████| 352/352 [00:03<00:00, 96.60it/s] 


TRAIN: EPOCH 267: MSE: 0.008983962481248785, KL_LOSS: 3.6580016962969614e-07


100%|██████████| 352/352 [00:05<00:00, 64.62it/s]


TRAIN: EPOCH 268: MSE: 0.008998117560457269, KL_LOSS: 3.661768155764507e-07


100%|██████████| 352/352 [00:05<00:00, 63.64it/s]


TRAIN: EPOCH 269: MSE: 0.009036648748654195, KL_LOSS: 3.651778384330193e-07


100%|██████████| 352/352 [00:05<00:00, 65.43it/s]


TRAIN: EPOCH 270: MSE: 0.008993073608673347, KL_LOSS: 3.700830568866629e-07


100%|██████████| 352/352 [00:05<00:00, 65.41it/s]


TRAIN: EPOCH 271: MSE: 0.008980943088483235, KL_LOSS: 3.656878579343987e-07


100%|██████████| 352/352 [00:05<00:00, 66.04it/s]


TRAIN: EPOCH 272: MSE: 0.00897829255501909, KL_LOSS: 3.7177110012720505e-07


100%|██████████| 352/352 [00:05<00:00, 65.04it/s]


TRAIN: EPOCH 273: MSE: 0.00897944151471496, KL_LOSS: 3.702598028580569e-07


100%|██████████| 352/352 [00:05<00:00, 66.85it/s]


TRAIN: EPOCH 274: MSE: 0.008980463979903354, KL_LOSS: 3.710158645763396e-07


100%|██████████| 352/352 [00:05<00:00, 65.31it/s]


TRAIN: EPOCH 275: MSE: 0.008992703425147656, KL_LOSS: 3.7785632723729246e-07


100%|██████████| 352/352 [00:05<00:00, 63.24it/s]


TRAIN: EPOCH 276: MSE: 0.009013629563427954, KL_LOSS: 3.763122321228965e-07


100%|██████████| 352/352 [00:05<00:00, 64.35it/s]


TRAIN: EPOCH 277: MSE: 0.008964264585467225, KL_LOSS: 3.7349059949837676e-07


100%|██████████| 352/352 [00:05<00:00, 64.49it/s]


TRAIN: EPOCH 278: MSE: 0.008967809060281566, KL_LOSS: 3.772326343016113e-07


100%|██████████| 352/352 [00:05<00:00, 67.00it/s]


TRAIN: EPOCH 279: MSE: 0.00897857552097941, KL_LOSS: 3.7367047064665565e-07


100%|██████████| 352/352 [00:05<00:00, 62.37it/s]


TRAIN: EPOCH 280: MSE: 0.008995612436344592, KL_LOSS: 3.760086331183743e-07


100%|██████████| 352/352 [00:06<00:00, 55.50it/s]


TRAIN: EPOCH 281: MSE: 0.00896925924197686, KL_LOSS: 3.74301649634519e-07


100%|██████████| 352/352 [00:06<00:00, 53.95it/s]


TRAIN: EPOCH 282: MSE: 0.008966400811914355, KL_LOSS: 3.755343061971995e-07


100%|██████████| 352/352 [00:06<00:00, 54.23it/s]


TRAIN: EPOCH 283: MSE: 0.00898882432242813, KL_LOSS: 3.7527198996785543e-07


100%|██████████| 352/352 [00:06<00:00, 55.27it/s]


TRAIN: EPOCH 284: MSE: 0.008981418982412752, KL_LOSS: 3.77348388172746e-07


100%|██████████| 352/352 [00:06<00:00, 57.07it/s]


TRAIN: EPOCH 285: MSE: 0.008960569523465396, KL_LOSS: 3.7806512022019556e-07


100%|██████████| 352/352 [00:06<00:00, 58.22it/s]


TRAIN: EPOCH 286: MSE: 0.008954842860641127, KL_LOSS: 3.8224384998644917e-07


100%|██████████| 352/352 [00:06<00:00, 56.79it/s]


TRAIN: EPOCH 287: MSE: 0.008993358647620136, KL_LOSS: 3.8727637177967983e-07


100%|██████████| 352/352 [00:06<00:00, 56.95it/s]


TRAIN: EPOCH 288: MSE: 0.008968374817579223, KL_LOSS: 3.822251757929445e-07


100%|██████████| 352/352 [00:05<00:00, 61.40it/s]


TRAIN: EPOCH 289: MSE: 0.008954291562655602, KL_LOSS: 3.8551039750377536e-07


100%|██████████| 352/352 [00:05<00:00, 63.64it/s]


TRAIN: EPOCH 290: MSE: 0.00894700659608299, KL_LOSS: 3.85005494152324e-07


100%|██████████| 352/352 [00:05<00:00, 64.58it/s]


TRAIN: EPOCH 291: MSE: 0.008949805475739677, KL_LOSS: 3.8627866788038484e-07


100%|██████████| 352/352 [00:05<00:00, 59.18it/s]


TRAIN: EPOCH 292: MSE: 0.008942676868173294, KL_LOSS: 3.868398945818605e-07


100%|██████████| 352/352 [00:05<00:00, 60.72it/s]


TRAIN: EPOCH 293: MSE: 0.008950541108127007, KL_LOSS: 3.9105070155768956e-07


100%|██████████| 352/352 [00:06<00:00, 57.59it/s]


TRAIN: EPOCH 294: MSE: 0.008948053200252947, KL_LOSS: 3.9138369628039465e-07


100%|██████████| 352/352 [00:06<00:00, 56.79it/s]


TRAIN: EPOCH 295: MSE: 0.008946757467294281, KL_LOSS: 3.9082625225006495e-07


100%|██████████| 352/352 [00:06<00:00, 58.41it/s]


TRAIN: EPOCH 296: MSE: 0.008949883671117608, KL_LOSS: 3.9246799337694876e-07


100%|██████████| 352/352 [00:06<00:00, 54.96it/s]


TRAIN: EPOCH 297: MSE: 0.008944100185710175, KL_LOSS: 3.9258915011815404e-07


100%|██████████| 352/352 [00:06<00:00, 57.31it/s]


TRAIN: EPOCH 298: MSE: 0.008947896394106052, KL_LOSS: 3.954926354894765e-07


100%|██████████| 352/352 [00:06<00:00, 56.27it/s]


TRAIN: EPOCH 299: MSE: 0.008941672981811942, KL_LOSS: 3.9654475783831605e-07


100%|██████████| 352/352 [00:05<00:00, 62.86it/s]


TRAIN: EPOCH 300: MSE: 0.008957703775112432, KL_LOSS: 3.9674717071265647e-07


100%|██████████| 352/352 [00:05<00:00, 63.58it/s]


TRAIN: EPOCH 301: MSE: 0.008926976158611731, KL_LOSS: 3.9782632380277766e-07


100%|██████████| 352/352 [00:05<00:00, 63.38it/s]


TRAIN: EPOCH 302: MSE: 0.008941822259327058, KL_LOSS: 4.016543391216487e-07


100%|██████████| 352/352 [00:05<00:00, 64.05it/s]


TRAIN: EPOCH 303: MSE: 0.008927051320576786, KL_LOSS: 3.9866308387814797e-07


100%|██████████| 352/352 [00:05<00:00, 61.01it/s]


TRAIN: EPOCH 304: MSE: 0.00891497195830611, KL_LOSS: 3.9575427872180896e-07


100%|██████████| 352/352 [00:07<00:00, 49.95it/s]


TRAIN: EPOCH 305: MSE: 0.008913360397193835, KL_LOSS: 3.969396186288453e-07


100%|██████████| 352/352 [00:03<00:00, 93.90it/s] 


TRAIN: EPOCH 306: MSE: 0.008911841949818401, KL_LOSS: 3.9438379592385525e-07


100%|██████████| 352/352 [00:03<00:00, 108.13it/s]


TRAIN: EPOCH 307: MSE: 0.008941886773258871, KL_LOSS: 3.932194697530246e-07


100%|██████████| 352/352 [00:03<00:00, 107.66it/s]


TRAIN: EPOCH 308: MSE: 0.008927458631329831, KL_LOSS: 3.938252146214309e-07


100%|██████████| 352/352 [00:03<00:00, 108.49it/s]


TRAIN: EPOCH 309: MSE: 0.00893036580898545, KL_LOSS: 3.9483077776933134e-07


100%|██████████| 352/352 [00:03<00:00, 108.77it/s]


TRAIN: EPOCH 310: MSE: 0.008924999487300573, KL_LOSS: 3.948068881114027e-07


100%|██████████| 352/352 [00:03<00:00, 108.88it/s]


TRAIN: EPOCH 311: MSE: 0.008910449227021838, KL_LOSS: 3.977218598905097e-07


100%|██████████| 352/352 [00:03<00:00, 109.08it/s]


TRAIN: EPOCH 312: MSE: 0.00891107889897698, KL_LOSS: 3.944383979033587e-07


100%|██████████| 352/352 [00:03<00:00, 107.51it/s]


TRAIN: EPOCH 313: MSE: 0.008906552745346826, KL_LOSS: 3.972859547603362e-07


100%|██████████| 352/352 [00:03<00:00, 107.27it/s]


TRAIN: EPOCH 314: MSE: 0.00893666648798072, KL_LOSS: 4.0154668543770167e-07


100%|██████████| 352/352 [00:03<00:00, 107.60it/s]


TRAIN: EPOCH 315: MSE: 0.008907683011123234, KL_LOSS: 4.019688676460863e-07


100%|██████████| 352/352 [00:03<00:00, 106.50it/s]


TRAIN: EPOCH 316: MSE: 0.00889818724350665, KL_LOSS: 4.0206906248840824e-07


100%|██████████| 352/352 [00:03<00:00, 107.46it/s]


TRAIN: EPOCH 317: MSE: 0.00890430871434298, KL_LOSS: 4.0026894711851147e-07


100%|██████████| 352/352 [00:03<00:00, 107.09it/s]


TRAIN: EPOCH 318: MSE: 0.008893970140806314, KL_LOSS: 3.9926727790629e-07


100%|██████████| 352/352 [00:03<00:00, 108.07it/s]


TRAIN: EPOCH 319: MSE: 0.008926080516010354, KL_LOSS: 4.0079327888843286e-07


100%|██████████| 352/352 [00:03<00:00, 108.59it/s]


TRAIN: EPOCH 320: MSE: 0.008901948216424154, KL_LOSS: 3.9763367798428817e-07


100%|██████████| 352/352 [00:03<00:00, 109.15it/s]


TRAIN: EPOCH 321: MSE: 0.008900927080751651, KL_LOSS: 4.020868858069868e-07


100%|██████████| 352/352 [00:03<00:00, 107.18it/s]


TRAIN: EPOCH 322: MSE: 0.008887064795337872, KL_LOSS: 3.99692177940639e-07


100%|██████████| 352/352 [00:03<00:00, 112.22it/s]


TRAIN: EPOCH 323: MSE: 0.008897765511805615, KL_LOSS: 4.075761291295004e-07


100%|██████████| 352/352 [00:03<00:00, 112.46it/s]


TRAIN: EPOCH 324: MSE: 0.008897030036843551, KL_LOSS: 4.045299004056863e-07


100%|██████████| 352/352 [00:03<00:00, 111.70it/s]


TRAIN: EPOCH 325: MSE: 0.008894481966060332, KL_LOSS: 4.0770324901987204e-07


100%|██████████| 352/352 [00:03<00:00, 113.22it/s]


TRAIN: EPOCH 326: MSE: 0.008896842282741669, KL_LOSS: 4.085596564900646e-07


100%|██████████| 352/352 [00:03<00:00, 113.13it/s]


TRAIN: EPOCH 327: MSE: 0.00890314243786799, KL_LOSS: 4.132268529779346e-07


100%|██████████| 352/352 [00:03<00:00, 112.08it/s]


TRAIN: EPOCH 328: MSE: 0.008893860154785216, KL_LOSS: 4.143378220724323e-07


100%|██████████| 352/352 [00:03<00:00, 111.88it/s]


TRAIN: EPOCH 329: MSE: 0.008881437212535688, KL_LOSS: 4.130282214828443e-07


100%|██████████| 352/352 [00:03<00:00, 113.46it/s]


TRAIN: EPOCH 330: MSE: 0.008902074732098052, KL_LOSS: 4.1190247269572533e-07


100%|██████████| 352/352 [00:03<00:00, 112.26it/s]


TRAIN: EPOCH 331: MSE: 0.0088961403433297, KL_LOSS: 4.1893689471030693e-07


100%|██████████| 352/352 [00:03<00:00, 111.84it/s]


TRAIN: EPOCH 332: MSE: 0.008892605370916002, KL_LOSS: 4.1475471114115303e-07


100%|██████████| 352/352 [00:03<00:00, 112.29it/s]


TRAIN: EPOCH 333: MSE: 0.008887924657425505, KL_LOSS: 4.1631930113407056e-07


100%|██████████| 352/352 [00:03<00:00, 112.76it/s]


TRAIN: EPOCH 334: MSE: 0.008898830532085743, KL_LOSS: 4.1578141027090704e-07


100%|██████████| 352/352 [00:03<00:00, 112.33it/s]


TRAIN: EPOCH 335: MSE: 0.008896963058346459, KL_LOSS: 4.185294148422775e-07


100%|██████████| 352/352 [00:03<00:00, 111.74it/s]


TRAIN: EPOCH 336: MSE: 0.008886079088287343, KL_LOSS: 4.2070357460359403e-07


100%|██████████| 352/352 [00:03<00:00, 109.28it/s]


TRAIN: EPOCH 337: MSE: 0.008864907395697876, KL_LOSS: 4.222131909547112e-07


100%|██████████| 352/352 [00:03<00:00, 112.00it/s]


TRAIN: EPOCH 338: MSE: 0.008886061709331856, KL_LOSS: 4.260742743673001e-07


100%|██████████| 352/352 [00:03<00:00, 109.82it/s]


TRAIN: EPOCH 339: MSE: 0.008888888400344347, KL_LOSS: 4.2392493870835324e-07


100%|██████████| 352/352 [00:03<00:00, 109.56it/s]


TRAIN: EPOCH 340: MSE: 0.008862080331627194, KL_LOSS: 4.25411956651477e-07


100%|██████████| 352/352 [00:03<00:00, 111.05it/s]


TRAIN: EPOCH 341: MSE: 0.008869771472580562, KL_LOSS: 4.2200277252455654e-07


100%|██████████| 352/352 [00:03<00:00, 115.80it/s]


TRAIN: EPOCH 342: MSE: 0.008878882598682221, KL_LOSS: 4.245762208880551e-07


100%|██████████| 352/352 [00:03<00:00, 116.08it/s]


TRAIN: EPOCH 343: MSE: 0.008865836577818052, KL_LOSS: 4.2528823501248264e-07


100%|██████████| 352/352 [00:03<00:00, 115.44it/s]


TRAIN: EPOCH 344: MSE: 0.008866436999100684, KL_LOSS: 4.2467636300487627e-07


100%|██████████| 352/352 [00:03<00:00, 115.65it/s]


TRAIN: EPOCH 345: MSE: 0.008867710073404438, KL_LOSS: 4.2300500758717485e-07


100%|██████████| 352/352 [00:03<00:00, 115.34it/s]


TRAIN: EPOCH 346: MSE: 0.008856344245776365, KL_LOSS: 4.268554736341912e-07


100%|██████████| 352/352 [00:03<00:00, 115.99it/s]


TRAIN: EPOCH 347: MSE: 0.008853028894835998, KL_LOSS: 4.256062913816033e-07


100%|██████████| 352/352 [00:03<00:00, 115.86it/s]


TRAIN: EPOCH 348: MSE: 0.00886451264060187, KL_LOSS: 4.2312650102874466e-07


100%|██████████| 352/352 [00:03<00:00, 115.95it/s]


TRAIN: EPOCH 349: MSE: 0.008865875053081916, KL_LOSS: 4.2056359744241883e-07


100%|██████████| 352/352 [00:03<00:00, 112.53it/s]


TRAIN: EPOCH 350: MSE: 0.008860249086865224, KL_LOSS: 4.2128704670251305e-07


100%|██████████| 352/352 [00:03<00:00, 115.43it/s]


TRAIN: EPOCH 351: MSE: 0.008865531147669324, KL_LOSS: 4.2072931885880654e-07


100%|██████████| 352/352 [00:03<00:00, 115.03it/s]


TRAIN: EPOCH 352: MSE: 0.008857742156992159, KL_LOSS: 4.245228831094168e-07


100%|██████████| 352/352 [00:03<00:00, 115.01it/s]


TRAIN: EPOCH 353: MSE: 0.008868334236798215, KL_LOSS: 4.289808444630177e-07


100%|██████████| 352/352 [00:03<00:00, 115.23it/s]


TRAIN: EPOCH 354: MSE: 0.008855338678007352, KL_LOSS: 4.2464429968318157e-07


100%|██████████| 352/352 [00:03<00:00, 114.95it/s]


TRAIN: EPOCH 355: MSE: 0.008924349744285626, KL_LOSS: 4.4645162064230443e-07


100%|██████████| 352/352 [00:03<00:00, 115.14it/s]


TRAIN: EPOCH 356: MSE: 0.008886306889259933, KL_LOSS: 4.269060310106583e-07


100%|██████████| 352/352 [00:03<00:00, 115.36it/s]


TRAIN: EPOCH 357: MSE: 0.008874293150835332, KL_LOSS: 4.339866060864107e-07


100%|██████████| 352/352 [00:03<00:00, 115.01it/s]


TRAIN: EPOCH 358: MSE: 0.008923604383281518, KL_LOSS: 4.3763979042730417e-07


100%|██████████| 352/352 [00:03<00:00, 115.03it/s]


TRAIN: EPOCH 359: MSE: 0.008852105768958361, KL_LOSS: 4.231264411170731e-07


100%|██████████| 352/352 [00:03<00:00, 114.92it/s]


TRAIN: EPOCH 360: MSE: 0.008839284916493025, KL_LOSS: 4.203227742426531e-07


100%|██████████| 352/352 [00:03<00:00, 113.98it/s]


TRAIN: EPOCH 361: MSE: 0.008837542405341414, KL_LOSS: 4.286285157917291e-07


100%|██████████| 352/352 [00:03<00:00, 111.04it/s]


TRAIN: EPOCH 362: MSE: 0.008851480014114217, KL_LOSS: 4.2914900707816367e-07


100%|██████████| 352/352 [00:03<00:00, 111.44it/s]


TRAIN: EPOCH 363: MSE: 0.008838904266171581, KL_LOSS: 4.249577842750021e-07


100%|██████████| 352/352 [00:03<00:00, 111.45it/s]


TRAIN: EPOCH 364: MSE: 0.008844507045895707, KL_LOSS: 4.2971440450401133e-07


100%|██████████| 352/352 [00:03<00:00, 111.35it/s]


TRAIN: EPOCH 365: MSE: 0.00886704222119244, KL_LOSS: 4.3504895492157464e-07


100%|██████████| 352/352 [00:03<00:00, 113.44it/s]


TRAIN: EPOCH 366: MSE: 0.00885315316025464, KL_LOSS: 4.335806659160436e-07


100%|██████████| 352/352 [00:03<00:00, 115.75it/s]


TRAIN: EPOCH 367: MSE: 0.008836669992888346, KL_LOSS: 4.305115850074385e-07


100%|██████████| 352/352 [00:03<00:00, 115.71it/s]


TRAIN: EPOCH 368: MSE: 0.008834077725292776, KL_LOSS: 4.282795564656474e-07


100%|██████████| 352/352 [00:03<00:00, 115.91it/s]


TRAIN: EPOCH 369: MSE: 0.00883047958169217, KL_LOSS: 4.2942348731468305e-07


100%|██████████| 352/352 [00:03<00:00, 115.58it/s]


TRAIN: EPOCH 370: MSE: 0.00882360514184587, KL_LOSS: 4.3326552955459296e-07


100%|██████████| 352/352 [00:03<00:00, 107.36it/s]


TRAIN: EPOCH 371: MSE: 0.008821819404891523, KL_LOSS: 4.298043785927968e-07


100%|██████████| 352/352 [00:03<00:00, 105.48it/s]


TRAIN: EPOCH 372: MSE: 0.00881829129701311, KL_LOSS: 4.277628942778671e-07


100%|██████████| 352/352 [00:03<00:00, 105.32it/s]


TRAIN: EPOCH 373: MSE: 0.008822167575338179, KL_LOSS: 4.3346687080129825e-07


100%|██████████| 352/352 [00:03<00:00, 105.75it/s]


TRAIN: EPOCH 374: MSE: 0.008822372939903289, KL_LOSS: 4.337907101808531e-07


100%|██████████| 352/352 [00:03<00:00, 105.50it/s]


TRAIN: EPOCH 375: MSE: 0.008835149300549265, KL_LOSS: 4.38165776469904e-07


100%|██████████| 352/352 [00:03<00:00, 106.10it/s]


TRAIN: EPOCH 376: MSE: 0.00883279890562831, KL_LOSS: 4.3939614040185276e-07


100%|██████████| 352/352 [00:03<00:00, 105.94it/s]


TRAIN: EPOCH 377: MSE: 0.00882734469698996, KL_LOSS: 4.464607185768314e-07


100%|██████████| 352/352 [00:03<00:00, 105.55it/s]


TRAIN: EPOCH 378: MSE: 0.008838681735638105, KL_LOSS: 4.461949142286027e-07


100%|██████████| 352/352 [00:03<00:00, 105.56it/s]


TRAIN: EPOCH 379: MSE: 0.008838294239301997, KL_LOSS: 4.448537394613988e-07


100%|██████████| 352/352 [00:03<00:00, 106.34it/s]


TRAIN: EPOCH 380: MSE: 0.008815672052812508, KL_LOSS: 4.4292145636019725e-07


 18%|█▊        | 64/352 [00:00<00:02, 106.47it/s]


KeyboardInterrupt: 

In [10]:
df_to_be_shown=encode(autoencoder,dataset,device)

  return {'x':torch.tensor(self.data.X[idx]),'c':torch.tensor(self.data.obs.iloc[idx]['core_control'])}
100%|██████████| 11258/11258 [00:08<00:00, 1371.61it/s]


### We get the cosine similarity of every two perturabation

In [11]:
cos_sim_f=cosine_similarity(np.array(df_to_be_shown.drop(['control'], axis=1)))

### Now we want to know which two perturabations are similar

In [12]:
similarity_matrix=np.zeros(cos_sim_f.shape)
similarity_db=hu_data_loader()

for gene_name in tqdm.tqdm(adata_orig.obs.gene_name.unique()):
    query=query_hu_data(similarity_db,gene_name)
    for q in query:
        if q in adata_orig.obs.gene_name.values:
            y_indices=adata_orig.obs[adata_orig.obs.gene_name==q].id
            x_indices=adata_orig.obs[adata_orig.obs.gene_name==gene_name].id
            for x_id in x_indices:
                for y_id in y_indices:
                    similarity_matrix[y_id,x_id]=1
                    similarity_matrix[x_id,y_id]=1

cos_sim_f_flatten=cos_sim_f.reshape(-1,)
similarity_matrix_flatten=similarity_matrix.reshape(-1,)
cos_sim_f_flatten1=cos_sim_f_flatten[similarity_matrix_flatten==1]
cos_sim_f_flatten0=cos_sim_f_flatten[similarity_matrix_flatten==0]

File downloaded successfully to humap2_complexes_20200809.txt


100%|██████████| 9867/9867 [10:09<00:00, 16.20it/s]  


### We want to visualize the value of recall with respect to different quantiles as thresholds for similarities

In [15]:
def get_recall(rate):
    qrate_down=np.quantile(cos_sim_f_flatten,rate)
    qrate_up=np.quantile(cos_sim_f_flatten,1-rate)
    pred_p=np.logical_or(cos_sim_f_flatten>qrate_up,cos_sim_f_flatten<qrate_down)
    pred_n=np.logical_and(cos_sim_f_flatten<qrate_up,cos_sim_f_flatten>qrate_down)
    tp=np.logical_and(pred_p,similarity_matrix_flatten==1).sum()
    fp=np.logical_and(pred_p,similarity_matrix_flatten==0).sum()
    fn=np.logical_and(pred_n,similarity_matrix_flatten==1).sum()
    return tp/(tp+fn)
def visualize_recal_vs_quantile():
    values=[]
    xs=[i*0.05 for i in range(10)]
    for i in xs:
        values.append(get_recall(i))
    temp_df=pd.DataFrame({'quantile':xs,'recall':values})
    fig=px.line(temp_df,x='quantile',y='recall',title='recall_vs_quantile',width=1000, height=400)
    fig.update_traces(mode='lines+text', text=list(map(lambda x:round(x,2),values)), textposition='top center')
    fig.update_layout(
    font=dict(
        family="Arial, sans-serif",
        size=10,  # Set the desired font size
        color="black"
    )
)
    fig.show()
    

In [16]:
visualize_recal_vs_quantile()

### Okay. now we plot the distributions of similarities divided into classes. first, those pairs that are already know to be similar. Second those that are not.

In [None]:
cos_sim_f_flatten1

Unnamed: 0,correlations
0,-0.055764
1,-0.435485
2,0.935924
3,0.010324
4,0.725961
...,...
2043,0.726085
2044,-0.377220
2045,0.592346
2046,-0.030000


In [None]:
choice=np.random.choice(cos_sim_f_flatten1.shape[0], 2048)
cos_sim_f_flatten1=cos_sim_f_flatten1[choice]
cos_sim_f_flatten1=pd.DataFrame(cos_sim_f_flatten1,columns=['correlations'])
fig=px.violin(cos_sim_f_flatten1, y='correlations',width=500, height=400,title="SIMILARS")
fig.show()


choice=np.random.choice(cos_sim_f_flatten0.shape[0], 2048)
cos_sim_f_flatten0=cos_sim_f_flatten0[choice]
cos_sim_f_flatten0=pd.DataFrame(cos_sim_f_flatten0,columns=['correlations'])
fig=px.violin(cos_sim_f_flatten0, y='correlations',width=500, height=400,title="Not SIMILARS")
fig.show()

In [None]:
print("Not SIMILARS MEAN:",cos_sim_f_flatten0.mean())
print("SIMILARS MEAN:",cos_sim_f_flatten1.mean())

Not SIMILARS MEAN: correlations    0.015807
dtype: float32
SIMILARS MEAN: correlations    0.175297
dtype: float32


### Here is the visualiztion of feature vectors

In [None]:
fig=px.scatter(df_to_be_shown,x='f0',y='f1',color='control',width=500, height=400)
fig.show()
fig=px.scatter(df_to_be_shown,x='f2',y='f3',color='control',width=500, height=400)
fig.show()
fig=px.scatter(df_to_be_shown,x='f4',y='f5',color='control',width=500, height=400)
fig.show()











