In [1]:
%load_ext autoreload
# %reload_ext autoreload
%autoreload 2


In [2]:
from methylVA.mnist.dataset import get_methyl_data_loaders

data_id = 0.1
batch_size = 128

train_data_path = f"../data/dimension_reduction/highly_variable_features/train_data_filtered_{data_id}.pkl"
train_metadata_path = f"../data/dimension_reduction/highly_variable_features/train_metadata_with_labels.pkl"
test_data_path = f"../data/dimension_reduction/highly_variable_features/test_data_filtered_{data_id}.pkl"
test_metadata_path = f"../data/dimension_reduction/highly_variable_features/test_metadata_with_labels.pkl"


train_loader, test_loader = get_methyl_data_loaders(
    train_data_path,
    train_metadata_path,
    test_data_path,
    test_metadata_path,
    batch_size=batch_size
)


Found NaN values in the data after conversion.
Found NaN values in the data after conversion.


In [3]:
data_batch, _ = next(iter(train_loader))


num_train_rows = len(train_loader.dataset)
num_test_rows = len(test_loader.dataset)

print("Number of features in each dataset:", data_batch.shape[1])
print("Number of rows in the training dataset:", num_train_rows)
print("Number of rows in the test dataset:", num_test_rows)

Number of features in each dataset: 2605
Number of rows in the training dataset: 33360
Number of rows in the test dataset: 3707


In [4]:
from datetime import datetime

import torch
from torch.utils.tensorboard import SummaryWriter

from methylVA.mnist.model import VAE
from methylVA.mnist.training import train, test

input_dim = data_batch.shape[1]
learning_rate = 1e-3
weight_decay = 1e-2
num_epochs = 100
latent_dim = 32
hidden_dim = 2048
kl_weight = 1.0
name = f'VAE_methyl_data_{data_id}_latent_{latent_dim}_kl_{kl_weight}'



In [5]:

writer_train = SummaryWriter(f'../experiments/{name}/train/{datetime.now().strftime("%Y%m%d-%H%M%S")}')
writer_test = SummaryWriter(f'../experiments/{name}/test/{datetime.now().strftime("%Y%m%d-%H%M%S")}')

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = VAE(input_dim=input_dim, latent_dim=latent_dim, hidden_dim=hidden_dim).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)

In [7]:
model

VAE(
  (encoder): Sequential(
    (0): Linear(in_features=2605, out_features=2048, bias=True)
    (1): SiLU()
    (2): Linear(in_features=2048, out_features=1024, bias=True)
    (3): SiLU()
    (4): Linear(in_features=1024, out_features=512, bias=True)
    (5): SiLU()
    (6): Linear(in_features=512, out_features=256, bias=True)
    (7): SiLU()
    (8): Linear(in_features=256, out_features=64, bias=True)
  )
  (softplus): Softplus(beta=1.0, threshold=20.0)
  (decoder): Sequential(
    (0): Linear(in_features=32, out_features=256, bias=True)
    (1): SiLU()
    (2): Linear(in_features=256, out_features=512, bias=True)
    (3): SiLU()
    (4): Linear(in_features=512, out_features=1024, bias=True)
    (5): SiLU()
    (6): Linear(in_features=1024, out_features=2048, bias=True)
    (7): SiLU()
    (8): Linear(in_features=2048, out_features=2605, bias=True)
    (9): Sigmoid()
  )
)

In [8]:
from methylVA.mnist.training import train, test


prev_updates = 0
for epoch in range(num_epochs):
    print(f'Epoch {epoch + 1}/{num_epochs}')
    prev_updates = train(model, train_loader, optimizer, prev_updates, writer=writer_train)
    test(model, test_loader, prev_updates, writer=writer_test)

Epoch 1/100


  2%|▏         | 4/261 [00:00<00:25, 10.01it/s]

Step 0, (N samples: 0), Loss: 1808.9576, (Recon: 1805.5591, KLD: 3.3986), Gradient norm: 8.7981


 40%|███▉      | 104/261 [00:03<00:04, 31.89it/s]

Step 100, (N samples: 12,800), Loss: 1176.6036, (Recon: 1167.6359, KLD: 8.9678), Gradient norm: 164.2641


 78%|███████▊  | 204/261 [00:06<00:01, 31.95it/s]

Step 200, (N samples: 25,600), Loss: 1184.8728, (Recon: 1177.5050, KLD: 7.3678), Gradient norm: 201.4290


100%|██████████| 261/261 [00:08<00:00, 30.61it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.38it/s]


====> Test set loss: 1164.8302, (BCE: 1157.5696, KLD: 7.2606)
Epoch 2/100


 18%|█▊        | 46/261 [00:01<00:06, 31.62it/s]

Step 300, (N samples: 38,400), Loss: 1129.2145, (Recon: 1121.0381, KLD: 8.1764), Gradient norm: 198.4668


 56%|█████▌    | 146/261 [00:04<00:03, 31.77it/s]

Step 400, (N samples: 51,200), Loss: 1166.5367, (Recon: 1157.8538, KLD: 8.6830), Gradient norm: 95.9660


 94%|█████████▍| 246/261 [00:07<00:00, 31.82it/s]

Step 500, (N samples: 64,000), Loss: 1152.6552, (Recon: 1144.2500, KLD: 8.4051), Gradient norm: 143.7696


100%|██████████| 261/261 [00:08<00:00, 31.66it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.26it/s]


====> Test set loss: 1131.2259, (BCE: 1123.2659, KLD: 7.9600)
Epoch 3/100


 31%|███▏      | 82/261 [00:02<00:05, 31.03it/s]

Step 600, (N samples: 76,800), Loss: 1141.0159, (Recon: 1133.9119, KLD: 7.1040), Gradient norm: 76.4691


 70%|██████▉   | 182/261 [00:05<00:02, 31.64it/s]

Step 700, (N samples: 89,600), Loss: 1115.6455, (Recon: 1107.8289, KLD: 7.8167), Gradient norm: 130.2562


100%|██████████| 261/261 [00:08<00:00, 31.49it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.21it/s]


====> Test set loss: 1114.5670, (BCE: 1106.5743, KLD: 7.9928)
Epoch 4/100


  8%|▊         | 22/261 [00:00<00:07, 31.06it/s]

Step 800, (N samples: 102,400), Loss: 1157.1346, (Recon: 1149.1667, KLD: 7.9679), Gradient norm: 332.7639


 47%|████▋     | 122/261 [00:03<00:04, 31.77it/s]

Step 900, (N samples: 115,200), Loss: 1156.7570, (Recon: 1148.5571, KLD: 8.1999), Gradient norm: 262.7542


 85%|████████▌ | 222/261 [00:07<00:01, 31.70it/s]

Step 1,000, (N samples: 128,000), Loss: 1083.2399, (Recon: 1074.6145, KLD: 8.6254), Gradient norm: 151.0644


100%|██████████| 261/261 [00:08<00:00, 31.64it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.49it/s]


====> Test set loss: 1107.5741, (BCE: 1099.7691, KLD: 7.8050)
Epoch 5/100


 24%|██▍       | 62/261 [00:01<00:06, 31.96it/s]

Step 1,100, (N samples: 140,800), Loss: 1157.9989, (Recon: 1149.3650, KLD: 8.6340), Gradient norm: 250.5745


 62%|██████▏   | 162/261 [00:05<00:03, 31.70it/s]

Step 1,200, (N samples: 153,600), Loss: 1063.0671, (Recon: 1055.3125, KLD: 7.7546), Gradient norm: 107.6011


100%|██████████| 261/261 [00:08<00:00, 31.68it/s]


Step 1,300, (N samples: 166,400), Loss: 1109.6223, (Recon: 1101.2014, KLD: 8.4209), Gradient norm: 180.5019


Testing: 100%|██████████| 29/29 [00:00<00:00, 36.24it/s]


====> Test set loss: 1103.7831, (BCE: 1095.5488, KLD: 8.2343)
Epoch 6/100


 39%|███▉      | 102/261 [00:03<00:04, 31.95it/s]

Step 1,400, (N samples: 179,200), Loss: 1066.4933, (Recon: 1058.0835, KLD: 8.4097), Gradient norm: 262.1366


 77%|███████▋  | 202/261 [00:06<00:01, 31.90it/s]

Step 1,500, (N samples: 192,000), Loss: 1093.3066, (Recon: 1085.3350, KLD: 7.9717), Gradient norm: 238.7835


100%|██████████| 261/261 [00:08<00:00, 31.72it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.40it/s]


====> Test set loss: 1099.6348, (BCE: 1091.4121, KLD: 8.2227)
Epoch 7/100


 15%|█▍        | 38/261 [00:01<00:11, 18.77it/s]

Step 1,600, (N samples: 204,800), Loss: 1085.5027, (Recon: 1077.0054, KLD: 8.4973), Gradient norm: 167.9048


 54%|█████▍    | 141/261 [00:05<00:03, 31.94it/s]

Step 1,700, (N samples: 217,600), Loss: 1061.4531, (Recon: 1053.0554, KLD: 8.3977), Gradient norm: 126.8604


 92%|█████████▏| 241/261 [00:08<00:00, 31.84it/s]

Step 1,800, (N samples: 230,400), Loss: 1064.2242, (Recon: 1055.9874, KLD: 8.2368), Gradient norm: 121.6729


100%|██████████| 261/261 [00:08<00:00, 29.79it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.22it/s]


====> Test set loss: 1096.7883, (BCE: 1088.4895, KLD: 8.2988)
Epoch 8/100


 30%|███       | 79/261 [00:02<00:05, 31.91it/s]

Step 1,900, (N samples: 243,200), Loss: 1149.1608, (Recon: 1140.5195, KLD: 8.6413), Gradient norm: 166.4810


 69%|██████▊   | 179/261 [00:05<00:02, 31.90it/s]

Step 2,000, (N samples: 256,000), Loss: 1113.8269, (Recon: 1105.3188, KLD: 8.5080), Gradient norm: 121.7407


100%|██████████| 261/261 [00:08<00:00, 31.59it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.31it/s]


====> Test set loss: 1094.3450, (BCE: 1085.9384, KLD: 8.4066)
Epoch 9/100


  7%|▋         | 18/261 [00:00<00:08, 30.14it/s]

Step 2,100, (N samples: 268,800), Loss: 1101.2528, (Recon: 1092.6870, KLD: 8.5658), Gradient norm: 101.1436


 45%|████▌     | 118/261 [00:03<00:04, 30.92it/s]

Step 2,200, (N samples: 281,600), Loss: 1076.0811, (Recon: 1067.7804, KLD: 8.3006), Gradient norm: 171.3490


 84%|████████▎ | 218/261 [00:07<00:01, 31.46it/s]

Step 2,300, (N samples: 294,400), Loss: 1073.1945, (Recon: 1064.9292, KLD: 8.2652), Gradient norm: 171.9806


100%|██████████| 261/261 [00:08<00:00, 30.90it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.36it/s]


====> Test set loss: 1091.4802, (BCE: 1083.0713, KLD: 8.4089)
Epoch 10/100


 22%|██▏       | 58/261 [00:01<00:06, 31.88it/s]

Step 2,400, (N samples: 307,200), Loss: 1129.8928, (Recon: 1121.2594, KLD: 8.6334), Gradient norm: 88.4636


 61%|██████    | 158/261 [00:04<00:03, 31.92it/s]

Step 2,500, (N samples: 320,000), Loss: 1099.6427, (Recon: 1091.1992, KLD: 8.4435), Gradient norm: 154.3376


 99%|█████████▉| 258/261 [00:08<00:00, 31.88it/s]

Step 2,600, (N samples: 332,800), Loss: 1105.4508, (Recon: 1097.0076, KLD: 8.4432), Gradient norm: 91.8854


100%|██████████| 261/261 [00:08<00:00, 31.75it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.33it/s]


====> Test set loss: 1091.3020, (BCE: 1082.7804, KLD: 8.5216)
Epoch 11/100


 36%|███▌      | 94/261 [00:02<00:05, 31.82it/s]

Step 2,700, (N samples: 345,600), Loss: 1104.3147, (Recon: 1096.1174, KLD: 8.1973), Gradient norm: 146.6124


 74%|███████▍  | 194/261 [00:06<00:02, 31.85it/s]

Step 2,800, (N samples: 358,400), Loss: 1074.4056, (Recon: 1065.6262, KLD: 8.7794), Gradient norm: 147.4809


100%|██████████| 261/261 [00:08<00:00, 31.68it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.15it/s]


====> Test set loss: 1087.6181, (BCE: 1078.9516, KLD: 8.6665)
Epoch 12/100


 13%|█▎        | 34/261 [00:01<00:07, 31.51it/s]

Step 2,900, (N samples: 371,200), Loss: 1116.1412, (Recon: 1107.4863, KLD: 8.6549), Gradient norm: 266.1220


 51%|█████▏    | 134/261 [00:04<00:03, 31.80it/s]

Step 3,000, (N samples: 384,000), Loss: 1085.3467, (Recon: 1076.5883, KLD: 8.7585), Gradient norm: 213.8395


 90%|████████▉ | 234/261 [00:07<00:00, 31.83it/s]

Step 3,100, (N samples: 396,800), Loss: 1047.9875, (Recon: 1039.5020, KLD: 8.4856), Gradient norm: 143.8784


100%|██████████| 261/261 [00:08<00:00, 31.65it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.41it/s]


====> Test set loss: 1086.9994, (BCE: 1078.3961, KLD: 8.6033)
Epoch 13/100


 28%|██▊       | 74/261 [00:02<00:05, 31.76it/s]

Step 3,200, (N samples: 409,600), Loss: 1079.2344, (Recon: 1070.3053, KLD: 8.9291), Gradient norm: 116.0502


 67%|██████▋   | 174/261 [00:05<00:02, 31.90it/s]

Step 3,300, (N samples: 422,400), Loss: 1109.2826, (Recon: 1100.1987, KLD: 9.0839), Gradient norm: 164.5036


100%|██████████| 261/261 [00:08<00:00, 31.69it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.78it/s]


====> Test set loss: 1083.7425, (BCE: 1075.1243, KLD: 8.6181)
Epoch 14/100


  5%|▌         | 14/261 [00:00<00:08, 29.71it/s]

Step 3,400, (N samples: 435,200), Loss: 1058.3530, (Recon: 1049.2025, KLD: 9.1505), Gradient norm: 159.7014


 44%|████▎     | 114/261 [00:03<00:04, 31.86it/s]

Step 3,500, (N samples: 448,000), Loss: 1107.5525, (Recon: 1098.6681, KLD: 8.8845), Gradient norm: 152.6606


 82%|████████▏ | 214/261 [00:06<00:01, 31.79it/s]

Step 3,600, (N samples: 460,800), Loss: 1028.9172, (Recon: 1020.1940, KLD: 8.7233), Gradient norm: 167.1744


100%|██████████| 261/261 [00:08<00:00, 31.72it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.20it/s]


====> Test set loss: 1084.0439, (BCE: 1075.3403, KLD: 8.7036)
Epoch 15/100


 19%|█▉        | 50/261 [00:01<00:06, 31.74it/s]

Step 3,700, (N samples: 473,600), Loss: 1080.3589, (Recon: 1071.8145, KLD: 8.5445), Gradient norm: 140.0326


 57%|█████▋    | 150/261 [00:04<00:03, 31.74it/s]

Step 3,800, (N samples: 486,400), Loss: 1125.9435, (Recon: 1117.0481, KLD: 8.8954), Gradient norm: 254.4804


 96%|█████████▌| 250/261 [00:07<00:00, 31.82it/s]

Step 3,900, (N samples: 499,200), Loss: 1088.8229, (Recon: 1080.2290, KLD: 8.5939), Gradient norm: 193.1263


100%|██████████| 261/261 [00:08<00:00, 31.70it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.57it/s]


====> Test set loss: 1082.6858, (BCE: 1073.9360, KLD: 8.7498)
Epoch 16/100


 34%|███▍      | 90/261 [00:02<00:05, 31.81it/s]

Step 4,000, (N samples: 512,000), Loss: 1104.2659, (Recon: 1095.5659, KLD: 8.6999), Gradient norm: 129.2834


 73%|███████▎  | 190/261 [00:06<00:02, 31.84it/s]

Step 4,100, (N samples: 524,800), Loss: 1050.7992, (Recon: 1042.3505, KLD: 8.4487), Gradient norm: 183.2830


100%|██████████| 261/261 [00:08<00:00, 31.66it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.59it/s]


====> Test set loss: 1081.3493, (BCE: 1072.6603, KLD: 8.6890)
Epoch 17/100


 11%|█▏        | 30/261 [00:00<00:07, 31.38it/s]

Step 4,200, (N samples: 537,600), Loss: 1098.0349, (Recon: 1089.4183, KLD: 8.6166), Gradient norm: 95.9175


 50%|████▉     | 130/261 [00:04<00:04, 31.81it/s]

Step 4,300, (N samples: 550,400), Loss: 1138.2327, (Recon: 1129.4148, KLD: 8.8179), Gradient norm: 228.0301


 88%|████████▊ | 230/261 [00:07<00:00, 31.85it/s]

Step 4,400, (N samples: 563,200), Loss: 1062.9972, (Recon: 1054.6055, KLD: 8.3918), Gradient norm: 187.3075


100%|██████████| 261/261 [00:08<00:00, 31.75it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.64it/s]


====> Test set loss: 1080.4330, (BCE: 1071.8694, KLD: 8.5635)
Epoch 18/100


 27%|██▋       | 70/261 [00:02<00:06, 31.78it/s]

Step 4,500, (N samples: 576,000), Loss: 1087.2975, (Recon: 1078.4695, KLD: 8.8280), Gradient norm: 176.0333


 65%|██████▌   | 170/261 [00:05<00:02, 31.80it/s]

Step 4,600, (N samples: 588,800), Loss: 1053.4048, (Recon: 1044.7689, KLD: 8.6359), Gradient norm: 262.4464


100%|██████████| 261/261 [00:08<00:00, 31.68it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.59it/s]


====> Test set loss: 1079.0889, (BCE: 1070.1728, KLD: 8.9161)
Epoch 19/100


  1%|          | 3/261 [00:00<00:11, 22.58it/s]

Step 4,700, (N samples: 601,600), Loss: 1117.2529, (Recon: 1107.9760, KLD: 9.2770), Gradient norm: 159.5842


 41%|████      | 106/261 [00:03<00:04, 31.87it/s]

Step 4,800, (N samples: 614,400), Loss: 1057.3187, (Recon: 1048.5215, KLD: 8.7973), Gradient norm: 203.0696


 79%|███████▉  | 206/261 [00:06<00:01, 31.80it/s]

Step 4,900, (N samples: 627,200), Loss: 1073.3871, (Recon: 1064.7321, KLD: 8.6550), Gradient norm: 117.1240


100%|██████████| 261/261 [00:08<00:00, 31.74it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.46it/s]


====> Test set loss: 1079.2755, (BCE: 1070.3683, KLD: 8.9072)
Epoch 20/100


 18%|█▊        | 46/261 [00:01<00:06, 31.89it/s]

Step 5,000, (N samples: 640,000), Loss: 1067.0367, (Recon: 1058.2207, KLD: 8.8161), Gradient norm: 108.0642


 56%|█████▌    | 145/261 [00:04<00:04, 28.51it/s]

Step 5,100, (N samples: 652,800), Loss: 1105.9298, (Recon: 1097.3271, KLD: 8.6027), Gradient norm: 171.7590


 94%|█████████▍| 246/261 [00:08<00:00, 31.95it/s]

Step 5,200, (N samples: 665,600), Loss: 1056.0176, (Recon: 1047.2493, KLD: 8.7683), Gradient norm: 237.1717


100%|██████████| 261/261 [00:08<00:00, 29.77it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.34it/s]


====> Test set loss: 1079.4657, (BCE: 1070.7673, KLD: 8.6984)
Epoch 21/100


 32%|███▏      | 84/261 [00:02<00:05, 30.88it/s]

Step 5,300, (N samples: 678,400), Loss: 1116.7343, (Recon: 1107.9536, KLD: 8.7806), Gradient norm: 178.4067


 70%|███████   | 184/261 [00:05<00:02, 31.78it/s]

Step 5,400, (N samples: 691,200), Loss: 1049.4299, (Recon: 1039.6096, KLD: 9.8203), Gradient norm: 312.2889


100%|██████████| 261/261 [00:08<00:00, 31.25it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.43it/s]


====> Test set loss: 1077.6824, (BCE: 1068.9015, KLD: 8.7809)
Epoch 22/100


 10%|▉         | 26/261 [00:00<00:07, 31.23it/s]

Step 5,500, (N samples: 704,000), Loss: 1059.6881, (Recon: 1051.3276, KLD: 8.3605), Gradient norm: 168.3038


 47%|████▋     | 122/261 [00:03<00:04, 31.77it/s]

Step 5,600, (N samples: 716,800), Loss: 1114.8624, (Recon: 1105.6926, KLD: 9.1697), Gradient norm: 192.6576


 87%|████████▋ | 226/261 [00:07<00:01, 31.76it/s]

Step 5,700, (N samples: 729,600), Loss: 1144.4174, (Recon: 1135.3921, KLD: 9.0252), Gradient norm: 165.1487


100%|██████████| 261/261 [00:08<00:00, 31.17it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.11it/s]


====> Test set loss: 1077.0969, (BCE: 1068.3782, KLD: 8.7187)
Epoch 23/100


 24%|██▍       | 62/261 [00:02<00:06, 31.25it/s]

Step 5,800, (N samples: 742,400), Loss: 1050.8704, (Recon: 1041.7567, KLD: 9.1136), Gradient norm: 166.2137


 62%|██████▏   | 162/261 [00:05<00:03, 31.40it/s]

Step 5,900, (N samples: 755,200), Loss: 1096.8896, (Recon: 1087.7922, KLD: 9.0974), Gradient norm: 153.1931


100%|██████████| 261/261 [00:08<00:00, 31.53it/s]


Step 6,000, (N samples: 768,000), Loss: 1103.5823, (Recon: 1095.1226, KLD: 8.4598), Gradient norm: 184.1568


Testing: 100%|██████████| 29/29 [00:00<00:00, 36.50it/s]


====> Test set loss: 1076.0025, (BCE: 1067.2525, KLD: 8.7499)
Epoch 24/100


 39%|███▉      | 102/261 [00:03<00:04, 32.07it/s]

Step 6,100, (N samples: 780,800), Loss: 1101.7297, (Recon: 1092.8296, KLD: 8.9002), Gradient norm: 167.5251


 77%|███████▋  | 202/261 [00:06<00:01, 30.97it/s]

Step 6,200, (N samples: 793,600), Loss: 1067.2596, (Recon: 1058.3826, KLD: 8.8771), Gradient norm: 146.3494


100%|██████████| 261/261 [00:08<00:00, 31.60it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.47it/s]


====> Test set loss: 1078.1970, (BCE: 1069.1184, KLD: 9.0787)
Epoch 25/100


 16%|█▌        | 42/261 [00:01<00:06, 31.66it/s]

Step 6,300, (N samples: 806,400), Loss: 1034.1930, (Recon: 1025.2772, KLD: 8.9158), Gradient norm: 194.3470


 54%|█████▍    | 142/261 [00:04<00:03, 31.82it/s]

Step 6,400, (N samples: 819,200), Loss: 1049.9574, (Recon: 1040.9358, KLD: 9.0216), Gradient norm: 550.3437


 93%|█████████▎| 242/261 [00:07<00:00, 31.73it/s]

Step 6,500, (N samples: 832,000), Loss: 1059.2546, (Recon: 1050.2332, KLD: 9.0215), Gradient norm: 133.6934


100%|██████████| 261/261 [00:08<00:00, 31.68it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.59it/s]


====> Test set loss: 1075.7082, (BCE: 1066.9411, KLD: 8.7671)
Epoch 26/100


 31%|███▏      | 82/261 [00:02<00:05, 31.97it/s]

Step 6,600, (N samples: 844,800), Loss: 1109.1047, (Recon: 1100.1062, KLD: 8.9986), Gradient norm: 226.4870


 70%|██████▉   | 182/261 [00:05<00:02, 32.00it/s]

Step 6,700, (N samples: 857,600), Loss: 1061.8248, (Recon: 1052.7739, KLD: 9.0509), Gradient norm: 135.9042


100%|██████████| 261/261 [00:08<00:00, 31.40it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.38it/s]


====> Test set loss: 1076.4542, (BCE: 1067.5465, KLD: 8.9077)
Epoch 27/100


  8%|▊         | 20/261 [00:00<00:09, 25.48it/s]

Step 6,800, (N samples: 870,400), Loss: 1060.6570, (Recon: 1051.4906, KLD: 9.1664), Gradient norm: 143.7711


 46%|████▌     | 120/261 [00:04<00:04, 31.94it/s]

Step 6,900, (N samples: 883,200), Loss: 1085.0663, (Recon: 1076.2144, KLD: 8.8519), Gradient norm: 138.8383


 84%|████████▍ | 220/261 [00:07<00:01, 31.90it/s]

Step 7,000, (N samples: 896,000), Loss: 1102.6660, (Recon: 1093.5929, KLD: 9.0731), Gradient norm: 254.8917


100%|██████████| 261/261 [00:08<00:00, 30.80it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 35.85it/s]


====> Test set loss: 1077.6979, (BCE: 1068.7322, KLD: 8.9658)
Epoch 28/100


 22%|██▏       | 58/261 [00:01<00:06, 31.89it/s]

Step 7,100, (N samples: 908,800), Loss: 1104.3092, (Recon: 1095.7390, KLD: 8.5702), Gradient norm: 292.7425


 61%|██████    | 158/261 [00:04<00:03, 31.90it/s]

Step 7,200, (N samples: 921,600), Loss: 1116.7540, (Recon: 1107.4912, KLD: 9.2628), Gradient norm: 199.7289


 99%|█████████▉| 258/261 [00:08<00:00, 31.91it/s]

Step 7,300, (N samples: 934,400), Loss: 1084.9531, (Recon: 1076.0194, KLD: 8.9337), Gradient norm: 123.3423


100%|██████████| 261/261 [00:08<00:00, 31.74it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.40it/s]


====> Test set loss: 1075.3850, (BCE: 1066.4563, KLD: 8.9287)
Epoch 29/100


 38%|███▊      | 98/261 [00:03<00:05, 31.89it/s]

Step 7,400, (N samples: 947,200), Loss: 1056.7124, (Recon: 1047.9006, KLD: 8.8118), Gradient norm: 230.1637


 76%|███████▌  | 198/261 [00:06<00:01, 31.79it/s]

Step 7,500, (N samples: 960,000), Loss: 1092.2080, (Recon: 1083.0573, KLD: 9.1508), Gradient norm: 153.6240


100%|██████████| 261/261 [00:08<00:00, 31.58it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.46it/s]


====> Test set loss: 1075.4310, (BCE: 1066.5150, KLD: 8.9160)
Epoch 30/100


 15%|█▍        | 38/261 [00:01<00:07, 31.71it/s]

Step 7,600, (N samples: 972,800), Loss: 1135.4301, (Recon: 1126.1812, KLD: 9.2489), Gradient norm: 175.5708


 53%|█████▎    | 138/261 [00:04<00:03, 31.98it/s]

Step 7,700, (N samples: 985,600), Loss: 1081.5917, (Recon: 1072.7717, KLD: 8.8200), Gradient norm: 157.6149


 91%|█████████ | 238/261 [00:07<00:00, 31.85it/s]

Step 7,800, (N samples: 998,400), Loss: 1077.9934, (Recon: 1069.0461, KLD: 8.9472), Gradient norm: 302.7595


100%|██████████| 261/261 [00:08<00:00, 31.78it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.40it/s]


====> Test set loss: 1074.7886, (BCE: 1065.8796, KLD: 8.9091)
Epoch 31/100


 28%|██▊       | 74/261 [00:02<00:05, 31.82it/s]

Step 7,900, (N samples: 1,011,200), Loss: 1041.0648, (Recon: 1032.3727, KLD: 8.6922), Gradient norm: 207.2799


 67%|██████▋   | 174/261 [00:05<00:02, 31.81it/s]

Step 8,000, (N samples: 1,024,000), Loss: 1101.0492, (Recon: 1092.1770, KLD: 8.8722), Gradient norm: 150.0440


100%|██████████| 261/261 [00:08<00:00, 31.68it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.36it/s]


====> Test set loss: 1074.0947, (BCE: 1065.0977, KLD: 8.9970)
Epoch 32/100


  5%|▌         | 14/261 [00:00<00:08, 29.79it/s]

Step 8,100, (N samples: 1,036,800), Loss: 1092.2311, (Recon: 1082.9172, KLD: 9.3138), Gradient norm: 234.5857


 44%|████▎     | 114/261 [00:03<00:04, 31.90it/s]

Step 8,200, (N samples: 1,049,600), Loss: 1027.7059, (Recon: 1019.0571, KLD: 8.6488), Gradient norm: 103.9045


 82%|████████▏ | 214/261 [00:06<00:01, 31.88it/s]

Step 8,300, (N samples: 1,062,400), Loss: 1071.2145, (Recon: 1062.5983, KLD: 8.6162), Gradient norm: 314.9172


100%|██████████| 261/261 [00:08<00:00, 31.72it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.60it/s]


====> Test set loss: 1074.6780, (BCE: 1065.7456, KLD: 8.9324)
Epoch 33/100


 21%|██        | 54/261 [00:01<00:06, 31.70it/s]

Step 8,400, (N samples: 1,075,200), Loss: 1001.7894, (Recon: 993.3480, KLD: 8.4414), Gradient norm: 98.2992


 59%|█████▉    | 154/261 [00:04<00:03, 31.61it/s]

Step 8,500, (N samples: 1,088,000), Loss: 1088.9679, (Recon: 1079.9916, KLD: 8.9763), Gradient norm: 326.3835


 97%|█████████▋| 254/261 [00:08<00:00, 31.78it/s]

Step 8,600, (N samples: 1,100,800), Loss: 1088.8702, (Recon: 1079.9332, KLD: 8.9371), Gradient norm: 151.8610


100%|██████████| 261/261 [00:08<00:00, 31.04it/s]
Testing: 100%|██████████| 29/29 [00:01<00:00, 25.95it/s]


====> Test set loss: 1072.6237, (BCE: 1063.6226, KLD: 9.0011)
Epoch 34/100


 36%|███▌      | 94/261 [00:02<00:05, 31.97it/s]

Step 8,700, (N samples: 1,113,600), Loss: 1043.7521, (Recon: 1035.0605, KLD: 8.6916), Gradient norm: 170.1803


 74%|███████▍  | 194/261 [00:06<00:02, 31.97it/s]

Step 8,800, (N samples: 1,126,400), Loss: 1132.7764, (Recon: 1124.0554, KLD: 8.7209), Gradient norm: 229.2586


100%|██████████| 261/261 [00:08<00:00, 31.31it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.51it/s]


====> Test set loss: 1074.3475, (BCE: 1065.4072, KLD: 8.9403)
Epoch 35/100


 11%|█▏        | 30/261 [00:00<00:07, 31.36it/s]

Step 8,900, (N samples: 1,139,200), Loss: 1079.9116, (Recon: 1070.9171, KLD: 8.9945), Gradient norm: 292.5820


 50%|████▉     | 130/261 [00:04<00:04, 31.75it/s]

Step 9,000, (N samples: 1,152,000), Loss: 1082.2531, (Recon: 1073.0352, KLD: 9.2179), Gradient norm: 167.2822


 88%|████████▊ | 230/261 [00:07<00:00, 31.79it/s]

Step 9,100, (N samples: 1,164,800), Loss: 1050.6488, (Recon: 1041.4905, KLD: 9.1584), Gradient norm: 176.8391


100%|██████████| 261/261 [00:08<00:00, 31.63it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.58it/s]


====> Test set loss: 1073.9066, (BCE: 1064.9600, KLD: 8.9466)
Epoch 36/100


 27%|██▋       | 70/261 [00:02<00:06, 31.77it/s]

Step 9,200, (N samples: 1,177,600), Loss: 1050.7412, (Recon: 1041.6105, KLD: 9.1308), Gradient norm: 125.0764


 65%|██████▌   | 170/261 [00:05<00:02, 31.83it/s]

Step 9,300, (N samples: 1,190,400), Loss: 1096.9641, (Recon: 1088.3716, KLD: 8.5925), Gradient norm: 203.8297


100%|██████████| 261/261 [00:08<00:00, 31.07it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.50it/s]


====> Test set loss: 1074.1699, (BCE: 1065.1419, KLD: 9.0280)
Epoch 37/100


  4%|▍         | 10/261 [00:00<00:08, 28.38it/s]

Step 9,400, (N samples: 1,203,200), Loss: 1066.2576, (Recon: 1056.8123, KLD: 9.4453), Gradient norm: 250.8930


 42%|████▏     | 110/261 [00:03<00:04, 31.90it/s]

Step 9,500, (N samples: 1,216,000), Loss: 1060.2921, (Recon: 1050.8772, KLD: 9.4150), Gradient norm: 220.4706


 80%|████████  | 210/261 [00:06<00:01, 31.79it/s]

Step 9,600, (N samples: 1,228,800), Loss: 1053.6338, (Recon: 1044.8823, KLD: 8.7514), Gradient norm: 149.0627


100%|██████████| 261/261 [00:08<00:00, 31.68it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.47it/s]


====> Test set loss: 1072.9991, (BCE: 1064.3190, KLD: 8.6801)
Epoch 38/100


 19%|█▉        | 50/261 [00:01<00:06, 31.46it/s]

Step 9,700, (N samples: 1,241,600), Loss: 1078.5841, (Recon: 1069.0881, KLD: 9.4959), Gradient norm: 140.2286


 57%|█████▋    | 150/261 [00:04<00:03, 31.93it/s]

Step 9,800, (N samples: 1,254,400), Loss: 1090.8679, (Recon: 1081.7495, KLD: 9.1184), Gradient norm: 187.1757


 96%|█████████▌| 250/261 [00:07<00:00, 31.95it/s]

Step 9,900, (N samples: 1,267,200), Loss: 1065.4265, (Recon: 1056.3381, KLD: 9.0884), Gradient norm: 143.3500


100%|██████████| 261/261 [00:08<00:00, 31.73it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.35it/s]


====> Test set loss: 1072.8065, (BCE: 1063.9801, KLD: 8.8264)
Epoch 39/100


 33%|███▎      | 86/261 [00:02<00:05, 31.78it/s]

Step 10,000, (N samples: 1,280,000), Loss: 1072.2682, (Recon: 1063.3783, KLD: 8.8899), Gradient norm: 322.5131


 71%|███████▏  | 186/261 [00:05<00:02, 31.82it/s]

Step 10,100, (N samples: 1,292,800), Loss: 1091.4187, (Recon: 1082.7457, KLD: 8.6730), Gradient norm: 146.5891


100%|██████████| 261/261 [00:08<00:00, 31.72it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.16it/s]


====> Test set loss: 1071.8884, (BCE: 1063.0089, KLD: 8.8794)
Epoch 40/100


 10%|▉         | 25/261 [00:00<00:07, 30.41it/s]

Step 10,200, (N samples: 1,305,600), Loss: 1039.3523, (Recon: 1029.9917, KLD: 9.3606), Gradient norm: 163.9115


 48%|████▊     | 125/261 [00:03<00:04, 31.87it/s]

Step 10,300, (N samples: 1,318,400), Loss: 1109.9847, (Recon: 1100.9027, KLD: 9.0821), Gradient norm: 181.5310


 86%|████████▌ | 225/261 [00:07<00:01, 31.91it/s]

Step 10,400, (N samples: 1,331,200), Loss: 1152.5388, (Recon: 1143.2107, KLD: 9.3281), Gradient norm: 288.2010


100%|██████████| 261/261 [00:08<00:00, 31.68it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.64it/s]


====> Test set loss: 1072.2656, (BCE: 1063.2503, KLD: 9.0152)
Epoch 41/100


 25%|██▌       | 66/261 [00:02<00:06, 31.29it/s]

Step 10,500, (N samples: 1,344,000), Loss: 1038.7324, (Recon: 1029.7092, KLD: 9.0232), Gradient norm: 194.8498


 64%|██████▎   | 166/261 [00:05<00:02, 32.04it/s]

Step 10,600, (N samples: 1,356,800), Loss: 1090.9749, (Recon: 1082.0867, KLD: 8.8882), Gradient norm: 211.9373


100%|██████████| 261/261 [00:08<00:00, 31.74it/s]


Step 10,700, (N samples: 1,369,600), Loss: 990.0766, (Recon: 981.1561, KLD: 8.9205), Gradient norm: 193.9620


Testing: 100%|██████████| 29/29 [00:00<00:00, 36.62it/s]


====> Test set loss: 1072.0642, (BCE: 1063.1595, KLD: 8.9047)
Epoch 42/100


 41%|████      | 106/261 [00:03<00:04, 31.57it/s]

Step 10,800, (N samples: 1,382,400), Loss: 1062.7881, (Recon: 1054.0752, KLD: 8.7129), Gradient norm: 230.0238


 79%|███████▉  | 206/261 [00:06<00:01, 31.73it/s]

Step 10,900, (N samples: 1,395,200), Loss: 1041.3575, (Recon: 1032.0831, KLD: 9.2744), Gradient norm: 198.1921


100%|██████████| 261/261 [00:08<00:00, 31.69it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.75it/s]


====> Test set loss: 1072.3984, (BCE: 1063.4611, KLD: 8.9373)
Epoch 43/100


 16%|█▌        | 42/261 [00:01<00:06, 31.71it/s]

Step 11,000, (N samples: 1,408,000), Loss: 1045.6610, (Recon: 1036.8831, KLD: 8.7780), Gradient norm: 229.2937


 54%|█████▍    | 142/261 [00:04<00:03, 31.77it/s]

Step 11,100, (N samples: 1,420,800), Loss: 1043.4899, (Recon: 1034.4930, KLD: 8.9968), Gradient norm: 233.6545


 93%|█████████▎| 242/261 [00:07<00:00, 31.98it/s]

Step 11,200, (N samples: 1,433,600), Loss: 1065.9470, (Recon: 1056.6298, KLD: 9.3172), Gradient norm: 290.0391


100%|██████████| 261/261 [00:08<00:00, 31.72it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.22it/s]


====> Test set loss: 1071.0077, (BCE: 1062.0331, KLD: 8.9746)
Epoch 44/100


 31%|███▏      | 82/261 [00:02<00:05, 31.87it/s]

Step 11,300, (N samples: 1,446,400), Loss: 1092.4434, (Recon: 1083.4548, KLD: 8.9885), Gradient norm: 251.6033


 70%|██████▉   | 182/261 [00:05<00:02, 31.87it/s]

Step 11,400, (N samples: 1,459,200), Loss: 1112.2678, (Recon: 1103.4642, KLD: 8.8036), Gradient norm: 193.1708


100%|██████████| 261/261 [00:08<00:00, 31.68it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.46it/s]


====> Test set loss: 1071.2679, (BCE: 1062.1720, KLD: 9.0958)
Epoch 45/100


  7%|▋         | 18/261 [00:00<00:11, 21.03it/s]

Step 11,500, (N samples: 1,472,000), Loss: 985.3007, (Recon: 976.3895, KLD: 8.9111), Gradient norm: 197.1888


 46%|████▋     | 121/261 [00:05<00:04, 31.63it/s]

Step 11,600, (N samples: 1,484,800), Loss: 1085.3485, (Recon: 1076.1252, KLD: 9.2233), Gradient norm: 139.3244


 85%|████████▍ | 221/261 [00:08<00:01, 31.85it/s]

Step 11,700, (N samples: 1,497,600), Loss: 1041.6643, (Recon: 1032.4960, KLD: 9.1684), Gradient norm: 540.3185


100%|██████████| 261/261 [00:09<00:00, 26.25it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.52it/s]


====> Test set loss: 1070.2065, (BCE: 1061.1072, KLD: 9.0993)
Epoch 46/100


 24%|██▍       | 62/261 [00:01<00:06, 31.71it/s]

Step 11,800, (N samples: 1,510,400), Loss: 1084.4810, (Recon: 1075.6582, KLD: 8.8227), Gradient norm: 251.8497


 62%|██████▏   | 162/261 [00:05<00:03, 31.29it/s]

Step 11,900, (N samples: 1,523,200), Loss: 1131.3888, (Recon: 1122.3313, KLD: 9.0575), Gradient norm: 203.9381


100%|██████████| 261/261 [00:08<00:00, 31.52it/s]


Step 12,000, (N samples: 1,536,000), Loss: 1070.7761, (Recon: 1061.6195, KLD: 9.1566), Gradient norm: 220.6975


Testing: 100%|██████████| 29/29 [00:00<00:00, 36.29it/s]


====> Test set loss: 1070.4735, (BCE: 1061.4523, KLD: 9.0212)
Epoch 47/100


 38%|███▊      | 99/261 [00:03<00:05, 31.60it/s]

Step 12,100, (N samples: 1,548,800), Loss: 1066.0414, (Recon: 1057.0608, KLD: 8.9806), Gradient norm: 174.2094


 76%|███████▌  | 199/261 [00:06<00:01, 31.88it/s]

Step 12,200, (N samples: 1,561,600), Loss: 1064.5482, (Recon: 1055.4347, KLD: 9.1136), Gradient norm: 329.5057


100%|██████████| 261/261 [00:08<00:00, 29.96it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.55it/s]


====> Test set loss: 1073.0502, (BCE: 1064.0597, KLD: 8.9905)
Epoch 48/100


 15%|█▍        | 38/261 [00:01<00:07, 31.72it/s]

Step 12,300, (N samples: 1,574,400), Loss: 1055.2490, (Recon: 1046.3776, KLD: 8.8715), Gradient norm: 508.6112


 53%|█████▎    | 138/261 [00:04<00:03, 31.80it/s]

Step 12,400, (N samples: 1,587,200), Loss: 1058.6774, (Recon: 1049.5430, KLD: 9.1344), Gradient norm: 339.3650


 91%|█████████ | 238/261 [00:07<00:00, 31.93it/s]

Step 12,500, (N samples: 1,600,000), Loss: 1051.1663, (Recon: 1041.9976, KLD: 9.1688), Gradient norm: 266.4609


100%|██████████| 261/261 [00:08<00:00, 31.74it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.35it/s]


====> Test set loss: 1070.7759, (BCE: 1061.8195, KLD: 8.9565)
Epoch 49/100


 30%|██▉       | 78/261 [00:02<00:05, 31.73it/s]

Step 12,600, (N samples: 1,612,800), Loss: 1036.4352, (Recon: 1027.4930, KLD: 8.9421), Gradient norm: 181.7664


 68%|██████▊   | 178/261 [00:05<00:02, 31.87it/s]

Step 12,700, (N samples: 1,625,600), Loss: 1086.0581, (Recon: 1076.9453, KLD: 9.1128), Gradient norm: 330.0681


100%|██████████| 261/261 [00:08<00:00, 31.64it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.26it/s]


====> Test set loss: 1070.2092, (BCE: 1061.2564, KLD: 8.9528)
Epoch 50/100


  7%|▋         | 18/261 [00:00<00:07, 30.46it/s]

Step 12,800, (N samples: 1,638,400), Loss: 1049.7051, (Recon: 1040.6055, KLD: 9.0996), Gradient norm: 215.6988


 45%|████▌     | 118/261 [00:03<00:04, 31.65it/s]

Step 12,900, (N samples: 1,651,200), Loss: 1098.2366, (Recon: 1089.2369, KLD: 8.9997), Gradient norm: 449.5430


 84%|████████▎ | 218/261 [00:07<00:01, 31.85it/s]

Step 13,000, (N samples: 1,664,000), Loss: 1064.3289, (Recon: 1055.6039, KLD: 8.7250), Gradient norm: 154.8548


100%|██████████| 261/261 [00:08<00:00, 30.89it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.52it/s]


====> Test set loss: 1070.5463, (BCE: 1061.5417, KLD: 9.0046)
Epoch 51/100


 21%|██        | 54/261 [00:01<00:06, 31.73it/s]

Step 13,100, (N samples: 1,676,800), Loss: 1038.3499, (Recon: 1029.5286, KLD: 8.8213), Gradient norm: 231.6018


 59%|█████▉    | 154/261 [00:04<00:03, 31.82it/s]

Step 13,200, (N samples: 1,689,600), Loss: 1017.3072, (Recon: 1007.8790, KLD: 9.4282), Gradient norm: 231.8264


 97%|█████████▋| 254/261 [00:07<00:00, 31.87it/s]

Step 13,300, (N samples: 1,702,400), Loss: 1083.0985, (Recon: 1073.9452, KLD: 9.1533), Gradient norm: 260.4849


100%|██████████| 261/261 [00:08<00:00, 31.76it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.57it/s]


====> Test set loss: 1070.0528, (BCE: 1060.8625, KLD: 9.1903)
Epoch 52/100


 36%|███▌      | 94/261 [00:02<00:05, 31.99it/s]

Step 13,400, (N samples: 1,715,200), Loss: 1072.6967, (Recon: 1063.6010, KLD: 9.0957), Gradient norm: 294.9614


 74%|███████▍  | 194/261 [00:06<00:02, 31.70it/s]

Step 13,500, (N samples: 1,728,000), Loss: 1044.9006, (Recon: 1035.9237, KLD: 8.9769), Gradient norm: 132.9278


100%|██████████| 261/261 [00:08<00:00, 31.71it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.47it/s]


====> Test set loss: 1069.2125, (BCE: 1060.0618, KLD: 9.1506)
Epoch 53/100


 13%|█▎        | 34/261 [00:01<00:07, 31.59it/s]

Step 13,600, (N samples: 1,740,800), Loss: 1065.8955, (Recon: 1056.7617, KLD: 9.1337), Gradient norm: 257.4120


 51%|█████▏    | 134/261 [00:04<00:03, 32.01it/s]

Step 13,700, (N samples: 1,753,600), Loss: 1065.8219, (Recon: 1056.9551, KLD: 8.8668), Gradient norm: 253.8530


 90%|████████▉ | 234/261 [00:07<00:00, 31.58it/s]

Step 13,800, (N samples: 1,766,400), Loss: 1045.4846, (Recon: 1036.0933, KLD: 9.3914), Gradient norm: 254.9091


100%|██████████| 261/261 [00:08<00:00, 31.69it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.61it/s]


====> Test set loss: 1069.0988, (BCE: 1059.9866, KLD: 9.1122)
Epoch 54/100


 27%|██▋       | 70/261 [00:02<00:06, 30.91it/s]

Step 13,900, (N samples: 1,779,200), Loss: 1042.3759, (Recon: 1033.3044, KLD: 9.0714), Gradient norm: 227.0981


 67%|██████▋   | 174/261 [00:05<00:02, 31.92it/s]

Step 14,000, (N samples: 1,792,000), Loss: 1020.4737, (Recon: 1011.5286, KLD: 8.9451), Gradient norm: 181.4224


100%|██████████| 261/261 [00:08<00:00, 31.77it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.69it/s]


====> Test set loss: 1069.3938, (BCE: 1060.4309, KLD: 8.9629)
Epoch 55/100


  4%|▍         | 10/261 [00:00<00:08, 28.26it/s]

Step 14,100, (N samples: 1,804,800), Loss: 1083.8696, (Recon: 1074.7798, KLD: 9.0898), Gradient norm: 276.0213


 42%|████▏     | 110/261 [00:03<00:04, 31.24it/s]

Step 14,200, (N samples: 1,817,600), Loss: 1068.4312, (Recon: 1059.5380, KLD: 8.8932), Gradient norm: 139.8832


 80%|████████  | 210/261 [00:06<00:01, 31.79it/s]

Step 14,300, (N samples: 1,830,400), Loss: 1019.1329, (Recon: 1010.2322, KLD: 8.9007), Gradient norm: 218.3954


100%|██████████| 261/261 [00:09<00:00, 27.69it/s]
Testing: 100%|██████████| 29/29 [00:01<00:00, 21.53it/s]


====> Test set loss: 1068.8975, (BCE: 1059.8314, KLD: 9.0660)
Epoch 56/100


 20%|█▉        | 52/261 [00:02<00:06, 31.38it/s]

Step 14,400, (N samples: 1,843,200), Loss: 1130.6160, (Recon: 1121.2710, KLD: 9.3450), Gradient norm: 201.7571


 58%|█████▊    | 152/261 [00:05<00:03, 31.73it/s]

Step 14,500, (N samples: 1,856,000), Loss: 1069.6879, (Recon: 1060.5691, KLD: 9.1187), Gradient norm: 195.9717


 97%|█████████▋| 252/261 [00:08<00:00, 31.99it/s]

Step 14,600, (N samples: 1,868,800), Loss: 1055.6035, (Recon: 1046.4430, KLD: 9.1605), Gradient norm: 225.1021


100%|██████████| 261/261 [00:08<00:00, 30.39it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.51it/s]


====> Test set loss: 1068.7670, (BCE: 1059.5918, KLD: 9.1752)
Epoch 57/100


 34%|███▎      | 88/261 [00:04<00:09, 17.39it/s]

Step 14,700, (N samples: 1,881,600), Loss: 1042.4529, (Recon: 1033.3806, KLD: 9.0722), Gradient norm: 169.0845


 72%|███████▏  | 189/261 [00:08<00:02, 31.77it/s]

Step 14,800, (N samples: 1,894,400), Loss: 1057.2838, (Recon: 1048.0721, KLD: 9.2116), Gradient norm: 134.7793


100%|██████████| 261/261 [00:10<00:00, 25.34it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 35.67it/s]


====> Test set loss: 1069.5911, (BCE: 1060.4908, KLD: 9.1003)
Epoch 58/100


 11%|█▏        | 30/261 [00:00<00:07, 31.31it/s]

Step 14,900, (N samples: 1,907,200), Loss: 1072.8734, (Recon: 1063.6199, KLD: 9.2536), Gradient norm: 144.3749


 48%|████▊     | 126/261 [00:05<00:06, 19.36it/s]

Step 15,000, (N samples: 1,920,000), Loss: 1008.3265, (Recon: 999.1052, KLD: 9.2214), Gradient norm: 244.3560


 87%|████████▋ | 226/261 [00:10<00:01, 19.45it/s]

Step 15,100, (N samples: 1,932,800), Loss: 1059.1986, (Recon: 1050.0288, KLD: 9.1698), Gradient norm: 315.7866


100%|██████████| 261/261 [00:12<00:00, 20.87it/s]
Testing: 100%|██████████| 29/29 [00:01<00:00, 23.86it/s]


====> Test set loss: 1068.6362, (BCE: 1059.5167, KLD: 9.1194)
Epoch 59/100


 25%|██▍       | 65/261 [00:03<00:10, 19.34it/s]

Step 15,200, (N samples: 1,945,600), Loss: 1102.8134, (Recon: 1093.6469, KLD: 9.1665), Gradient norm: 213.8651


 63%|██████▎   | 165/261 [00:09<00:05, 18.96it/s]

Step 15,300, (N samples: 1,958,400), Loss: 1061.6556, (Recon: 1052.4771, KLD: 9.1786), Gradient norm: 144.1103


100%|██████████| 261/261 [00:14<00:00, 18.33it/s]
Testing: 100%|██████████| 29/29 [00:01<00:00, 23.36it/s]


====> Test set loss: 1067.9330, (BCE: 1058.7238, KLD: 9.2093)
Epoch 60/100


  1%|          | 2/261 [00:00<00:13, 19.00it/s]

Step 15,400, (N samples: 1,971,200), Loss: 1139.5524, (Recon: 1130.1918, KLD: 9.3606), Gradient norm: 148.3319


 40%|███▉      | 104/261 [00:05<00:08, 19.04it/s]

Step 15,500, (N samples: 1,984,000), Loss: 1131.7290, (Recon: 1122.1396, KLD: 9.5893), Gradient norm: 415.5551


 78%|███████▊  | 204/261 [00:10<00:02, 19.16it/s]

Step 15,600, (N samples: 1,996,800), Loss: 1034.0214, (Recon: 1024.7852, KLD: 9.2362), Gradient norm: 266.5248


100%|██████████| 261/261 [00:13<00:00, 19.30it/s]
Testing: 100%|██████████| 29/29 [00:01<00:00, 22.22it/s]


====> Test set loss: 1069.2124, (BCE: 1060.0444, KLD: 9.1680)
Epoch 61/100


 16%|█▋        | 43/261 [00:02<00:11, 19.08it/s]

Step 15,700, (N samples: 2,009,600), Loss: 1039.9583, (Recon: 1031.0463, KLD: 8.9120), Gradient norm: 237.6914


 55%|█████▍    | 143/261 [00:07<00:06, 19.36it/s]

Step 15,800, (N samples: 2,022,400), Loss: 1075.5826, (Recon: 1065.9587, KLD: 9.6239), Gradient norm: 134.0361


 93%|█████████▎| 243/261 [00:12<00:00, 18.85it/s]

Step 15,900, (N samples: 2,035,200), Loss: 1024.0559, (Recon: 1015.2498, KLD: 8.8062), Gradient norm: 164.1609


100%|██████████| 261/261 [00:13<00:00, 19.00it/s]
Testing: 100%|██████████| 29/29 [00:01<00:00, 23.19it/s]


====> Test set loss: 1069.7414, (BCE: 1060.7533, KLD: 8.9882)
Epoch 62/100


 31%|███▏      | 82/261 [00:04<00:09, 19.26it/s]

Step 16,000, (N samples: 2,048,000), Loss: 1059.8409, (Recon: 1050.9833, KLD: 8.8576), Gradient norm: 292.9154


 71%|███████▏  | 186/261 [00:10<00:03, 23.52it/s]

Step 16,100, (N samples: 2,060,800), Loss: 1005.3193, (Recon: 996.5327, KLD: 8.7866), Gradient norm: 242.4932


100%|██████████| 261/261 [00:12<00:00, 20.83it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.13it/s]


====> Test set loss: 1068.6766, (BCE: 1059.7142, KLD: 8.9624)
Epoch 63/100


  8%|▊         | 22/261 [00:00<00:07, 30.99it/s]

Step 16,200, (N samples: 2,073,600), Loss: 1080.4774, (Recon: 1071.1025, KLD: 9.3749), Gradient norm: 316.6802


 47%|████▋     | 122/261 [00:03<00:04, 31.85it/s]

Step 16,300, (N samples: 2,086,400), Loss: 1066.1571, (Recon: 1057.0951, KLD: 9.0620), Gradient norm: 205.3755


 85%|████████▌ | 222/261 [00:07<00:01, 31.84it/s]

Step 16,400, (N samples: 2,099,200), Loss: 1102.8119, (Recon: 1093.4850, KLD: 9.3269), Gradient norm: 623.8082


100%|██████████| 261/261 [00:08<00:00, 31.05it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.47it/s]


====> Test set loss: 1069.6388, (BCE: 1060.5787, KLD: 9.0601)
Epoch 64/100


 24%|██▍       | 62/261 [00:01<00:06, 31.88it/s]

Step 16,500, (N samples: 2,112,000), Loss: 1125.5455, (Recon: 1116.2019, KLD: 9.3436), Gradient norm: 159.4859


 62%|██████▏   | 161/261 [00:05<00:04, 20.01it/s]

Step 16,600, (N samples: 2,124,800), Loss: 1056.5293, (Recon: 1047.6338, KLD: 8.8955), Gradient norm: 257.9051


 99%|█████████▉| 259/261 [00:11<00:00, 16.18it/s]

Step 16,700, (N samples: 2,137,600), Loss: 1028.8373, (Recon: 1019.9470, KLD: 8.8903), Gradient norm: 115.9873


100%|██████████| 261/261 [00:11<00:00, 22.66it/s]
Testing: 100%|██████████| 29/29 [00:01<00:00, 20.07it/s]


====> Test set loss: 1068.9403, (BCE: 1059.8870, KLD: 9.0534)
Epoch 65/100


 38%|███▊      | 99/261 [00:06<00:09, 16.23it/s]

Step 16,800, (N samples: 2,150,400), Loss: 1045.5773, (Recon: 1036.5085, KLD: 9.0687), Gradient norm: 370.8193


 76%|███████▌  | 199/261 [00:12<00:03, 16.50it/s]

Step 16,900, (N samples: 2,163,200), Loss: 1097.6626, (Recon: 1088.5035, KLD: 9.1590), Gradient norm: 210.7608


100%|██████████| 261/261 [00:15<00:00, 16.33it/s]
Testing: 100%|██████████| 29/29 [00:01<00:00, 19.27it/s]


====> Test set loss: 1068.5729, (BCE: 1059.4463, KLD: 9.1265)
Epoch 66/100


 15%|█▍        | 38/261 [00:02<00:13, 16.45it/s]

Step 17,000, (N samples: 2,176,000), Loss: 1090.9656, (Recon: 1081.7815, KLD: 9.1841), Gradient norm: 240.2179


 53%|█████▎    | 138/261 [00:08<00:07, 16.29it/s]

Step 17,100, (N samples: 2,188,800), Loss: 1090.8777, (Recon: 1081.5967, KLD: 9.2810), Gradient norm: 448.3353


 91%|█████████ | 238/261 [00:14<00:01, 16.16it/s]

Step 17,200, (N samples: 2,201,600), Loss: 1031.8108, (Recon: 1022.9659, KLD: 8.8449), Gradient norm: 263.1410


100%|██████████| 261/261 [00:16<00:00, 16.15it/s]
Testing: 100%|██████████| 29/29 [00:01<00:00, 19.92it/s]


====> Test set loss: 1068.3090, (BCE: 1059.0664, KLD: 9.2426)
Epoch 67/100


 30%|██▉       | 78/261 [00:04<00:11, 16.14it/s]

Step 17,300, (N samples: 2,214,400), Loss: 1068.1138, (Recon: 1058.8372, KLD: 9.2766), Gradient norm: 212.8619


 68%|██████▊   | 178/261 [00:10<00:04, 17.49it/s]

Step 17,400, (N samples: 2,227,200), Loss: 1092.2209, (Recon: 1082.7402, KLD: 9.4807), Gradient norm: 230.9608


100%|██████████| 261/261 [00:14<00:00, 18.55it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 35.99it/s]


====> Test set loss: 1070.0129, (BCE: 1060.8341, KLD: 9.1788)
Epoch 68/100


  7%|▋         | 18/261 [00:00<00:07, 30.39it/s]

Step 17,500, (N samples: 2,240,000), Loss: 1119.7089, (Recon: 1110.3105, KLD: 9.3983), Gradient norm: 313.1538


 45%|████▌     | 118/261 [00:03<00:04, 31.41it/s]

Step 17,600, (N samples: 2,252,800), Loss: 1011.7076, (Recon: 1002.3870, KLD: 9.3205), Gradient norm: 284.8948


 84%|████████▎ | 218/261 [00:06<00:01, 31.95it/s]

Step 17,700, (N samples: 2,265,600), Loss: 1089.0309, (Recon: 1079.9180, KLD: 9.1129), Gradient norm: 226.0367


100%|██████████| 261/261 [00:08<00:00, 31.62it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.62it/s]


====> Test set loss: 1067.6689, (BCE: 1058.5472, KLD: 9.1217)
Epoch 69/100


 22%|██▏       | 58/261 [00:01<00:06, 31.56it/s]

Step 17,800, (N samples: 2,278,400), Loss: 1069.4562, (Recon: 1060.2660, KLD: 9.1902), Gradient norm: 145.0503


 61%|██████    | 158/261 [00:05<00:03, 31.90it/s]

Step 17,900, (N samples: 2,291,200), Loss: 1023.8298, (Recon: 1014.6566, KLD: 9.1733), Gradient norm: 180.0557


 99%|█████████▉| 258/261 [00:08<00:00, 31.72it/s]

Step 18,000, (N samples: 2,304,000), Loss: 1086.5204, (Recon: 1077.3280, KLD: 9.1924), Gradient norm: 558.9226


100%|██████████| 261/261 [00:08<00:00, 31.62it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.73it/s]


====> Test set loss: 1068.6845, (BCE: 1059.5498, KLD: 9.1347)
Epoch 70/100


 38%|███▊      | 98/261 [00:03<00:05, 31.74it/s]

Step 18,100, (N samples: 2,316,800), Loss: 1038.2145, (Recon: 1029.1987, KLD: 9.0157), Gradient norm: 246.8073


 76%|███████▌  | 198/261 [00:06<00:01, 31.90it/s]

Step 18,200, (N samples: 2,329,600), Loss: 1080.6926, (Recon: 1071.6663, KLD: 9.0263), Gradient norm: 218.3039


100%|██████████| 261/261 [00:08<00:00, 31.55it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.51it/s]


====> Test set loss: 1068.2524, (BCE: 1059.2361, KLD: 9.0163)
Epoch 71/100


 14%|█▍        | 37/261 [00:01<00:07, 31.61it/s]

Step 18,300, (N samples: 2,342,400), Loss: 1011.5510, (Recon: 1002.7725, KLD: 8.7786), Gradient norm: 214.2093


 52%|█████▏    | 137/261 [00:04<00:03, 31.40it/s]

Step 18,400, (N samples: 2,355,200), Loss: 1054.5007, (Recon: 1045.3542, KLD: 9.1465), Gradient norm: 210.4111


 91%|█████████ | 237/261 [00:07<00:00, 31.30it/s]

Step 18,500, (N samples: 2,368,000), Loss: 1055.6189, (Recon: 1046.2504, KLD: 9.3685), Gradient norm: 235.1565


100%|██████████| 261/261 [00:08<00:00, 31.34it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.24it/s]


====> Test set loss: 1067.6590, (BCE: 1058.4311, KLD: 9.2280)
Epoch 72/100


 28%|██▊       | 74/261 [00:02<00:05, 31.83it/s]

Step 18,600, (N samples: 2,380,800), Loss: 1047.7050, (Recon: 1038.6627, KLD: 9.0422), Gradient norm: 224.8198


 67%|██████▋   | 174/261 [00:05<00:02, 31.73it/s]

Step 18,700, (N samples: 2,393,600), Loss: 1026.6772, (Recon: 1017.6097, KLD: 9.0676), Gradient norm: 192.5173


100%|██████████| 261/261 [00:08<00:00, 31.64it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.10it/s]


====> Test set loss: 1068.2318, (BCE: 1059.2125, KLD: 9.0193)
Epoch 73/100


  5%|▌         | 14/261 [00:00<00:08, 29.76it/s]

Step 18,800, (N samples: 2,406,400), Loss: 1065.7975, (Recon: 1056.6528, KLD: 9.1447), Gradient norm: 287.8217


 44%|████▎     | 114/261 [00:03<00:04, 31.88it/s]

Step 18,900, (N samples: 2,419,200), Loss: 1050.5475, (Recon: 1041.5427, KLD: 9.0047), Gradient norm: 177.9831


 82%|████████▏ | 214/261 [00:06<00:01, 31.91it/s]

Step 19,000, (N samples: 2,432,000), Loss: 1054.2200, (Recon: 1045.2469, KLD: 8.9730), Gradient norm: 205.2437


100%|██████████| 261/261 [00:08<00:00, 31.62it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.50it/s]


====> Test set loss: 1068.3228, (BCE: 1059.3418, KLD: 8.9809)
Epoch 74/100


 21%|██        | 54/261 [00:01<00:06, 31.90it/s]

Step 19,100, (N samples: 2,444,800), Loss: 1061.9857, (Recon: 1052.5901, KLD: 9.3956), Gradient norm: 283.2404


 59%|█████▉    | 154/261 [00:04<00:03, 31.94it/s]

Step 19,200, (N samples: 2,457,600), Loss: 1024.6031, (Recon: 1015.7939, KLD: 8.8093), Gradient norm: 239.2022


 97%|█████████▋| 254/261 [00:08<00:00, 31.90it/s]

Step 19,300, (N samples: 2,470,400), Loss: 1067.8857, (Recon: 1058.3405, KLD: 9.5453), Gradient norm: 221.7884


100%|██████████| 261/261 [00:08<00:00, 31.58it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.37it/s]


====> Test set loss: 1066.6163, (BCE: 1057.5909, KLD: 9.0254)
Epoch 75/100


 34%|███▍      | 90/261 [00:02<00:05, 31.62it/s]

Step 19,400, (N samples: 2,483,200), Loss: 1029.1714, (Recon: 1020.1825, KLD: 8.9889), Gradient norm: 101.2558


 73%|███████▎  | 190/261 [00:06<00:02, 31.75it/s]

Step 19,500, (N samples: 2,496,000), Loss: 1057.0416, (Recon: 1048.1331, KLD: 8.9086), Gradient norm: 221.0165


100%|██████████| 261/261 [00:08<00:00, 31.41it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.47it/s]


====> Test set loss: 1067.5190, (BCE: 1058.3583, KLD: 9.1607)
Epoch 76/100


 11%|█▏        | 30/261 [00:01<00:07, 31.06it/s]

Step 19,600, (N samples: 2,508,800), Loss: 1058.1818, (Recon: 1048.7083, KLD: 9.4735), Gradient norm: 304.2827


 50%|████▉     | 130/261 [00:04<00:04, 31.73it/s]

Step 19,700, (N samples: 2,521,600), Loss: 1120.6732, (Recon: 1111.4349, KLD: 9.2383), Gradient norm: 244.3212


 88%|████████▊ | 230/261 [00:07<00:00, 31.81it/s]

Step 19,800, (N samples: 2,534,400), Loss: 1043.3146, (Recon: 1033.7341, KLD: 9.5804), Gradient norm: 357.1203


100%|██████████| 261/261 [00:08<00:00, 31.48it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 31.11it/s]


====> Test set loss: 1067.8750, (BCE: 1058.7679, KLD: 9.1071)
Epoch 77/100


 27%|██▋       | 70/261 [00:02<00:06, 31.65it/s]

Step 19,900, (N samples: 2,547,200), Loss: 1087.1931, (Recon: 1078.2686, KLD: 8.9246), Gradient norm: 158.7528


 65%|██████▌   | 170/261 [00:05<00:02, 31.81it/s]

Step 20,000, (N samples: 2,560,000), Loss: 1021.9653, (Recon: 1012.8382, KLD: 9.1270), Gradient norm: 153.3664


100%|██████████| 261/261 [00:08<00:00, 31.68it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.15it/s]


====> Test set loss: 1068.3743, (BCE: 1059.2031, KLD: 9.1712)
Epoch 78/100


  1%|          | 3/261 [00:00<00:11, 23.24it/s]

Step 20,100, (N samples: 2,572,800), Loss: 1043.5896, (Recon: 1034.6320, KLD: 8.9576), Gradient norm: 231.1564


 42%|████▏     | 110/261 [00:03<00:04, 31.94it/s]

Step 20,200, (N samples: 2,585,600), Loss: 1077.0875, (Recon: 1067.7512, KLD: 9.3363), Gradient norm: 367.7054


 80%|████████  | 210/261 [00:06<00:01, 31.92it/s]

Step 20,300, (N samples: 2,598,400), Loss: 1039.0754, (Recon: 1030.1926, KLD: 8.8828), Gradient norm: 197.4719


100%|██████████| 261/261 [00:08<00:00, 31.73it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.17it/s]


====> Test set loss: 1067.5316, (BCE: 1058.3355, KLD: 9.1961)
Epoch 79/100


 18%|█▊        | 46/261 [00:01<00:06, 31.55it/s]

Step 20,400, (N samples: 2,611,200), Loss: 1096.1592, (Recon: 1086.9286, KLD: 9.2306), Gradient norm: 128.5488


 56%|█████▌    | 146/261 [00:04<00:03, 31.88it/s]

Step 20,500, (N samples: 2,624,000), Loss: 1054.2773, (Recon: 1044.9407, KLD: 9.3367), Gradient norm: 247.8474


 94%|█████████▍| 246/261 [00:07<00:00, 31.07it/s]

Step 20,600, (N samples: 2,636,800), Loss: 1054.8096, (Recon: 1045.4235, KLD: 9.3861), Gradient norm: 146.0749


100%|██████████| 261/261 [00:08<00:00, 31.68it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.47it/s]


====> Test set loss: 1067.6361, (BCE: 1058.5592, KLD: 9.0769)
Epoch 80/100


 33%|███▎      | 86/261 [00:02<00:05, 31.85it/s]

Step 20,700, (N samples: 2,649,600), Loss: 1090.1409, (Recon: 1080.5402, KLD: 9.6007), Gradient norm: 308.7229


 71%|███████▏  | 186/261 [00:05<00:02, 31.94it/s]

Step 20,800, (N samples: 2,662,400), Loss: 1053.7455, (Recon: 1044.5294, KLD: 9.2161), Gradient norm: 164.8458


100%|██████████| 261/261 [00:08<00:00, 31.73it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 35.95it/s]


====> Test set loss: 1067.4603, (BCE: 1058.3119, KLD: 9.1485)
Epoch 81/100


  8%|▊         | 22/261 [00:00<00:09, 26.21it/s]

Step 20,900, (N samples: 2,675,200), Loss: 1036.4584, (Recon: 1026.8942, KLD: 9.5642), Gradient norm: 251.1491


 49%|████▊     | 127/261 [00:04<00:04, 31.91it/s]

Step 21,000, (N samples: 2,688,000), Loss: 1123.3704, (Recon: 1114.0951, KLD: 9.2752), Gradient norm: 273.0569


 87%|████████▋ | 227/261 [00:07<00:01, 31.96it/s]

Step 21,100, (N samples: 2,700,800), Loss: 1084.2224, (Recon: 1075.2155, KLD: 9.0069), Gradient norm: 383.6544


100%|██████████| 261/261 [00:08<00:00, 29.93it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.42it/s]


====> Test set loss: 1069.1711, (BCE: 1060.0826, KLD: 9.0884)
Epoch 82/100


 25%|██▌       | 66/261 [00:02<00:06, 31.79it/s]

Step 21,200, (N samples: 2,713,600), Loss: 1041.5654, (Recon: 1031.9598, KLD: 9.6056), Gradient norm: 150.7111


 64%|██████▎   | 166/261 [00:05<00:02, 32.04it/s]

Step 21,300, (N samples: 2,726,400), Loss: 1047.4762, (Recon: 1038.6100, KLD: 8.8662), Gradient norm: 203.4317


100%|██████████| 261/261 [00:08<00:00, 31.72it/s]


Step 21,400, (N samples: 2,739,200), Loss: 1098.6597, (Recon: 1089.0974, KLD: 9.5623), Gradient norm: 298.1462


Testing: 100%|██████████| 29/29 [00:00<00:00, 36.43it/s]


====> Test set loss: 1068.8347, (BCE: 1059.7312, KLD: 9.1035)
Epoch 83/100


 39%|███▉      | 102/261 [00:03<00:04, 31.93it/s]

Step 21,500, (N samples: 2,752,000), Loss: 1025.8538, (Recon: 1016.7155, KLD: 9.1383), Gradient norm: 197.7445


 77%|███████▋  | 202/261 [00:06<00:01, 31.83it/s]

Step 21,600, (N samples: 2,764,800), Loss: 1072.8840, (Recon: 1063.5830, KLD: 9.3011), Gradient norm: 235.3203


100%|██████████| 261/261 [00:08<00:00, 31.77it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.47it/s]


====> Test set loss: 1067.2594, (BCE: 1058.0859, KLD: 9.1734)
Epoch 84/100


 16%|█▌        | 42/261 [00:01<00:06, 31.74it/s]

Step 21,700, (N samples: 2,777,600), Loss: 1103.9082, (Recon: 1094.6051, KLD: 9.3031), Gradient norm: 211.8514


 54%|█████▍    | 142/261 [00:04<00:03, 31.94it/s]

Step 21,800, (N samples: 2,790,400), Loss: 1051.8511, (Recon: 1042.5662, KLD: 9.2849), Gradient norm: 352.1729


 93%|█████████▎| 242/261 [00:07<00:00, 31.96it/s]

Step 21,900, (N samples: 2,803,200), Loss: 1070.5177, (Recon: 1060.6152, KLD: 9.9024), Gradient norm: 240.3872


100%|██████████| 261/261 [00:08<00:00, 31.68it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.34it/s]


====> Test set loss: 1067.1230, (BCE: 1057.8904, KLD: 9.2327)
Epoch 85/100


 31%|███▏      | 82/261 [00:02<00:05, 31.77it/s]

Step 22,000, (N samples: 2,816,000), Loss: 1076.1143, (Recon: 1066.8667, KLD: 9.2475), Gradient norm: 197.4046


 70%|██████▉   | 182/261 [00:05<00:02, 31.72it/s]

Step 22,100, (N samples: 2,828,800), Loss: 1046.9550, (Recon: 1037.7292, KLD: 9.2257), Gradient norm: 308.8392


100%|██████████| 261/261 [00:08<00:00, 31.65it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.43it/s]


====> Test set loss: 1067.7815, (BCE: 1058.6516, KLD: 9.1298)
Epoch 86/100


  8%|▊         | 22/261 [00:00<00:07, 31.12it/s]

Step 22,200, (N samples: 2,841,600), Loss: 1053.1105, (Recon: 1043.9526, KLD: 9.1579), Gradient norm: 218.1409


 47%|████▋     | 122/261 [00:03<00:04, 31.85it/s]

Step 22,300, (N samples: 2,854,400), Loss: 1041.3970, (Recon: 1031.9926, KLD: 9.4044), Gradient norm: 381.3744


 85%|████████▌ | 222/261 [00:07<00:01, 31.84it/s]

Step 22,400, (N samples: 2,867,200), Loss: 1026.1161, (Recon: 1017.0439, KLD: 9.0723), Gradient norm: 209.1040


100%|██████████| 261/261 [00:08<00:00, 31.69it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.38it/s]


====> Test set loss: 1067.3965, (BCE: 1058.1427, KLD: 9.2538)
Epoch 87/100


 22%|██▏       | 58/261 [00:01<00:06, 31.80it/s]

Step 22,500, (N samples: 2,880,000), Loss: 1112.4586, (Recon: 1103.2565, KLD: 9.2021), Gradient norm: 162.3248


 61%|██████    | 158/261 [00:04<00:03, 31.95it/s]

Step 22,600, (N samples: 2,892,800), Loss: 1045.5765, (Recon: 1036.8101, KLD: 8.7665), Gradient norm: 146.1125


100%|██████████| 261/261 [00:08<00:00, 31.70it/s]


Step 22,700, (N samples: 2,905,600), Loss: 1080.2324, (Recon: 1070.8038, KLD: 9.4286), Gradient norm: 185.4929


Testing: 100%|██████████| 29/29 [00:00<00:00, 36.17it/s]


====> Test set loss: 1067.5142, (BCE: 1058.3933, KLD: 9.1210)
Epoch 88/100


 38%|███▊      | 98/261 [00:03<00:05, 31.97it/s]

Step 22,800, (N samples: 2,918,400), Loss: 1052.9314, (Recon: 1043.5566, KLD: 9.3748), Gradient norm: 216.5289


 76%|███████▌  | 198/261 [00:06<00:01, 31.91it/s]

Step 22,900, (N samples: 2,931,200), Loss: 1045.2450, (Recon: 1036.1633, KLD: 9.0817), Gradient norm: 183.8200


100%|██████████| 261/261 [00:08<00:00, 31.77it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.23it/s]


====> Test set loss: 1068.3446, (BCE: 1059.0238, KLD: 9.3208)
Epoch 89/100


 14%|█▍        | 37/261 [00:01<00:07, 31.36it/s]

Step 23,000, (N samples: 2,944,000), Loss: 1009.4304, (Recon: 1000.9274, KLD: 8.5030), Gradient norm: 231.0093


 52%|█████▏    | 137/261 [00:04<00:03, 32.02it/s]

Step 23,100, (N samples: 2,956,800), Loss: 1032.9437, (Recon: 1023.9753, KLD: 8.9684), Gradient norm: 139.0533


 91%|█████████ | 237/261 [00:07<00:00, 30.91it/s]

Step 23,200, (N samples: 2,969,600), Loss: 1048.7534, (Recon: 1039.5996, KLD: 9.1538), Gradient norm: 140.4736


100%|██████████| 261/261 [00:08<00:00, 31.63it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.51it/s]


====> Test set loss: 1066.2146, (BCE: 1057.0073, KLD: 9.2073)
Epoch 90/100


 30%|██▉       | 78/261 [00:02<00:07, 26.09it/s]

Step 23,300, (N samples: 2,982,400), Loss: 1074.3645, (Recon: 1064.9214, KLD: 9.4431), Gradient norm: 241.7444


 68%|██████▊   | 178/261 [00:05<00:02, 31.82it/s]

Step 23,400, (N samples: 2,995,200), Loss: 1074.4020, (Recon: 1064.9440, KLD: 9.4580), Gradient norm: 193.5206


100%|██████████| 261/261 [00:08<00:00, 31.21it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 35.66it/s]


====> Test set loss: 1067.7211, (BCE: 1058.7464, KLD: 8.9747)
Epoch 91/100


  5%|▌         | 14/261 [00:00<00:08, 29.72it/s]

Step 23,500, (N samples: 3,008,000), Loss: 1067.8181, (Recon: 1058.2554, KLD: 9.5628), Gradient norm: 295.4897


 44%|████▎     | 114/261 [00:03<00:04, 31.86it/s]

Step 23,600, (N samples: 3,020,800), Loss: 1062.6423, (Recon: 1053.2501, KLD: 9.3922), Gradient norm: 148.7769


 82%|████████▏ | 214/261 [00:06<00:01, 31.86it/s]

Step 23,700, (N samples: 3,033,600), Loss: 1107.4337, (Recon: 1098.1958, KLD: 9.2379), Gradient norm: 174.0885


100%|██████████| 261/261 [00:08<00:00, 31.73it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.21it/s]


====> Test set loss: 1066.3473, (BCE: 1057.1507, KLD: 9.1966)
Epoch 92/100


 20%|██        | 53/261 [00:01<00:06, 31.77it/s]

Step 23,800, (N samples: 3,046,400), Loss: 1100.1831, (Recon: 1090.8217, KLD: 9.3615), Gradient norm: 209.1388


 59%|█████▊    | 153/261 [00:04<00:03, 31.89it/s]

Step 23,900, (N samples: 3,059,200), Loss: 1123.1951, (Recon: 1113.6763, KLD: 9.5188), Gradient norm: 220.4258


 97%|█████████▋| 253/261 [00:07<00:00, 31.82it/s]

Step 24,000, (N samples: 3,072,000), Loss: 1114.9462, (Recon: 1105.6016, KLD: 9.3446), Gradient norm: 214.9622


100%|██████████| 261/261 [00:08<00:00, 31.67it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.58it/s]


====> Test set loss: 1066.8950, (BCE: 1057.8926, KLD: 9.0024)
Epoch 93/100


 36%|███▌      | 94/261 [00:03<00:05, 31.91it/s]

Step 24,100, (N samples: 3,084,800), Loss: 1081.4517, (Recon: 1072.1296, KLD: 9.3221), Gradient norm: 195.6131


 74%|███████▍  | 194/261 [00:06<00:02, 31.79it/s]

Step 24,200, (N samples: 3,097,600), Loss: 1049.3999, (Recon: 1040.2751, KLD: 9.1247), Gradient norm: 178.1170


100%|██████████| 261/261 [00:08<00:00, 31.69it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.50it/s]


====> Test set loss: 1066.6587, (BCE: 1057.3671, KLD: 9.2916)
Epoch 94/100


 13%|█▎        | 34/261 [00:01<00:07, 31.59it/s]

Step 24,300, (N samples: 3,110,400), Loss: 1042.3740, (Recon: 1033.0583, KLD: 9.3156), Gradient norm: 258.4099


 51%|█████▏    | 134/261 [00:04<00:03, 31.86it/s]

Step 24,400, (N samples: 3,123,200), Loss: 1048.2985, (Recon: 1039.5623, KLD: 8.7362), Gradient norm: 106.5285


 89%|████████▊ | 231/261 [00:07<00:00, 30.59it/s]

Step 24,500, (N samples: 3,136,000), Loss: 1065.4490, (Recon: 1056.0427, KLD: 9.4062), Gradient norm: 218.1163


100%|██████████| 261/261 [00:08<00:00, 29.65it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.32it/s]


====> Test set loss: 1066.4191, (BCE: 1057.3924, KLD: 9.0267)
Epoch 95/100


 27%|██▋       | 70/261 [00:02<00:06, 31.78it/s]

Step 24,600, (N samples: 3,148,800), Loss: 1072.7289, (Recon: 1063.3665, KLD: 9.3624), Gradient norm: 251.4558


 65%|██████▌   | 170/261 [00:05<00:02, 31.92it/s]

Step 24,700, (N samples: 3,161,600), Loss: 1071.1675, (Recon: 1061.9403, KLD: 9.2271), Gradient norm: 353.4936


100%|██████████| 261/261 [00:08<00:00, 31.64it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.14it/s]


====> Test set loss: 1066.9156, (BCE: 1057.7349, KLD: 9.1807)
Epoch 96/100


  4%|▍         | 10/261 [00:00<00:08, 28.40it/s]

Step 24,800, (N samples: 3,174,400), Loss: 1090.8990, (Recon: 1081.7385, KLD: 9.1605), Gradient norm: 262.3862


 42%|████▏     | 110/261 [00:03<00:04, 31.93it/s]

Step 24,900, (N samples: 3,187,200), Loss: 1059.8442, (Recon: 1050.6613, KLD: 9.1830), Gradient norm: 182.2568


 80%|████████  | 210/261 [00:06<00:01, 31.88it/s]

Step 25,000, (N samples: 3,200,000), Loss: 1042.8618, (Recon: 1033.5986, KLD: 9.2631), Gradient norm: 192.7958


100%|██████████| 261/261 [00:08<00:00, 31.74it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.40it/s]


====> Test set loss: 1066.9973, (BCE: 1057.8600, KLD: 9.1374)
Epoch 97/100


 19%|█▉        | 50/261 [00:01<00:06, 31.74it/s]

Step 25,100, (N samples: 3,212,800), Loss: 1031.8875, (Recon: 1022.6479, KLD: 9.2395), Gradient norm: 308.5836


 57%|█████▋    | 150/261 [00:04<00:03, 31.13it/s]

Step 25,200, (N samples: 3,225,600), Loss: 1107.6555, (Recon: 1098.5557, KLD: 9.0999), Gradient norm: 165.0673


 96%|█████████▌| 250/261 [00:07<00:00, 31.95it/s]

Step 25,300, (N samples: 3,238,400), Loss: 1062.7678, (Recon: 1053.3596, KLD: 9.4082), Gradient norm: 216.0917


100%|██████████| 261/261 [00:08<00:00, 31.61it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.35it/s]


====> Test set loss: 1065.9958, (BCE: 1056.8005, KLD: 9.1953)
Epoch 98/100


 34%|███▍      | 90/261 [00:02<00:05, 31.80it/s]

Step 25,400, (N samples: 3,251,200), Loss: 1059.4032, (Recon: 1049.8513, KLD: 9.5519), Gradient norm: 192.5399


 73%|███████▎  | 190/261 [00:06<00:02, 31.86it/s]

Step 25,500, (N samples: 3,264,000), Loss: 1089.0865, (Recon: 1079.4604, KLD: 9.6261), Gradient norm: 166.8046


100%|██████████| 261/261 [00:08<00:00, 31.71it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.18it/s]


====> Test set loss: 1066.7570, (BCE: 1057.5553, KLD: 9.2017)
Epoch 99/100


 10%|▉         | 26/261 [00:00<00:07, 31.19it/s]

Step 25,600, (N samples: 3,276,800), Loss: 1020.0345, (Recon: 1011.2513, KLD: 8.7831), Gradient norm: 91.8331


 48%|████▊     | 126/261 [00:04<00:04, 31.53it/s]

Step 25,700, (N samples: 3,289,600), Loss: 1064.7904, (Recon: 1055.0811, KLD: 9.7093), Gradient norm: 406.4956


 87%|████████▋ | 226/261 [00:07<00:01, 31.79it/s]

Step 25,800, (N samples: 3,302,400), Loss: 1065.1919, (Recon: 1055.9385, KLD: 9.2534), Gradient norm: 186.6512


100%|██████████| 261/261 [00:08<00:00, 31.65it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.09it/s]


====> Test set loss: 1065.4396, (BCE: 1056.3812, KLD: 9.0585)
Epoch 100/100


 25%|██▌       | 66/261 [00:02<00:06, 31.79it/s]

Step 25,900, (N samples: 3,315,200), Loss: 1053.0640, (Recon: 1043.8555, KLD: 9.2085), Gradient norm: 153.3590


 64%|██████▎   | 166/261 [00:05<00:02, 31.89it/s]

Step 26,000, (N samples: 3,328,000), Loss: 1054.3584, (Recon: 1045.2986, KLD: 9.0599), Gradient norm: 275.4611


100%|██████████| 261/261 [00:08<00:00, 31.63it/s]
Testing: 100%|██████████| 29/29 [00:00<00:00, 36.45it/s]

====> Test set loss: 1066.5811, (BCE: 1057.3722, KLD: 9.2089)





In [10]:
def pearson_correlation(original_x, x_hat, mask=None):
    
    # Calculate mean and standard deviation
    mean_x = torch.mean(original_x)
    mean_x_hat = torch.mean(x_hat)
    
    std_x = torch.std(original_x)
    std_x_hat = torch.std(x_hat)
    
    # Calculate covariance
    covariance = torch.mean((original_x - mean_x) * (x_hat - mean_x_hat))
    
    # Calculate Pearson correlation
    correlation = covariance / (std_x * std_x_hat)
    
    return correlation.item()


correlations = []

# Disable gradient calculation for efficiency
with torch.no_grad():
    for batch in train_loader:
        # Unpack the batch to get the input data; adjust if there are labels
        original_x, _ = batch  # Adjust this if your batch contains more items
        original_x = original_x.to(device)  # Move input to device if using GPU

        # Get the model output for the entire batch
        out = model(original_x)
        x_hat_batch = out.x_recon

        # Compute Pearson correlation for each element in the batch
        for i in range(original_x.size(0)):  # Loop over each sample in the batch
            original_sample = original_x[i]  # Select the i-th sample
            x_hat_sample = x_hat_batch[i]    # Select the corresponding output sample

            # Compute Pearson correlation for the sample
            correlation = pearson_correlation(original_sample, x_hat_sample)
            correlations.append(correlation)
            # print(f"Sample {i} Pearson Correlation:", correlation)

# Optionally, compute average correlation across all elements
average_correlation = sum(correlations) / len(correlations)
print("Average Pearson Correlation across all samples:", average_correlation)


Average Pearson Correlation across all samples: 0.9239656174638335
