In [1]:
%load_ext autoreload
# %reload_ext autoreload
%autoreload 2


In [2]:
from methylVA.mnist.dataset import get_methyl_data_loaders

data_id = 0.05
batch_size = 128

train_data_path = f"../data/dimension_reduction/highly_variable_features/train_data_filtered_{data_id}.pkl"
train_metadata_path = f"../data/dimension_reduction/highly_variable_features/train_metadata_with_labels.pkl"
test_data_path = f"../data/dimension_reduction/highly_variable_features/test_data_filtered_{data_id}.pkl"
test_metadata_path = f"../data/dimension_reduction/highly_variable_features/test_metadata_with_labels.pkl"


train_loader, test_loader = get_methyl_data_loaders(
    train_data_path,
    train_metadata_path,
    test_data_path,
    test_metadata_path,
    batch_size=batch_size
)


Found NaN values in the data after conversion.
Found NaN values in the data after conversion.


In [3]:
data_batch, _ = next(iter(train_loader))


num_train_rows = len(train_loader.dataset)
num_test_rows = len(test_loader.dataset)

print("Number of features in each dataset:", data_batch.shape[1])
print("Number of rows in the training dataset:", num_train_rows)
print("Number of rows in the test dataset:", num_test_rows)

Number of features in each dataset: 30579
Number of rows in the training dataset: 33360
Number of rows in the test dataset: 3707


In [4]:
from datetime import datetime

import torch
from torch.utils.tensorboard import SummaryWriter

from methylVA.mnist.model import VAE
from methylVA.mnist.training import train, test

input_dim = data_batch.shape[1]
learning_rate = 1e-3
weight_decay = 1e-2
num_epochs = 100
latent_dim = 32
hidden_dim = 2048
kl_weight = 1.0
name = f'VAE_methyl_data_{data_id}_latent_{latent_dim}_kl_{kl_weight}'



In [5]:

writer_train = SummaryWriter(f'../experiments/{name}/train/{datetime.now().strftime("%Y%m%d-%H%M%S")}')
writer_test = SummaryWriter(f'../experiments/{name}/test/{datetime.now().strftime("%Y%m%d-%H%M%S")}')

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE(input_dim=input_dim, latent_dim=latent_dim, hidden_dim=hidden_dim).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [7]:
model

VAE(
  (encoder): Sequential(
    (0): Linear(in_features=30579, out_features=2048, bias=True)
    (1): SiLU()
    (2): Linear(in_features=2048, out_features=1024, bias=True)
    (3): SiLU()
    (4): Linear(in_features=1024, out_features=512, bias=True)
    (5): SiLU()
    (6): Linear(in_features=512, out_features=256, bias=True)
    (7): SiLU()
    (8): Linear(in_features=256, out_features=64, bias=True)
  )
  (softplus): Softplus(beta=1.0, threshold=20.0)
  (decoder): Sequential(
    (0): Linear(in_features=32, out_features=256, bias=True)
    (1): SiLU()
    (2): Linear(in_features=256, out_features=512, bias=True)
    (3): SiLU()
    (4): Linear(in_features=512, out_features=1024, bias=True)
    (5): SiLU()
    (6): Linear(in_features=1024, out_features=2048, bias=True)
    (7): SiLU()
    (8): Linear(in_features=2048, out_features=30579, bias=True)
    (9): Sigmoid()
  )
)

In [None]:
from methylVA.mnist.training import train, test


prev_updates = 0
for epoch in range(num_epochs):
    print(f'Epoch {epoch + 1}/{num_epochs}')
    prev_updates = train(model, train_loader, optimizer, prev_updates, writer=writer_train)
    test(model, test_loader, prev_updates, writer=writer_test)

Epoch 1/100


  0%|          | 1/261 [00:00<01:56,  2.24it/s]

Step 0, (N samples: 0), Loss: 21200.0156, (Recon: 21196.5586, KLD: 3.4579), Gradient norm: 37.3112


 39%|███▉      | 102/261 [00:16<00:24,  6.47it/s]

Step 100, (N samples: 12,800), Loss: 16289.8965, (Recon: 16254.0186, KLD: 35.8778), Gradient norm: 2319.5700


 77%|███████▋  | 202/261 [00:31<00:09,  6.49it/s]

Step 200, (N samples: 25,600), Loss: 15100.0098, (Recon: 15058.5586, KLD: 41.4514), Gradient norm: 2870.6572


100%|██████████| 261/261 [00:40<00:00,  6.45it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.86it/s]


====> Test set loss: 15098.5752, (BCE: 15062.2726, KLD: 36.3025)
Epoch 2/100


 16%|█▌        | 41/261 [00:06<00:33,  6.50it/s]

Step 300, (N samples: 38,400), Loss: 15228.2764, (Recon: 15191.3359, KLD: 36.9402), Gradient norm: 1434.0466


 54%|█████▍    | 141/261 [00:21<00:18,  6.49it/s]

Step 400, (N samples: 51,200), Loss: 14710.9736, (Recon: 14671.6191, KLD: 39.3544), Gradient norm: 2204.1827


 92%|█████████▏| 241/261 [00:37<00:03,  6.49it/s]

Step 500, (N samples: 64,000), Loss: 14711.6113, (Recon: 14676.3555, KLD: 35.2556), Gradient norm: 1922.0993


100%|██████████| 261/261 [00:40<00:00,  6.45it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.91it/s]


====> Test set loss: 14771.5330, (BCE: 14737.1849, KLD: 34.3480)
Epoch 3/100


 31%|███       | 80/261 [00:12<00:28,  6.46it/s]

Step 600, (N samples: 76,800), Loss: 14679.0947, (Recon: 14644.3359, KLD: 34.7591), Gradient norm: 1361.1615


 69%|██████▉   | 180/261 [00:27<00:12,  6.49it/s]

Step 700, (N samples: 89,600), Loss: 14888.4512, (Recon: 14858.7051, KLD: 29.7460), Gradient norm: 1340.6814


100%|██████████| 261/261 [00:40<00:00,  6.50it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.86it/s]


====> Test set loss: 14617.3708, (BCE: 14587.0468, KLD: 30.3240)
Epoch 4/100


  7%|▋         | 19/261 [00:02<00:37,  6.48it/s]

Step 800, (N samples: 102,400), Loss: 14742.1992, (Recon: 14712.4688, KLD: 29.7300), Gradient norm: 1701.4440


 46%|████▌     | 119/261 [00:18<00:22,  6.44it/s]

Step 900, (N samples: 115,200), Loss: 14208.9551, (Recon: 14175.7383, KLD: 33.2164), Gradient norm: 1260.0564


 84%|████████▍ | 219/261 [00:33<00:06,  6.48it/s]

Step 1,000, (N samples: 128,000), Loss: 14624.5791, (Recon: 14594.8867, KLD: 29.6925), Gradient norm: 1233.6325


100%|██████████| 261/261 [00:40<00:00,  6.49it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.85it/s]


====> Test set loss: 14538.4405, (BCE: 14509.7443, KLD: 28.6962)
Epoch 5/100


 22%|██▏       | 58/261 [00:08<00:31,  6.50it/s]

Step 1,100, (N samples: 140,800), Loss: 14567.2803, (Recon: 14537.9414, KLD: 29.3385), Gradient norm: 961.6088


 61%|██████    | 158/261 [00:24<00:15,  6.49it/s]

Step 1,200, (N samples: 153,600), Loss: 14936.1758, (Recon: 14907.7695, KLD: 28.4063), Gradient norm: 1391.5014


 99%|█████████▉| 258/261 [00:40<00:00,  6.47it/s]

Step 1,300, (N samples: 166,400), Loss: 13998.5293, (Recon: 13971.0439, KLD: 27.4849), Gradient norm: 1270.1564


100%|██████████| 261/261 [00:40<00:00,  6.45it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.87it/s]


====> Test set loss: 14485.8215, (BCE: 14459.1301, KLD: 26.6914)
Epoch 6/100


 37%|███▋      | 97/261 [00:14<00:25,  6.47it/s]

Step 1,400, (N samples: 179,200), Loss: 14625.8760, (Recon: 14599.0332, KLD: 26.8426), Gradient norm: 1056.5050


 75%|███████▌  | 197/261 [00:30<00:09,  6.51it/s]

Step 1,500, (N samples: 192,000), Loss: 14254.8340, (Recon: 14229.4180, KLD: 25.4163), Gradient norm: 1560.3576


100%|██████████| 261/261 [00:40<00:00,  6.51it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.85it/s]


====> Test set loss: 14450.7533, (BCE: 14425.2298, KLD: 25.5234)
Epoch 7/100


 14%|█▍        | 36/261 [00:05<00:34,  6.50it/s]

Step 1,600, (N samples: 204,800), Loss: 14407.5840, (Recon: 14382.8203, KLD: 24.7638), Gradient norm: 973.9191


 52%|█████▏    | 136/261 [00:20<00:19,  6.48it/s]

Step 1,700, (N samples: 217,600), Loss: 14421.2344, (Recon: 14396.0381, KLD: 25.1968), Gradient norm: 1593.5534


 90%|█████████ | 236/261 [00:36<00:03,  6.42it/s]

Step 1,800, (N samples: 230,400), Loss: 14296.9082, (Recon: 14272.3535, KLD: 24.5546), Gradient norm: 1090.9139


100%|██████████| 261/261 [00:40<00:00,  6.51it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.85it/s]


====> Test set loss: 14431.6342, (BCE: 14407.2107, KLD: 24.4235)
Epoch 8/100


 29%|██▊       | 75/261 [00:11<00:28,  6.46it/s]

Step 1,900, (N samples: 243,200), Loss: 14415.4160, (Recon: 14391.0537, KLD: 24.3619), Gradient norm: 1013.2750


 67%|██████▋   | 175/261 [00:27<00:13,  6.48it/s]

Step 2,000, (N samples: 256,000), Loss: 14268.1631, (Recon: 14242.8076, KLD: 25.3550), Gradient norm: 1112.4516


100%|██████████| 261/261 [00:40<00:00,  6.45it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.89it/s]


====> Test set loss: 14404.8281, (BCE: 14380.7396, KLD: 24.0885)
Epoch 9/100


  5%|▌         | 14/261 [00:02<00:38,  6.37it/s]

Step 2,100, (N samples: 268,800), Loss: 14169.2568, (Recon: 14140.6523, KLD: 28.6044), Gradient norm: 1493.3987


 44%|████▎     | 114/261 [00:17<00:22,  6.47it/s]

Step 2,200, (N samples: 281,600), Loss: 14160.8271, (Recon: 14136.8320, KLD: 23.9950), Gradient norm: 1241.2682


 82%|████████▏ | 214/261 [00:32<00:07,  6.42it/s]

Step 2,300, (N samples: 294,400), Loss: 14646.7812, (Recon: 14623.7539, KLD: 23.0278), Gradient norm: 710.2986


100%|██████████| 261/261 [00:40<00:00,  6.50it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.85it/s]


====> Test set loss: 14389.8270, (BCE: 14366.4934, KLD: 23.3336)
Epoch 10/100


 20%|██        | 53/261 [00:08<00:32,  6.48it/s]

Step 2,400, (N samples: 307,200), Loss: 14511.2598, (Recon: 14487.7871, KLD: 23.4723), Gradient norm: 1299.9365


 59%|█████▊    | 153/261 [00:23<00:16,  6.42it/s]

Step 2,500, (N samples: 320,000), Loss: 14415.5068, (Recon: 14391.7305, KLD: 23.7761), Gradient norm: 1254.2073


 97%|█████████▋| 253/261 [00:38<00:01,  6.48it/s]

Step 2,600, (N samples: 332,800), Loss: 14340.1465, (Recon: 14316.1660, KLD: 23.9804), Gradient norm: 1172.1759


100%|██████████| 261/261 [00:40<00:00,  6.49it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.82it/s]


====> Test set loss: 14373.3942, (BCE: 14349.7942, KLD: 23.6000)
Epoch 11/100


 35%|███▌      | 92/261 [00:14<00:26,  6.47it/s]

Step 2,700, (N samples: 345,600), Loss: 14028.0791, (Recon: 14005.0625, KLD: 23.0170), Gradient norm: 1402.8624


 74%|███████▎  | 192/261 [00:29<00:10,  6.45it/s]

Step 2,800, (N samples: 358,400), Loss: 14211.0254, (Recon: 14187.6025, KLD: 23.4228), Gradient norm: 1306.8476


100%|██████████| 261/261 [00:40<00:00,  6.44it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.90it/s]


====> Test set loss: 14359.7331, (BCE: 14336.3708, KLD: 23.3623)
Epoch 12/100


 12%|█▏        | 31/261 [00:04<00:35,  6.48it/s]

Step 2,900, (N samples: 371,200), Loss: 14290.2568, (Recon: 14264.1689, KLD: 26.0876), Gradient norm: 1069.3081


 50%|█████     | 131/261 [00:20<00:20,  6.48it/s]

Step 3,000, (N samples: 384,000), Loss: 14149.4199, (Recon: 14126.2031, KLD: 23.2163), Gradient norm: 954.2256


 89%|████████▊ | 231/261 [00:35<00:04,  6.50it/s]

Step 3,100, (N samples: 396,800), Loss: 14247.5605, (Recon: 14223.6729, KLD: 23.8879), Gradient norm: 1410.9940


100%|██████████| 261/261 [00:40<00:00,  6.52it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.85it/s]


====> Test set loss: 14351.7991, (BCE: 14328.1671, KLD: 23.6319)
Epoch 13/100


 27%|██▋       | 70/261 [00:10<00:29,  6.49it/s]

Step 3,200, (N samples: 409,600), Loss: 14217.9824, (Recon: 14193.4004, KLD: 24.5818), Gradient norm: 920.8006


 65%|██████▌   | 170/261 [00:26<00:13,  6.51it/s]

Step 3,300, (N samples: 422,400), Loss: 14502.3086, (Recon: 14480.0312, KLD: 22.2775), Gradient norm: 1129.3687


100%|██████████| 261/261 [00:40<00:00,  6.51it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.85it/s]


====> Test set loss: 14345.9541, (BCE: 14322.6302, KLD: 23.3240)
Epoch 14/100


  3%|▎         | 9/261 [00:01<00:39,  6.45it/s]

Step 3,400, (N samples: 435,200), Loss: 14476.5254, (Recon: 14451.5732, KLD: 24.9517), Gradient norm: 1048.5199


 42%|████▏     | 109/261 [00:16<00:23,  6.47it/s]

Step 3,500, (N samples: 448,000), Loss: 14320.5195, (Recon: 14295.0430, KLD: 25.4768), Gradient norm: 1400.2084


 80%|████████  | 209/261 [00:32<00:08,  6.50it/s]

Step 3,600, (N samples: 460,800), Loss: 14291.1650, (Recon: 14266.4951, KLD: 24.6700), Gradient norm: 1404.3489


100%|██████████| 261/261 [00:40<00:00,  6.51it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.92it/s]


====> Test set loss: 14325.3167, (BCE: 14302.3831, KLD: 22.9337)
Epoch 15/100


 18%|█▊        | 48/261 [00:07<00:32,  6.49it/s]

Step 3,700, (N samples: 473,600), Loss: 14522.2500, (Recon: 14499.0674, KLD: 23.1822), Gradient norm: 1648.7490


 57%|█████▋    | 148/261 [00:22<00:17,  6.52it/s]

Step 3,800, (N samples: 486,400), Loss: 14376.1650, (Recon: 14353.1953, KLD: 22.9697), Gradient norm: 1099.8560


 95%|█████████▌| 248/261 [00:38<00:01,  6.51it/s]

Step 3,900, (N samples: 499,200), Loss: 14003.1650, (Recon: 13978.4141, KLD: 24.7513), Gradient norm: 1659.1760


100%|██████████| 261/261 [00:40<00:00,  6.51it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.92it/s]


====> Test set loss: 14322.2313, (BCE: 14298.9269, KLD: 23.3044)
Epoch 16/100


 33%|███▎      | 87/261 [00:13<00:26,  6.45it/s]

Step 4,000, (N samples: 512,000), Loss: 14180.3564, (Recon: 14156.7646, KLD: 23.5919), Gradient norm: 1037.3952


 72%|███████▏  | 187/261 [00:28<00:11,  6.45it/s]

Step 4,100, (N samples: 524,800), Loss: 14618.0840, (Recon: 14594.1621, KLD: 23.9222), Gradient norm: 1277.1258


100%|██████████| 261/261 [00:40<00:00,  6.50it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.82it/s]


====> Test set loss: 14317.7447, (BCE: 14294.1358, KLD: 23.6089)
Epoch 17/100


 10%|▉         | 26/261 [00:04<00:36,  6.47it/s]

Step 4,200, (N samples: 537,600), Loss: 14174.1152, (Recon: 14151.1895, KLD: 22.9262), Gradient norm: 1212.4997


 48%|████▊     | 126/261 [00:19<00:20,  6.49it/s]

Step 4,300, (N samples: 550,400), Loss: 14724.9600, (Recon: 14700.5215, KLD: 24.4386), Gradient norm: 1113.9334


 87%|████████▋ | 226/261 [00:34<00:05,  6.38it/s]

Step 4,400, (N samples: 563,200), Loss: 13864.9658, (Recon: 13839.6191, KLD: 25.3471), Gradient norm: 1697.1073


100%|██████████| 261/261 [00:40<00:00,  6.49it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.81it/s]


====> Test set loss: 14308.4990, (BCE: 14285.2694, KLD: 23.2295)
Epoch 18/100


 25%|██▍       | 65/261 [00:10<00:30,  6.49it/s]

Step 4,500, (N samples: 576,000), Loss: 14346.5771, (Recon: 14323.3330, KLD: 23.2437), Gradient norm: 1579.0071


 63%|██████▎   | 165/261 [00:25<00:14,  6.50it/s]

Step 4,600, (N samples: 588,800), Loss: 13986.8789, (Recon: 13963.0693, KLD: 23.8101), Gradient norm: 879.3552


100%|██████████| 261/261 [00:40<00:00,  6.52it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.93it/s]


====> Test set loss: 14316.4295, (BCE: 14293.3835, KLD: 23.0459)
Epoch 19/100


  2%|▏         | 4/261 [00:00<00:41,  6.18it/s]

Step 4,700, (N samples: 601,600), Loss: 14151.2314, (Recon: 14128.5840, KLD: 22.6476), Gradient norm: 1635.2482


 40%|███▉      | 104/261 [00:15<00:24,  6.50it/s]

Step 4,800, (N samples: 614,400), Loss: 14185.4902, (Recon: 14161.9824, KLD: 23.5075), Gradient norm: 1179.8525


 78%|███████▊  | 204/261 [00:31<00:08,  6.49it/s]

Step 4,900, (N samples: 627,200), Loss: 14552.5107, (Recon: 14530.0820, KLD: 22.4285), Gradient norm: 812.8032


100%|██████████| 261/261 [00:40<00:00,  6.51it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.87it/s]


====> Test set loss: 14305.5335, (BCE: 14282.5684, KLD: 22.9652)
Epoch 20/100


 16%|█▋        | 43/261 [00:06<00:33,  6.48it/s]

Step 5,000, (N samples: 640,000), Loss: 14407.0908, (Recon: 14384.0273, KLD: 23.0635), Gradient norm: 1528.0885


 55%|█████▍    | 143/261 [00:21<00:18,  6.49it/s]

Step 5,100, (N samples: 652,800), Loss: 14346.7852, (Recon: 14323.0918, KLD: 23.6933), Gradient norm: 1170.3813


 93%|█████████▎| 243/261 [00:37<00:02,  6.49it/s]

Step 5,200, (N samples: 665,600), Loss: 13973.0977, (Recon: 13949.0322, KLD: 24.0650), Gradient norm: 1058.3553


100%|██████████| 261/261 [00:40<00:00,  6.51it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.88it/s]


====> Test set loss: 14304.0950, (BCE: 14281.3221, KLD: 22.7728)
Epoch 21/100


 31%|███▏      | 82/261 [00:12<00:27,  6.49it/s]

Step 5,300, (N samples: 678,400), Loss: 14401.4014, (Recon: 14378.1406, KLD: 23.2609), Gradient norm: 1605.3578


 70%|██████▉   | 182/261 [00:27<00:12,  6.51it/s]

Step 5,400, (N samples: 691,200), Loss: 14220.4697, (Recon: 14198.1016, KLD: 22.3684), Gradient norm: 1080.0271


100%|██████████| 261/261 [00:40<00:00,  6.43it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.89it/s]


====> Test set loss: 14298.9077, (BCE: 14276.0905, KLD: 22.8173)
Epoch 22/100


  8%|▊         | 21/261 [00:03<00:37,  6.48it/s]

Step 5,500, (N samples: 704,000), Loss: 14432.8359, (Recon: 14411.3301, KLD: 21.5063), Gradient norm: 1409.5625


 46%|████▋     | 121/261 [00:19<00:21,  6.42it/s]

Step 5,600, (N samples: 716,800), Loss: 14275.2354, (Recon: 14251.7480, KLD: 23.4873), Gradient norm: 1357.4637


 85%|████████▍ | 221/261 [00:36<00:06,  6.49it/s]

Step 5,700, (N samples: 729,600), Loss: 14358.3105, (Recon: 14335.7656, KLD: 22.5450), Gradient norm: 936.2809


100%|██████████| 261/261 [00:42<00:00,  6.20it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.86it/s]


====> Test set loss: 14293.9190, (BCE: 14270.6634, KLD: 23.2556)
Epoch 23/100


 23%|██▎       | 60/261 [00:09<00:31,  6.44it/s]

Step 5,800, (N samples: 742,400), Loss: 14199.1738, (Recon: 14176.2324, KLD: 22.9413), Gradient norm: 1442.9929


 61%|██████▏   | 160/261 [00:25<00:15,  6.48it/s]

Step 5,900, (N samples: 755,200), Loss: 14097.7686, (Recon: 14074.7012, KLD: 23.0671), Gradient norm: 1197.0034


100%|█████████▉| 260/261 [00:41<00:00,  6.33it/s]

Step 6,000, (N samples: 768,000), Loss: 14146.7002, (Recon: 14124.0371, KLD: 22.6628), Gradient norm: 1482.5997


100%|██████████| 261/261 [00:41<00:00,  6.25it/s]
Testing: 100%|██████████| 29/29 [00:03<00:00,  8.93it/s]


====> Test set loss: 14293.5010, (BCE: 14270.2527, KLD: 23.2482)
Epoch 24/100


 38%|███▊      | 99/261 [00:19<00:25,  6.35it/s]

Step 6,100, (N samples: 780,800), Loss: 14121.5742, (Recon: 14099.0703, KLD: 22.5043), Gradient norm: 864.9671


 76%|███████▌  | 199/261 [00:35<00:09,  6.50it/s]

Step 6,200, (N samples: 793,600), Loss: 14520.1699, (Recon: 14497.0762, KLD: 23.0941), Gradient norm: 1193.9546


100%|██████████| 261/261 [00:45<00:00,  5.79it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.86it/s]


====> Test set loss: 14287.6251, (BCE: 14264.5550, KLD: 23.0701)
Epoch 25/100


 15%|█▍        | 38/261 [00:05<00:34,  6.46it/s]

Step 6,300, (N samples: 806,400), Loss: 13667.1270, (Recon: 13642.4590, KLD: 24.6680), Gradient norm: 1486.8700


 53%|█████▎    | 138/261 [00:21<00:19,  6.45it/s]

Step 6,400, (N samples: 819,200), Loss: 14323.4717, (Recon: 14300.0879, KLD: 23.3841), Gradient norm: 1472.6873


 91%|█████████ | 237/261 [00:37<00:03,  6.41it/s]

Step 6,500, (N samples: 832,000), Loss: 13993.5029, (Recon: 13970.3174, KLD: 23.1852), Gradient norm: 1306.9491


100%|██████████| 261/261 [00:42<00:00,  6.12it/s]
Testing: 100%|██████████| 29/29 [00:02<00:00,  9.86it/s]


====> Test set loss: 14281.1640, (BCE: 14258.0865, KLD: 23.0776)
Epoch 26/100


 30%|██▉       | 77/261 [00:11<00:28,  6.49it/s]

Step 6,600, (N samples: 844,800), Loss: 14308.9922, (Recon: 14286.3730, KLD: 22.6189), Gradient norm: 1295.1746


 68%|██████▊   | 177/261 [00:27<00:12,  6.49it/s]

Step 6,700, (N samples: 857,600), Loss: 14372.8135, (Recon: 14349.2793, KLD: 23.5341), Gradient norm: 1615.4812


100%|██████████| 261/261 [00:40<00:00,  6.45it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.92it/s]


====> Test set loss: 14289.5816, (BCE: 14266.7221, KLD: 22.8594)
Epoch 27/100


  6%|▌         | 16/261 [00:02<00:38,  6.40it/s]

Step 6,800, (N samples: 870,400), Loss: 13813.7041, (Recon: 13788.9990, KLD: 24.7048), Gradient norm: 1609.9620


 44%|████▍     | 115/261 [00:24<00:38,  3.76it/s]

Step 6,900, (N samples: 883,200), Loss: 14214.3027, (Recon: 14191.5361, KLD: 22.7667), Gradient norm: 1578.5917


 83%|████████▎ | 216/261 [00:49<00:07,  6.18it/s]

Step 7,000, (N samples: 896,000), Loss: 14427.3857, (Recon: 14404.3398, KLD: 23.0461), Gradient norm: 1441.0975


100%|██████████| 261/261 [00:56<00:00,  4.63it/s]
Testing: 100%|██████████| 29/29 [00:03<00:00,  9.47it/s]


====> Test set loss: 14275.9235, (BCE: 14252.9022, KLD: 23.0212)
Epoch 28/100


 21%|██        | 55/261 [00:08<00:31,  6.47it/s]

Step 7,100, (N samples: 908,800), Loss: 14130.1318, (Recon: 14106.5742, KLD: 23.5574), Gradient norm: 1722.5246


 59%|█████▉    | 155/261 [00:23<00:16,  6.48it/s]

Step 7,200, (N samples: 921,600), Loss: 14017.7188, (Recon: 13994.5195, KLD: 23.1992), Gradient norm: 1196.7773


 98%|█████████▊| 255/261 [00:40<00:00,  6.31it/s]

Step 7,300, (N samples: 934,400), Loss: 14119.3975, (Recon: 14093.1328, KLD: 26.2647), Gradient norm: 1356.5771


100%|██████████| 261/261 [00:41<00:00,  6.34it/s]
Testing: 100%|██████████| 29/29 [00:03<00:00,  9.03it/s]


====> Test set loss: 14279.4518, (BCE: 14255.8902, KLD: 23.5616)
Epoch 29/100


 36%|███▌      | 93/261 [00:24<00:44,  3.74it/s]

Step 7,400, (N samples: 947,200), Loss: 14147.3418, (Recon: 14123.1797, KLD: 24.1621), Gradient norm: 953.2465


 74%|███████▍  | 193/261 [00:51<00:17,  3.98it/s]

Step 7,500, (N samples: 960,000), Loss: 14005.8799, (Recon: 13982.7266, KLD: 23.1530), Gradient norm: 1397.5509


100%|██████████| 261/261 [01:10<00:00,  3.73it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.02it/s]


====> Test set loss: 14272.7229, (BCE: 14249.2970, KLD: 23.4258)
Epoch 30/100


 12%|█▏        | 32/261 [00:08<01:02,  3.66it/s]

Step 7,600, (N samples: 972,800), Loss: 14000.2959, (Recon: 13977.3301, KLD: 22.9658), Gradient norm: 1026.8315


 51%|█████     | 132/261 [00:35<00:35,  3.66it/s]

Step 7,700, (N samples: 985,600), Loss: 14049.9688, (Recon: 14025.9951, KLD: 23.9739), Gradient norm: 1190.7928


 89%|████████▉ | 232/261 [01:03<00:07,  3.73it/s]

Step 7,800, (N samples: 998,400), Loss: 14641.8994, (Recon: 14619.7402, KLD: 22.1591), Gradient norm: 1272.2811


100%|██████████| 261/261 [01:10<00:00,  3.69it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  5.98it/s]


====> Test set loss: 14269.0412, (BCE: 14245.9806, KLD: 23.0606)
Epoch 31/100


 27%|██▋       | 71/261 [00:19<00:52,  3.61it/s]

Step 7,900, (N samples: 1,011,200), Loss: 14104.5000, (Recon: 14080.8428, KLD: 23.6575), Gradient norm: 1278.7938


 66%|██████▌   | 171/261 [00:46<00:24,  3.64it/s]

Step 8,000, (N samples: 1,024,000), Loss: 14450.0850, (Recon: 14426.6309, KLD: 23.4538), Gradient norm: 1147.4875


100%|██████████| 261/261 [01:10<00:00,  3.73it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  5.98it/s]


====> Test set loss: 14273.3968, (BCE: 14250.1104, KLD: 23.2864)
Epoch 32/100


  4%|▍         | 10/261 [00:02<01:07,  3.71it/s]

Step 8,100, (N samples: 1,036,800), Loss: 14225.1367, (Recon: 14202.1611, KLD: 22.9759), Gradient norm: 2001.6484


 42%|████▏     | 110/261 [00:30<00:42,  3.55it/s]

Step 8,200, (N samples: 1,049,600), Loss: 13777.8379, (Recon: 13754.4395, KLD: 23.3980), Gradient norm: 1984.8965


 80%|████████  | 210/261 [00:56<00:13,  3.70it/s]

Step 8,300, (N samples: 1,062,400), Loss: 14268.5195, (Recon: 14245.5840, KLD: 22.9351), Gradient norm: 1234.6162


100%|██████████| 261/261 [01:10<00:00,  3.71it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  5.94it/s]


====> Test set loss: 14263.8334, (BCE: 14240.4959, KLD: 23.3376)
Epoch 33/100


 19%|█▉        | 49/261 [00:13<00:57,  3.67it/s]

Step 8,400, (N samples: 1,075,200), Loss: 14421.2578, (Recon: 14399.0410, KLD: 22.2164), Gradient norm: 1404.6164


 57%|█████▋    | 149/261 [00:40<00:30,  3.71it/s]

Step 8,500, (N samples: 1,088,000), Loss: 14119.3975, (Recon: 14097.2969, KLD: 22.1007), Gradient norm: 927.8664


 95%|█████████▌| 249/261 [01:07<00:03,  3.60it/s]

Step 8,600, (N samples: 1,100,800), Loss: 13997.3730, (Recon: 13973.5996, KLD: 23.7736), Gradient norm: 1154.0133


100%|██████████| 261/261 [01:11<00:00,  3.67it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.37it/s]


====> Test set loss: 14271.2150, (BCE: 14247.8390, KLD: 23.3760)
Epoch 34/100


 34%|███▎      | 88/261 [00:24<00:47,  3.64it/s]

Step 8,700, (N samples: 1,113,600), Loss: 14342.1797, (Recon: 14318.1172, KLD: 24.0629), Gradient norm: 1464.0035


 72%|███████▏  | 188/261 [00:50<00:18,  4.01it/s]

Step 8,800, (N samples: 1,126,400), Loss: 14110.9512, (Recon: 14087.4580, KLD: 23.4930), Gradient norm: 1618.0514


100%|██████████| 261/261 [01:10<00:00,  3.72it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  5.98it/s]


====> Test set loss: 14265.7306, (BCE: 14242.1643, KLD: 23.5663)
Epoch 35/100


 10%|█         | 27/261 [00:07<01:04,  3.64it/s]

Step 8,900, (N samples: 1,139,200), Loss: 13985.7109, (Recon: 13962.9180, KLD: 22.7930), Gradient norm: 1178.5970


 49%|████▊     | 127/261 [00:34<00:36,  3.66it/s]

Step 9,000, (N samples: 1,152,000), Loss: 14173.6992, (Recon: 14150.5820, KLD: 23.1168), Gradient norm: 1282.8820


 87%|████████▋ | 227/261 [01:01<00:09,  3.73it/s]

Step 9,100, (N samples: 1,164,800), Loss: 13715.9229, (Recon: 13692.0605, KLD: 23.8625), Gradient norm: 1259.0526


100%|██████████| 261/261 [01:10<00:00,  3.69it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  5.98it/s]


====> Test set loss: 14263.6966, (BCE: 14240.5058, KLD: 23.1908)
Epoch 36/100


 25%|██▌       | 66/261 [00:17<00:52,  3.71it/s]

Step 9,200, (N samples: 1,177,600), Loss: 13990.7275, (Recon: 13965.8789, KLD: 24.8487), Gradient norm: 1145.6767


 64%|██████▎   | 166/261 [00:44<00:25,  3.73it/s]

Step 9,300, (N samples: 1,190,400), Loss: 14141.9717, (Recon: 14118.2402, KLD: 23.7319), Gradient norm: 914.9707


100%|██████████| 261/261 [01:09<00:00,  3.75it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  5.97it/s]


====> Test set loss: 14259.7860, (BCE: 14236.4116, KLD: 23.3743)
Epoch 37/100


  2%|▏         | 5/261 [00:01<01:09,  3.66it/s]

Step 9,400, (N samples: 1,203,200), Loss: 13910.6650, (Recon: 13886.6484, KLD: 24.0164), Gradient norm: 1115.9366


 40%|████      | 105/261 [00:28<00:42,  3.67it/s]

Step 9,500, (N samples: 1,216,000), Loss: 14373.8594, (Recon: 14349.5127, KLD: 24.3467), Gradient norm: 1637.7987


 79%|███████▊  | 205/261 [00:55<00:15,  3.68it/s]

Step 9,600, (N samples: 1,228,800), Loss: 14236.3252, (Recon: 14213.2520, KLD: 23.0730), Gradient norm: 1289.9335


100%|██████████| 261/261 [01:10<00:00,  3.68it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.06it/s]


====> Test set loss: 14259.0343, (BCE: 14235.2579, KLD: 23.7764)
Epoch 38/100


 17%|█▋        | 44/261 [00:11<00:58,  3.70it/s]

Step 9,700, (N samples: 1,241,600), Loss: 14593.1221, (Recon: 14569.5703, KLD: 23.5519), Gradient norm: 1420.7484


 55%|█████▌    | 144/261 [00:38<00:31,  3.67it/s]

Step 9,800, (N samples: 1,254,400), Loss: 14400.9932, (Recon: 14376.8301, KLD: 24.1635), Gradient norm: 1171.9926


 93%|█████████▎| 244/261 [01:05<00:04,  3.60it/s]

Step 9,900, (N samples: 1,267,200), Loss: 14291.8359, (Recon: 14267.1895, KLD: 24.6466), Gradient norm: 1722.7005


100%|██████████| 261/261 [01:10<00:00,  3.72it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.31it/s]


====> Test set loss: 14256.9523, (BCE: 14233.8163, KLD: 23.1360)
Epoch 39/100


 32%|███▏      | 83/261 [00:22<00:48,  3.70it/s]

Step 10,000, (N samples: 1,280,000), Loss: 14278.5879, (Recon: 14254.4922, KLD: 24.0959), Gradient norm: 1344.1665


 70%|███████   | 183/261 [00:49<00:19,  4.02it/s]

Step 10,100, (N samples: 1,292,800), Loss: 14469.4121, (Recon: 14447.0254, KLD: 22.3865), Gradient norm: 1405.0723


100%|██████████| 261/261 [01:10<00:00,  3.72it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.03it/s]


====> Test set loss: 14262.5562, (BCE: 14239.1589, KLD: 23.3974)
Epoch 40/100


  8%|▊         | 22/261 [00:05<01:06,  3.62it/s]

Step 10,200, (N samples: 1,305,600), Loss: 13847.9697, (Recon: 13824.5215, KLD: 23.4484), Gradient norm: 1294.6578


 47%|████▋     | 122/261 [00:32<00:37,  3.67it/s]

Step 10,300, (N samples: 1,318,400), Loss: 14396.9580, (Recon: 14373.2744, KLD: 23.6838), Gradient norm: 976.0204


 85%|████████▌ | 222/261 [00:59<00:10,  3.72it/s]

Step 10,400, (N samples: 1,331,200), Loss: 14177.9609, (Recon: 14154.1484, KLD: 23.8129), Gradient norm: 1284.9248


100%|██████████| 261/261 [01:09<00:00,  3.74it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.06it/s]


====> Test set loss: 14266.1942, (BCE: 14242.9602, KLD: 23.2340)
Epoch 41/100


 23%|██▎       | 61/261 [00:16<00:54,  3.65it/s]

Step 10,500, (N samples: 1,344,000), Loss: 14111.6016, (Recon: 14087.6504, KLD: 23.9515), Gradient norm: 1327.8374


 62%|██████▏   | 161/261 [00:43<00:27,  3.63it/s]

Step 10,600, (N samples: 1,356,800), Loss: 13721.2178, (Recon: 13697.8320, KLD: 23.3856), Gradient norm: 1270.0140


100%|██████████| 261/261 [01:09<00:00,  3.73it/s]


Step 10,700, (N samples: 1,369,600), Loss: 14361.0615, (Recon: 14338.0908, KLD: 22.9706), Gradient norm: 1776.7959


Testing: 100%|██████████| 29/29 [00:04<00:00,  6.01it/s]


====> Test set loss: 14256.1276, (BCE: 14232.9203, KLD: 23.2073)
Epoch 42/100


 38%|███▊      | 100/261 [00:26<00:43,  3.73it/s]

Step 10,800, (N samples: 1,382,400), Loss: 14146.3457, (Recon: 14124.1719, KLD: 22.1740), Gradient norm: 852.3873


 77%|███████▋  | 200/261 [00:53<00:16,  3.69it/s]

Step 10,900, (N samples: 1,395,200), Loss: 14238.7930, (Recon: 14213.7979, KLD: 24.9955), Gradient norm: 1234.0289


100%|██████████| 261/261 [01:09<00:00,  3.73it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.04it/s]


====> Test set loss: 14253.8900, (BCE: 14230.9149, KLD: 22.9751)
Epoch 43/100


 15%|█▍        | 39/261 [00:10<00:59,  3.75it/s]

Step 11,000, (N samples: 1,408,000), Loss: 14183.7402, (Recon: 14160.3789, KLD: 23.3609), Gradient norm: 1823.1391


 53%|█████▎    | 139/261 [00:37<00:32,  3.72it/s]

Step 11,100, (N samples: 1,420,800), Loss: 13963.3828, (Recon: 13939.5420, KLD: 23.8411), Gradient norm: 1966.3367


 92%|█████████▏| 239/261 [01:04<00:05,  3.68it/s]

Step 11,200, (N samples: 1,433,600), Loss: 13888.1074, (Recon: 13864.7109, KLD: 23.3960), Gradient norm: 1064.8830


100%|██████████| 261/261 [01:09<00:00,  3.73it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.22it/s]


====> Test set loss: 14256.1614, (BCE: 14232.9371, KLD: 23.2243)
Epoch 44/100


 30%|██▉       | 78/261 [00:20<00:49,  3.67it/s]

Step 11,300, (N samples: 1,446,400), Loss: 14346.7969, (Recon: 14323.2969, KLD: 23.4996), Gradient norm: 1543.8055


 68%|██████▊   | 178/261 [00:47<00:21,  3.86it/s]

Step 11,400, (N samples: 1,459,200), Loss: 14314.3027, (Recon: 14290.0156, KLD: 24.2871), Gradient norm: 1033.1370


100%|██████████| 261/261 [01:10<00:00,  3.72it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.08it/s]


====> Test set loss: 14256.6855, (BCE: 14233.5603, KLD: 23.1252)
Epoch 45/100


  7%|▋         | 17/261 [00:04<01:05,  3.72it/s]

Step 11,500, (N samples: 1,472,000), Loss: 13956.5850, (Recon: 13931.7832, KLD: 24.8013), Gradient norm: 1444.9546


 45%|████▍     | 117/261 [00:31<00:39,  3.68it/s]

Step 11,600, (N samples: 1,484,800), Loss: 14232.3389, (Recon: 14210.4961, KLD: 21.8429), Gradient norm: 863.8293


 83%|████████▎ | 217/261 [00:58<00:12,  3.66it/s]

Step 11,700, (N samples: 1,497,600), Loss: 14249.1621, (Recon: 14225.2793, KLD: 23.8830), Gradient norm: 2114.8857


100%|██████████| 261/261 [01:10<00:00,  3.73it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.00it/s]


====> Test set loss: 14257.1760, (BCE: 14233.9053, KLD: 23.2707)
Epoch 46/100


 21%|██▏       | 56/261 [00:14<00:56,  3.64it/s]

Step 11,800, (N samples: 1,510,400), Loss: 14093.8848, (Recon: 14070.6992, KLD: 23.1857), Gradient norm: 1448.0247


 60%|█████▉    | 156/261 [00:41<00:28,  3.73it/s]

Step 11,900, (N samples: 1,523,200), Loss: 14533.7861, (Recon: 14509.9326, KLD: 23.8538), Gradient norm: 2107.5083


 98%|█████████▊| 256/261 [01:08<00:01,  3.75it/s]

Step 12,000, (N samples: 1,536,000), Loss: 14225.2939, (Recon: 14201.8896, KLD: 23.4039), Gradient norm: 1408.5097


100%|██████████| 261/261 [01:09<00:00,  3.75it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.11it/s]


====> Test set loss: 14250.9887, (BCE: 14227.6738, KLD: 23.3150)
Epoch 47/100


 36%|███▋      | 95/261 [00:25<00:43,  3.79it/s]

Step 12,100, (N samples: 1,548,800), Loss: 14466.8936, (Recon: 14443.3389, KLD: 23.5551), Gradient norm: 1332.1281


 75%|███████▍  | 195/261 [00:52<00:17,  3.75it/s]

Step 12,200, (N samples: 1,561,600), Loss: 14159.4297, (Recon: 14136.4258, KLD: 23.0040), Gradient norm: 1601.9446


100%|██████████| 261/261 [01:09<00:00,  3.75it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  5.97it/s]


====> Test set loss: 14254.8863, (BCE: 14231.9349, KLD: 22.9514)
Epoch 48/100


 13%|█▎        | 34/261 [00:09<01:01,  3.68it/s]

Step 12,300, (N samples: 1,574,400), Loss: 14294.1934, (Recon: 14271.0889, KLD: 23.1043), Gradient norm: 1313.0302


 51%|█████▏    | 134/261 [00:35<00:33,  3.75it/s]

Step 12,400, (N samples: 1,587,200), Loss: 14458.5586, (Recon: 14434.9180, KLD: 23.6408), Gradient norm: 1547.5295


 90%|████████▉ | 234/261 [01:02<00:07,  3.70it/s]

Step 12,500, (N samples: 1,600,000), Loss: 13893.9375, (Recon: 13869.9102, KLD: 24.0274), Gradient norm: 1391.7645


100%|██████████| 261/261 [01:09<00:00,  3.75it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.20it/s]


====> Test set loss: 14252.4515, (BCE: 14229.1490, KLD: 23.3024)
Epoch 49/100


 28%|██▊       | 73/261 [00:19<00:51,  3.68it/s]

Step 12,600, (N samples: 1,612,800), Loss: 14326.8066, (Recon: 14302.6455, KLD: 24.1609), Gradient norm: 1331.6723


 66%|██████▋   | 173/261 [00:46<00:23,  3.76it/s]

Step 12,700, (N samples: 1,625,600), Loss: 14063.7012, (Recon: 14040.9355, KLD: 22.7653), Gradient norm: 1858.4011


100%|██████████| 261/261 [01:09<00:00,  3.75it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  5.96it/s]


====> Test set loss: 14264.3802, (BCE: 14240.8559, KLD: 23.5243)
Epoch 50/100


  5%|▍         | 12/261 [00:03<01:06,  3.73it/s]

Step 12,800, (N samples: 1,638,400), Loss: 13888.9102, (Recon: 13865.6543, KLD: 23.2558), Gradient norm: 1426.6598


 43%|████▎     | 112/261 [00:29<00:40,  3.65it/s]

Step 12,900, (N samples: 1,651,200), Loss: 14150.7930, (Recon: 14128.2275, KLD: 22.5655), Gradient norm: 1127.8444


 81%|████████  | 212/261 [00:56<00:13,  3.69it/s]

Step 13,000, (N samples: 1,664,000), Loss: 14325.3447, (Recon: 14301.1543, KLD: 24.1902), Gradient norm: 1980.5484


100%|██████████| 261/261 [01:09<00:00,  3.75it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.09it/s]


====> Test set loss: 14255.4893, (BCE: 14232.0448, KLD: 23.4445)
Epoch 51/100


 20%|█▉        | 51/261 [00:13<00:57,  3.68it/s]

Step 13,100, (N samples: 1,676,800), Loss: 14605.4980, (Recon: 14582.5977, KLD: 22.9005), Gradient norm: 1923.9494


 58%|█████▊    | 151/261 [00:40<00:29,  3.68it/s]

Step 13,200, (N samples: 1,689,600), Loss: 14255.4492, (Recon: 14232.9434, KLD: 22.5059), Gradient norm: 1523.5663


 96%|█████████▌| 251/261 [01:06<00:02,  3.74it/s]

Step 13,300, (N samples: 1,702,400), Loss: 14006.0000, (Recon: 13983.3066, KLD: 22.6932), Gradient norm: 1319.7534


100%|██████████| 261/261 [01:09<00:00,  3.75it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.04it/s]


====> Test set loss: 14253.3511, (BCE: 14229.7411, KLD: 23.6101)
Epoch 52/100


 34%|███▍      | 90/261 [00:24<00:46,  3.70it/s]

Step 13,400, (N samples: 1,715,200), Loss: 14534.2676, (Recon: 14509.3711, KLD: 24.8961), Gradient norm: 2271.1076


 73%|███████▎  | 190/261 [00:50<00:19,  3.71it/s]

Step 13,500, (N samples: 1,728,000), Loss: 14348.9199, (Recon: 14323.9180, KLD: 25.0016), Gradient norm: 1364.4062


100%|██████████| 261/261 [01:09<00:00,  3.75it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.02it/s]


====> Test set loss: 14253.9273, (BCE: 14230.3426, KLD: 23.5847)
Epoch 53/100


 11%|█         | 29/261 [00:07<01:02,  3.71it/s]

Step 13,600, (N samples: 1,740,800), Loss: 14102.3789, (Recon: 14078.4775, KLD: 23.9015), Gradient norm: 1857.3698


 49%|████▉     | 129/261 [00:34<00:36,  3.63it/s]

Step 13,700, (N samples: 1,753,600), Loss: 14224.3252, (Recon: 14200.6992, KLD: 23.6259), Gradient norm: 1514.9350


 88%|████████▊ | 229/261 [01:01<00:08,  3.65it/s]

Step 13,800, (N samples: 1,766,400), Loss: 13901.3164, (Recon: 13877.4229, KLD: 23.8934), Gradient norm: 1492.0340


100%|██████████| 261/261 [01:09<00:00,  3.74it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.27it/s]


====> Test set loss: 14251.0206, (BCE: 14227.7185, KLD: 23.3022)
Epoch 54/100


 26%|██▌       | 68/261 [00:18<00:53,  3.63it/s]

Step 13,900, (N samples: 1,779,200), Loss: 13731.1523, (Recon: 13706.3281, KLD: 24.8241), Gradient norm: 1404.7425


 64%|██████▍   | 168/261 [00:44<00:24,  3.76it/s]

Step 14,000, (N samples: 1,792,000), Loss: 13949.7744, (Recon: 13925.6689, KLD: 24.1052), Gradient norm: 1454.0230


100%|██████████| 261/261 [01:09<00:00,  3.77it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.08it/s]


====> Test set loss: 14254.0487, (BCE: 14230.6351, KLD: 23.4135)
Epoch 55/100


  3%|▎         | 7/261 [00:01<01:08,  3.73it/s]

Step 14,100, (N samples: 1,804,800), Loss: 14108.3242, (Recon: 14085.5762, KLD: 22.7481), Gradient norm: 1612.6305


 41%|████      | 107/261 [00:28<00:38,  4.03it/s]

Step 14,200, (N samples: 1,817,600), Loss: 14261.5479, (Recon: 14238.4102, KLD: 23.1381), Gradient norm: 1243.1281


 79%|███████▉  | 207/261 [00:55<00:14,  3.69it/s]

Step 14,300, (N samples: 1,830,400), Loss: 13907.7305, (Recon: 13884.2100, KLD: 23.5201), Gradient norm: 1559.2708


100%|██████████| 261/261 [01:09<00:00,  3.76it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.03it/s]


====> Test set loss: 14249.2663, (BCE: 14225.9043, KLD: 23.3621)
Epoch 56/100


 18%|█▊        | 46/261 [00:12<00:58,  3.66it/s]

Step 14,400, (N samples: 1,843,200), Loss: 14119.9150, (Recon: 14093.2705, KLD: 26.6450), Gradient norm: 1185.8496


 56%|█████▌    | 146/261 [00:38<00:31,  3.70it/s]

Step 14,500, (N samples: 1,856,000), Loss: 14482.8027, (Recon: 14459.4902, KLD: 23.3124), Gradient norm: 1722.8903


 94%|█████████▍| 246/261 [01:05<00:04,  3.68it/s]

Step 14,600, (N samples: 1,868,800), Loss: 14052.2441, (Recon: 14028.0645, KLD: 24.1795), Gradient norm: 1909.1122


100%|██████████| 261/261 [01:09<00:00,  3.76it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.11it/s]


====> Test set loss: 14247.2248, (BCE: 14224.1687, KLD: 23.0561)
Epoch 57/100


 33%|███▎      | 85/261 [00:22<00:47,  3.67it/s]

Step 14,700, (N samples: 1,881,600), Loss: 13970.1963, (Recon: 13946.4824, KLD: 23.7141), Gradient norm: 1100.4778


 71%|███████   | 185/261 [00:49<00:20,  3.67it/s]

Step 14,800, (N samples: 1,894,400), Loss: 14060.8896, (Recon: 14037.6680, KLD: 23.2220), Gradient norm: 1267.7920


100%|██████████| 261/261 [01:09<00:00,  3.75it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  5.96it/s]


====> Test set loss: 14253.4209, (BCE: 14230.2373, KLD: 23.1837)
Epoch 58/100


  9%|▉         | 24/261 [00:06<01:04,  3.69it/s]

Step 14,900, (N samples: 1,907,200), Loss: 13893.8691, (Recon: 13869.8760, KLD: 23.9929), Gradient norm: 1343.2397


 48%|████▊     | 124/261 [00:33<00:37,  3.68it/s]

Step 15,000, (N samples: 1,920,000), Loss: 14193.7852, (Recon: 14170.7773, KLD: 23.0078), Gradient norm: 1624.2335


 86%|████████▌ | 224/261 [01:00<00:10,  3.68it/s]

Step 15,100, (N samples: 1,932,800), Loss: 13699.0645, (Recon: 13675.5098, KLD: 23.5549), Gradient norm: 1039.6275


100%|██████████| 261/261 [01:10<00:00,  3.71it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.26it/s]


====> Test set loss: 14254.1111, (BCE: 14231.0585, KLD: 23.0525)
Epoch 59/100


 24%|██▍       | 63/261 [00:16<00:52,  3.75it/s]

Step 15,200, (N samples: 1,945,600), Loss: 14228.8379, (Recon: 14205.4648, KLD: 23.3727), Gradient norm: 1422.0349


 62%|██████▏   | 163/261 [00:43<00:26,  3.68it/s]

Step 15,300, (N samples: 1,958,400), Loss: 14166.3623, (Recon: 14143.0947, KLD: 23.2678), Gradient norm: 722.9296


100%|██████████| 261/261 [01:09<00:00,  3.73it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.11it/s]


====> Test set loss: 14254.4764, (BCE: 14231.0988, KLD: 23.3777)
Epoch 60/100


  1%|          | 2/261 [00:00<01:14,  3.49it/s]

Step 15,400, (N samples: 1,971,200), Loss: 13877.5781, (Recon: 13853.3994, KLD: 24.1786), Gradient norm: 1276.2764


 39%|███▉      | 102/261 [00:27<00:40,  3.95it/s]

Step 15,500, (N samples: 1,984,000), Loss: 14318.4746, (Recon: 14294.0156, KLD: 24.4591), Gradient norm: 1678.7824


 77%|███████▋  | 202/261 [00:53<00:15,  3.72it/s]

Step 15,600, (N samples: 1,996,800), Loss: 13911.3779, (Recon: 13888.6289, KLD: 22.7489), Gradient norm: 1952.9903


100%|██████████| 261/261 [01:09<00:00,  3.75it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.05it/s]


====> Test set loss: 14261.8477, (BCE: 14238.5530, KLD: 23.2947)
Epoch 61/100


 16%|█▌        | 41/261 [00:10<00:58,  3.74it/s]

Step 15,700, (N samples: 2,009,600), Loss: 13873.8926, (Recon: 13850.9443, KLD: 22.9483), Gradient norm: 2279.0559


 54%|█████▍    | 141/261 [00:37<00:33,  3.58it/s]

Step 15,800, (N samples: 2,022,400), Loss: 14111.6885, (Recon: 14088.6680, KLD: 23.0205), Gradient norm: 1328.3934


 92%|█████████▏| 241/261 [01:03<00:05,  3.70it/s]

Step 15,900, (N samples: 2,035,200), Loss: 14462.4268, (Recon: 14439.5771, KLD: 22.8494), Gradient norm: 923.2703


100%|██████████| 261/261 [01:09<00:00,  3.76it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.02it/s]


====> Test set loss: 14253.3440, (BCE: 14230.1985, KLD: 23.1454)
Epoch 62/100


 31%|███       | 80/261 [00:21<00:48,  3.74it/s]

Step 16,000, (N samples: 2,048,000), Loss: 14391.7451, (Recon: 14367.5215, KLD: 24.2232), Gradient norm: 1601.7434


 69%|██████▉   | 180/261 [00:47<00:22,  3.66it/s]

Step 16,100, (N samples: 2,060,800), Loss: 13958.6846, (Recon: 13934.9766, KLD: 23.7077), Gradient norm: 1601.2298


100%|██████████| 261/261 [01:09<00:00,  3.76it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.03it/s]


====> Test set loss: 14253.1067, (BCE: 14229.8567, KLD: 23.2499)
Epoch 63/100


  7%|▋         | 19/261 [00:05<01:04,  3.74it/s]

Step 16,200, (N samples: 2,073,600), Loss: 14178.6758, (Recon: 14155.5840, KLD: 23.0913), Gradient norm: 1137.3511


 46%|████▌     | 119/261 [00:31<00:38,  3.73it/s]

Step 16,300, (N samples: 2,086,400), Loss: 13961.7002, (Recon: 13938.8438, KLD: 22.8561), Gradient norm: 1245.8984


 84%|████████▍ | 219/261 [00:58<00:11,  3.71it/s]

Step 16,400, (N samples: 2,099,200), Loss: 14088.5166, (Recon: 14065.6660, KLD: 22.8506), Gradient norm: 1177.6980


100%|██████████| 261/261 [01:09<00:00,  3.76it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.47it/s]


====> Test set loss: 14250.6529, (BCE: 14227.5767, KLD: 23.0762)
Epoch 64/100


 22%|██▏       | 58/261 [00:15<00:54,  3.73it/s]

Step 16,500, (N samples: 2,112,000), Loss: 13782.0488, (Recon: 13758.6680, KLD: 23.3812), Gradient norm: 1611.3873


 61%|██████    | 158/261 [00:42<00:27,  3.74it/s]

Step 16,600, (N samples: 2,124,800), Loss: 14264.2568, (Recon: 14240.3564, KLD: 23.9003), Gradient norm: 1905.7249


 99%|█████████▉| 258/261 [01:08<00:00,  3.77it/s]

Step 16,700, (N samples: 2,137,600), Loss: 14189.0283, (Recon: 14166.3770, KLD: 22.6513), Gradient norm: 2005.8770


100%|██████████| 261/261 [01:09<00:00,  3.76it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.43it/s]


====> Test set loss: 14249.6367, (BCE: 14226.6681, KLD: 22.9685)
Epoch 65/100


 37%|███▋      | 97/261 [00:25<00:44,  3.72it/s]

Step 16,800, (N samples: 2,150,400), Loss: 14013.5166, (Recon: 13988.4141, KLD: 25.1022), Gradient norm: 2581.1423


 75%|███████▌  | 197/261 [00:52<00:17,  3.62it/s]

Step 16,900, (N samples: 2,163,200), Loss: 13771.0400, (Recon: 13747.6426, KLD: 23.3974), Gradient norm: 1589.9868


100%|██████████| 261/261 [01:09<00:00,  3.74it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.37it/s]


====> Test set loss: 14246.2657, (BCE: 14223.0503, KLD: 23.2154)
Epoch 66/100


 14%|█▍        | 36/261 [00:09<00:55,  4.06it/s]

Step 17,000, (N samples: 2,176,000), Loss: 14177.8428, (Recon: 14153.5625, KLD: 24.2806), Gradient norm: 1494.6502


 52%|█████▏    | 136/261 [00:35<00:31,  3.91it/s]

Step 17,100, (N samples: 2,188,800), Loss: 14018.6328, (Recon: 13993.1445, KLD: 25.4888), Gradient norm: 1909.3292


 90%|█████████ | 236/261 [01:02<00:06,  4.09it/s]

Step 17,200, (N samples: 2,201,600), Loss: 14065.5039, (Recon: 14042.5557, KLD: 22.9478), Gradient norm: 2616.8258


100%|██████████| 261/261 [01:09<00:00,  3.78it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.37it/s]


====> Test set loss: 14253.6578, (BCE: 14230.5587, KLD: 23.0990)
Epoch 67/100


 29%|██▊       | 75/261 [00:20<00:49,  3.75it/s]

Step 17,300, (N samples: 2,214,400), Loss: 14474.5068, (Recon: 14449.9785, KLD: 24.5285), Gradient norm: 1700.0057


 67%|██████▋   | 175/261 [00:46<00:23,  3.65it/s]

Step 17,400, (N samples: 2,227,200), Loss: 14259.1611, (Recon: 14235.4902, KLD: 23.6714), Gradient norm: 1370.4851


100%|██████████| 261/261 [01:09<00:00,  3.76it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.44it/s]


====> Test set loss: 14256.3647, (BCE: 14233.0445, KLD: 23.3202)
Epoch 68/100


  5%|▌         | 14/261 [00:03<01:08,  3.63it/s]

Step 17,500, (N samples: 2,240,000), Loss: 13919.9717, (Recon: 13895.7793, KLD: 24.1923), Gradient norm: 1893.3539


 44%|████▎     | 114/261 [00:30<00:39,  3.69it/s]

Step 17,600, (N samples: 2,252,800), Loss: 13711.9053, (Recon: 13689.3672, KLD: 22.5383), Gradient norm: 1699.0153


 82%|████████▏ | 214/261 [00:57<00:12,  3.70it/s]

Step 17,700, (N samples: 2,265,600), Loss: 14110.4648, (Recon: 14086.8105, KLD: 23.6540), Gradient norm: 3195.0337


100%|██████████| 261/261 [01:09<00:00,  3.74it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.61it/s]


====> Test set loss: 14250.9345, (BCE: 14227.8649, KLD: 23.0697)
Epoch 69/100


 20%|██        | 53/261 [00:14<00:56,  3.69it/s]

Step 17,800, (N samples: 2,278,400), Loss: 14350.6436, (Recon: 14326.3350, KLD: 24.3083), Gradient norm: 1586.1978


 59%|█████▊    | 153/261 [00:40<00:28,  3.76it/s]

Step 17,900, (N samples: 2,291,200), Loss: 14275.9736, (Recon: 14252.8896, KLD: 23.0838), Gradient norm: 1650.9620


 97%|█████████▋| 253/261 [01:07<00:02,  3.68it/s]

Step 18,000, (N samples: 2,304,000), Loss: 13892.9824, (Recon: 13869.5146, KLD: 23.4681), Gradient norm: 1313.8833


100%|██████████| 261/261 [01:09<00:00,  3.76it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.35it/s]


====> Test set loss: 14251.5883, (BCE: 14228.3540, KLD: 23.2342)
Epoch 70/100


 35%|███▌      | 92/261 [00:24<00:46,  3.65it/s]

Step 18,100, (N samples: 2,316,800), Loss: 14047.8242, (Recon: 14023.8184, KLD: 24.0058), Gradient norm: 1783.7122


 74%|███████▎  | 192/261 [00:51<00:18,  3.71it/s]

Step 18,200, (N samples: 2,329,600), Loss: 13721.9375, (Recon: 13697.7949, KLD: 24.1426), Gradient norm: 1513.4632


100%|██████████| 261/261 [01:09<00:00,  3.76it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.45it/s]


====> Test set loss: 14250.1835, (BCE: 14226.4897, KLD: 23.6938)
Epoch 71/100


 12%|█▏        | 31/261 [00:08<01:02,  3.69it/s]

Step 18,300, (N samples: 2,342,400), Loss: 14014.2881, (Recon: 13991.5215, KLD: 22.7669), Gradient norm: 1843.2087


 50%|█████     | 131/261 [00:34<00:34,  3.74it/s]

Step 18,400, (N samples: 2,355,200), Loss: 13773.8037, (Recon: 13752.0664, KLD: 21.7376), Gradient norm: 1302.2265


 89%|████████▊ | 231/261 [01:01<00:08,  3.72it/s]

Step 18,500, (N samples: 2,368,000), Loss: 14006.0762, (Recon: 13983.0137, KLD: 23.0626), Gradient norm: 1410.5518


100%|██████████| 261/261 [01:09<00:00,  3.76it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.51it/s]


====> Test set loss: 14257.6177, (BCE: 14234.4301, KLD: 23.1876)
Epoch 72/100


 27%|██▋       | 70/261 [00:18<00:51,  3.68it/s]

Step 18,600, (N samples: 2,380,800), Loss: 14340.1406, (Recon: 14316.3809, KLD: 23.7598), Gradient norm: 1182.0484


 65%|██████▌   | 170/261 [00:44<00:22,  4.07it/s]

Step 18,700, (N samples: 2,393,600), Loss: 14329.9326, (Recon: 14306.1055, KLD: 23.8274), Gradient norm: 1267.2236


100%|██████████| 261/261 [01:08<00:00,  3.79it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.42it/s]


====> Test set loss: 14254.0199, (BCE: 14230.9397, KLD: 23.0803)
Epoch 73/100


  3%|▎         | 9/261 [00:02<01:07,  3.72it/s]

Step 18,800, (N samples: 2,406,400), Loss: 13986.8604, (Recon: 13962.9922, KLD: 23.8681), Gradient norm: 1981.1237


 42%|████▏     | 109/261 [00:28<00:41,  3.65it/s]

Step 18,900, (N samples: 2,419,200), Loss: 14126.8721, (Recon: 14103.4883, KLD: 23.3834), Gradient norm: 1850.1107


 80%|████████  | 209/261 [00:55<00:14,  3.70it/s]

Step 19,000, (N samples: 2,432,000), Loss: 13633.0596, (Recon: 13609.8682, KLD: 23.1918), Gradient norm: 1454.5585


100%|██████████| 261/261 [01:09<00:00,  3.75it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.47it/s]


====> Test set loss: 14258.1053, (BCE: 14234.3119, KLD: 23.7934)
Epoch 74/100


 18%|█▊        | 48/261 [00:12<00:57,  3.72it/s]

Step 19,100, (N samples: 2,444,800), Loss: 13775.1768, (Recon: 13752.3076, KLD: 22.8693), Gradient norm: 2123.8233


 57%|█████▋    | 148/261 [00:39<00:30,  3.70it/s]

Step 19,200, (N samples: 2,457,600), Loss: 14209.7676, (Recon: 14185.0635, KLD: 24.7039), Gradient norm: 1392.2517


 95%|█████████▌| 248/261 [01:05<00:03,  3.71it/s]

Step 19,300, (N samples: 2,470,400), Loss: 14135.7588, (Recon: 14112.0840, KLD: 23.6744), Gradient norm: 1668.5635


100%|██████████| 261/261 [01:09<00:00,  3.77it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.49it/s]


====> Test set loss: 14260.1237, (BCE: 14236.5993, KLD: 23.5243)
Epoch 75/100


 33%|███▎      | 87/261 [00:23<00:47,  3.67it/s]

Step 19,400, (N samples: 2,483,200), Loss: 14061.6240, (Recon: 14038.5078, KLD: 23.1165), Gradient norm: 1535.1468


 72%|███████▏  | 187/261 [00:49<00:19,  3.73it/s]

Step 19,500, (N samples: 2,496,000), Loss: 14324.7266, (Recon: 14300.7852, KLD: 23.9419), Gradient norm: 1139.5938


100%|██████████| 261/261 [01:09<00:00,  3.77it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.45it/s]


====> Test set loss: 14249.8755, (BCE: 14226.0492, KLD: 23.8263)
Epoch 76/100


 10%|▉         | 26/261 [00:06<01:02,  3.75it/s]

Step 19,600, (N samples: 2,508,800), Loss: 14077.9941, (Recon: 14054.1367, KLD: 23.8572), Gradient norm: 1798.4630


 48%|████▊     | 126/261 [00:33<00:35,  3.76it/s]

Step 19,700, (N samples: 2,521,600), Loss: 14210.3633, (Recon: 14186.2373, KLD: 24.1257), Gradient norm: 1511.9112


 87%|████████▋ | 226/261 [00:59<00:09,  3.76it/s]

Step 19,800, (N samples: 2,534,400), Loss: 14076.2559, (Recon: 14052.3574, KLD: 23.8981), Gradient norm: 1607.3944


100%|██████████| 261/261 [01:08<00:00,  3.79it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.80it/s]


====> Test set loss: 14254.1041, (BCE: 14230.5727, KLD: 23.5314)
Epoch 77/100


 25%|██▍       | 65/261 [00:17<00:52,  3.70it/s]

Step 19,900, (N samples: 2,547,200), Loss: 13845.8340, (Recon: 13821.3779, KLD: 24.4556), Gradient norm: 1145.0050


 63%|██████▎   | 165/261 [00:44<00:25,  3.76it/s]

Step 20,000, (N samples: 2,560,000), Loss: 13961.7881, (Recon: 13937.7617, KLD: 24.0264), Gradient norm: 2441.0227


100%|██████████| 261/261 [01:09<00:00,  3.76it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.37it/s]


====> Test set loss: 14266.7521, (BCE: 14243.4776, KLD: 23.2744)
Epoch 78/100


  2%|▏         | 4/261 [00:01<01:10,  3.65it/s]

Step 20,100, (N samples: 2,572,800), Loss: 14118.4580, (Recon: 14094.6270, KLD: 23.8309), Gradient norm: 1640.5125


 40%|███▉      | 104/261 [00:27<00:42,  3.69it/s]

Step 20,200, (N samples: 2,585,600), Loss: 13839.2559, (Recon: 13814.9375, KLD: 24.3182), Gradient norm: 1627.2665


 78%|███████▊  | 204/261 [00:54<00:15,  3.70it/s]

Step 20,300, (N samples: 2,598,400), Loss: 13894.4482, (Recon: 13871.2305, KLD: 23.2180), Gradient norm: 1173.7950


100%|██████████| 261/261 [01:09<00:00,  3.76it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.36it/s]


====> Test set loss: 14256.2886, (BCE: 14232.7444, KLD: 23.5441)
Epoch 79/100


 16%|█▋        | 43/261 [00:11<00:57,  3.82it/s]

Step 20,400, (N samples: 2,611,200), Loss: 13949.7939, (Recon: 13927.4150, KLD: 22.3788), Gradient norm: 1134.5901


 55%|█████▍    | 143/261 [00:37<00:32,  3.68it/s]

Step 20,500, (N samples: 2,624,000), Loss: 13996.3203, (Recon: 13973.0117, KLD: 23.3084), Gradient norm: 1999.1416


 93%|█████████▎| 243/261 [01:04<00:04,  4.07it/s]

Step 20,600, (N samples: 2,636,800), Loss: 13610.4658, (Recon: 13586.0566, KLD: 24.4088), Gradient norm: 1507.5102


100%|██████████| 261/261 [01:08<00:00,  3.79it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.44it/s]


====> Test set loss: 14256.3897, (BCE: 14232.7336, KLD: 23.6562)
Epoch 80/100


 31%|███▏      | 82/261 [00:21<00:49,  3.60it/s]

Step 20,700, (N samples: 2,649,600), Loss: 14215.3906, (Recon: 14191.6045, KLD: 23.7862), Gradient norm: 1544.5005


 70%|██████▉   | 182/261 [00:48<00:21,  3.63it/s]

Step 20,800, (N samples: 2,662,400), Loss: 14141.4609, (Recon: 14117.7910, KLD: 23.6697), Gradient norm: 1127.8797


100%|██████████| 261/261 [01:09<00:00,  3.77it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.35it/s]


====> Test set loss: 14255.9616, (BCE: 14232.4671, KLD: 23.4946)
Epoch 81/100


  8%|▊         | 21/261 [00:05<01:04,  3.72it/s]

Step 20,900, (N samples: 2,675,200), Loss: 14353.7109, (Recon: 14330.2012, KLD: 23.5098), Gradient norm: 983.3693


 46%|████▋     | 121/261 [00:32<00:37,  3.78it/s]

Step 21,000, (N samples: 2,688,000), Loss: 14202.6816, (Recon: 14179.1230, KLD: 23.5589), Gradient norm: 1032.1319


 85%|████████▍ | 221/261 [00:58<00:10,  3.75it/s]

Step 21,100, (N samples: 2,700,800), Loss: 14233.7607, (Recon: 14210.4688, KLD: 23.2920), Gradient norm: 1143.4619


100%|██████████| 261/261 [01:09<00:00,  3.77it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.30it/s]


====> Test set loss: 14260.7792, (BCE: 14237.2127, KLD: 23.5666)
Epoch 82/100


 23%|██▎       | 60/261 [00:15<00:54,  3.71it/s]

Step 21,200, (N samples: 2,713,600), Loss: 14021.5635, (Recon: 13997.5684, KLD: 23.9952), Gradient norm: 2520.2437


 61%|██████▏   | 160/261 [00:42<00:27,  3.70it/s]

Step 21,300, (N samples: 2,726,400), Loss: 14160.1885, (Recon: 14134.2246, KLD: 25.9637), Gradient norm: 1874.6853


100%|█████████▉| 260/261 [01:08<00:00,  3.76it/s]

Step 21,400, (N samples: 2,739,200), Loss: 13638.6367, (Recon: 13614.4473, KLD: 24.1892), Gradient norm: 1356.0840


100%|██████████| 261/261 [01:08<00:00,  3.79it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.35it/s]


====> Test set loss: 14251.9526, (BCE: 14228.4806, KLD: 23.4721)
Epoch 83/100


 38%|███▊      | 99/261 [00:26<00:43,  3.72it/s]

Step 21,500, (N samples: 2,752,000), Loss: 14180.8994, (Recon: 14156.8457, KLD: 24.0539), Gradient norm: 1632.4279


 76%|███████▌  | 199/261 [00:52<00:16,  3.71it/s]

Step 21,600, (N samples: 2,764,800), Loss: 14083.8721, (Recon: 14059.1582, KLD: 24.7138), Gradient norm: 1986.6458


100%|██████████| 261/261 [01:09<00:00,  3.77it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.47it/s]


====> Test set loss: 14248.5366, (BCE: 14225.3090, KLD: 23.2276)
Epoch 84/100


 15%|█▍        | 38/261 [00:10<01:00,  3.71it/s]

Step 21,700, (N samples: 2,777,600), Loss: 14131.9863, (Recon: 14107.7969, KLD: 24.1898), Gradient norm: 1629.4118


 53%|█████▎    | 138/261 [00:36<00:32,  3.79it/s]

Step 21,800, (N samples: 2,790,400), Loss: 14147.2285, (Recon: 14123.3018, KLD: 23.9264), Gradient norm: 1547.5000


 91%|█████████ | 238/261 [01:03<00:06,  3.78it/s]

Step 21,900, (N samples: 2,803,200), Loss: 14064.2959, (Recon: 14039.8984, KLD: 24.3979), Gradient norm: 1921.9872


100%|██████████| 261/261 [01:08<00:00,  3.78it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.39it/s]


====> Test set loss: 14253.8802, (BCE: 14230.5305, KLD: 23.3497)
Epoch 85/100


 30%|██▉       | 77/261 [00:20<00:48,  3.76it/s]

Step 22,000, (N samples: 2,816,000), Loss: 13762.2764, (Recon: 13737.5410, KLD: 24.7357), Gradient norm: 1291.3877


 68%|██████▊   | 177/261 [00:47<00:22,  3.80it/s]

Step 22,100, (N samples: 2,828,800), Loss: 13843.5439, (Recon: 13818.8809, KLD: 24.6631), Gradient norm: 1705.9511


100%|██████████| 261/261 [01:09<00:00,  3.78it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.43it/s]


====> Test set loss: 14263.7529, (BCE: 14240.2220, KLD: 23.5308)
Epoch 86/100


  6%|▌         | 16/261 [00:04<01:06,  3.66it/s]

Step 22,200, (N samples: 2,841,600), Loss: 14237.1836, (Recon: 14213.3301, KLD: 23.8538), Gradient norm: 1442.9433


 44%|████▍     | 116/261 [00:30<00:37,  3.84it/s]

Step 22,300, (N samples: 2,854,400), Loss: 14190.2295, (Recon: 14166.5098, KLD: 23.7197), Gradient norm: 1337.4116


 83%|████████▎ | 216/261 [00:57<00:12,  3.73it/s]

Step 22,400, (N samples: 2,867,200), Loss: 14053.0117, (Recon: 14030.6191, KLD: 22.3927), Gradient norm: 1990.2092


100%|██████████| 261/261 [01:09<00:00,  3.78it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.44it/s]


====> Test set loss: 14252.7101, (BCE: 14229.2412, KLD: 23.4689)
Epoch 87/100


 21%|██        | 55/261 [00:14<00:54,  3.76it/s]

Step 22,500, (N samples: 2,880,000), Loss: 14229.7275, (Recon: 14206.8164, KLD: 22.9115), Gradient norm: 1460.6601


 59%|█████▉    | 155/261 [00:40<00:28,  3.71it/s]

Step 22,600, (N samples: 2,892,800), Loss: 13743.8291, (Recon: 13718.5938, KLD: 25.2351), Gradient norm: 2044.8203


 98%|█████████▊| 255/261 [01:07<00:01,  3.68it/s]

Step 22,700, (N samples: 2,905,600), Loss: 14106.5674, (Recon: 14082.7119, KLD: 23.8551), Gradient norm: 1864.1689


100%|██████████| 261/261 [01:08<00:00,  3.79it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.34it/s]


====> Test set loss: 14250.7457, (BCE: 14227.2605, KLD: 23.4851)
Epoch 88/100


 36%|███▌      | 94/261 [00:25<00:45,  3.68it/s]

Step 22,800, (N samples: 2,918,400), Loss: 14221.6221, (Recon: 14199.2031, KLD: 22.4192), Gradient norm: 1029.3226


 74%|███████▍  | 194/261 [00:51<00:17,  3.73it/s]

Step 22,900, (N samples: 2,931,200), Loss: 13766.7021, (Recon: 13742.4424, KLD: 24.2602), Gradient norm: 1159.9696


100%|██████████| 261/261 [01:09<00:00,  3.74it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.34it/s]


====> Test set loss: 14257.8449, (BCE: 14234.1987, KLD: 23.6463)
Epoch 89/100


 13%|█▎        | 33/261 [00:08<01:02,  3.64it/s]

Step 23,000, (N samples: 2,944,000), Loss: 14166.1816, (Recon: 14141.6143, KLD: 24.5670), Gradient norm: 2025.4883


 51%|█████     | 133/261 [00:35<00:34,  3.74it/s]

Step 23,100, (N samples: 2,956,800), Loss: 13755.7393, (Recon: 13730.7109, KLD: 25.0282), Gradient norm: 1274.2627


 89%|████████▉ | 233/261 [01:01<00:07,  3.72it/s]

Step 23,200, (N samples: 2,969,600), Loss: 13901.4521, (Recon: 13876.0645, KLD: 25.3877), Gradient norm: 1731.7361


100%|██████████| 261/261 [01:09<00:00,  3.77it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.67it/s]


====> Test set loss: 14255.1325, (BCE: 14231.5789, KLD: 23.5537)
Epoch 90/100


 28%|██▊       | 72/261 [00:19<00:50,  3.75it/s]

Step 23,300, (N samples: 2,982,400), Loss: 14350.2793, (Recon: 14326.2559, KLD: 24.0237), Gradient norm: 1870.6803


 66%|██████▌   | 172/261 [00:45<00:23,  3.75it/s]

Step 23,400, (N samples: 2,995,200), Loss: 14255.1162, (Recon: 14231.1035, KLD: 24.0126), Gradient norm: 1088.5235


100%|██████████| 261/261 [01:09<00:00,  3.77it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.38it/s]


====> Test set loss: 14255.6827, (BCE: 14232.1764, KLD: 23.5063)
Epoch 91/100


  4%|▍         | 11/261 [00:02<01:06,  3.75it/s]

Step 23,500, (N samples: 3,008,000), Loss: 14011.0107, (Recon: 13988.1758, KLD: 22.8353), Gradient norm: 1251.3397


 43%|████▎     | 111/261 [00:29<00:41,  3.65it/s]

Step 23,600, (N samples: 3,020,800), Loss: 14211.3936, (Recon: 14187.8672, KLD: 23.5264), Gradient norm: 1271.0570


 81%|████████  | 211/261 [00:56<00:13,  3.71it/s]

Step 23,700, (N samples: 3,033,600), Loss: 13826.3018, (Recon: 13801.7656, KLD: 24.5360), Gradient norm: 1863.5604


100%|██████████| 261/261 [01:09<00:00,  3.76it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.34it/s]


====> Test set loss: 14257.4455, (BCE: 14233.7250, KLD: 23.7206)
Epoch 92/100


 19%|█▉        | 50/261 [00:13<00:52,  4.00it/s]

Step 23,800, (N samples: 3,046,400), Loss: 13949.9102, (Recon: 13925.2207, KLD: 24.6899), Gradient norm: 1530.0180


 57%|█████▋    | 150/261 [00:39<00:29,  3.77it/s]

Step 23,900, (N samples: 3,059,200), Loss: 14285.5293, (Recon: 14262.9092, KLD: 22.6200), Gradient norm: 1182.9748


 96%|█████████▌| 250/261 [01:05<00:02,  4.07it/s]

Step 24,000, (N samples: 3,072,000), Loss: 14388.3438, (Recon: 14365.3027, KLD: 23.0409), Gradient norm: 1435.8615


100%|██████████| 261/261 [01:08<00:00,  3.79it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.49it/s]


====> Test set loss: 14255.8647, (BCE: 14232.4982, KLD: 23.3666)
Epoch 93/100


 34%|███▍      | 89/261 [00:23<00:45,  3.77it/s]

Step 24,100, (N samples: 3,084,800), Loss: 13762.9434, (Recon: 13738.5000, KLD: 24.4438), Gradient norm: 1481.7386


 72%|███████▏  | 189/261 [00:49<00:19,  3.78it/s]

Step 24,200, (N samples: 3,097,600), Loss: 14178.3135, (Recon: 14154.5176, KLD: 23.7957), Gradient norm: 1662.8145


100%|██████████| 261/261 [01:09<00:00,  3.78it/s]
Testing: 100%|██████████| 29/29 [00:04<00:00,  6.39it/s]


====> Test set loss: 14255.8241, (BCE: 14232.3304, KLD: 23.4938)
Epoch 94/100


 11%|█         | 28/261 [00:07<01:02,  3.72it/s]

Step 24,300, (N samples: 3,110,400), Loss: 14049.5293, (Recon: 14025.6250, KLD: 23.9046), Gradient norm: 1284.9076


 31%|███       | 81/261 [00:21<00:47,  3.79it/s]

In [11]:
def pearson_correlation(original_x, x_hat, mask=None):
    
    # Calculate mean and standard deviation
    mean_x = torch.mean(original_x)
    mean_x_hat = torch.mean(x_hat)
    
    std_x = torch.std(original_x)
    std_x_hat = torch.std(x_hat)
    
    # Calculate covariance
    covariance = torch.mean((original_x - mean_x) * (x_hat - mean_x_hat))
    
    # Calculate Pearson correlation
    correlation = covariance / (std_x * std_x_hat)
    
    return correlation.item()

In [14]:
import torch

# Assuming `model` is your trained model and `dataloader` is your DataLoader
model.eval()  # Set model to evaluation mode

correlations = []

# Disable gradient calculation for efficiency
with torch.no_grad():
    for batch in test_loader:
        # Get the input data (original_x)
        original_x, _ = batch  # Adjust this if batch contains labels or other info
        original_x = original_x.to(device)  # Move to device if using GPU

        # Get the model output (x_hat)
        out = model(original_x)
        x_hat = out.x_recon
        
        # Compute Pearson correlation for the batch
        correlation = pearson_correlation(original_x, x_hat)
        correlations.append(correlation)
        
        # Print or store the correlation
        print("Pearson Correlation:", correlation)


Pearson Correlation: 0.957661509513855
Pearson Correlation: 0.950019896030426
Pearson Correlation: 0.9605104327201843
Pearson Correlation: 0.9542161226272583
Pearson Correlation: 0.951418936252594
Pearson Correlation: 0.9480416178703308
Pearson Correlation: 0.9556251168251038
Pearson Correlation: 0.9498641490936279
Pearson Correlation: 0.9456736445426941
Pearson Correlation: 0.9548927545547485
Pearson Correlation: 0.9457770586013794
Pearson Correlation: 0.9527332186698914
Pearson Correlation: 0.9417027831077576
Pearson Correlation: 0.9517192840576172
Pearson Correlation: 0.9572741985321045
Pearson Correlation: 0.9492014646530151
Pearson Correlation: 0.945146918296814
Pearson Correlation: 0.9474121332168579
Pearson Correlation: 0.9577855467796326
Pearson Correlation: 0.9552962779998779
Pearson Correlation: 0.9518298506736755
Pearson Correlation: 0.9525393843650818
Pearson Correlation: 0.9414599537849426
Pearson Correlation: 0.951117753982544
Pearson Correlation: 0.948650598526001
Pearso

In [17]:
correlations = []

# Disable gradient calculation for efficiency
with torch.no_grad():
    for batch in train_loader:
        # Unpack the batch to get the input data; adjust if there are labels
        original_x, _ = batch  # Adjust this if your batch contains more items
        original_x = original_x.to(device)  # Move input to device if using GPU

        # Get the model output for the entire batch
        out = model(original_x)
        x_hat_batch = out.x_recon

        # Compute Pearson correlation for each element in the batch
        for i in range(original_x.size(0)):  # Loop over each sample in the batch
            original_sample = original_x[i]  # Select the i-th sample
            x_hat_sample = x_hat_batch[i]    # Select the corresponding output sample

            # Compute Pearson correlation for the sample
            correlation = pearson_correlation(original_sample, x_hat_sample)
            correlations.append(correlation)
            # print(f"Sample {i} Pearson Correlation:", correlation)

# Optionally, compute average correlation across all elements
average_correlation = sum(correlations) / len(correlations)
print("Average Pearson Correlation across all samples:", average_correlation)


Average Pearson Correlation across all samples: 0.9529571182722096
