# VAE Training

In [7]:
import os
os.environ["CUDA_VISIBLE_DEVICE"]="0,1"

In [8]:
import torch as t
import torch.nn.functional as F
from tqdm import tqdm
from models.VariationalAutoencoder import VariationalAutoencoder
from torchvision import datasets, transforms

bs = 128
train_ds = datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
train_dl = t.utils.data.DataLoader(dataset=train_ds, batch_size=bs, shuffle=True, drop_last=True)

device = t.device('cuda') if t.cuda.is_available() else 'cpu'
model = VariationalAutoencoder(train_ds[0][0][None], in_c=1, enc_out_c=[32, 64, 64, 64],
                    enc_ks=[3, 3, 3, 3], enc_pads=[1, 1, 0, 1], enc_strides=[1, 2, 2, 1],
                    dec_out_c=[64, 64, 32, 1], dec_ks=[3, 3, 3, 3], dec_strides=[1, 2, 2, 1],
                    dec_pads=[1, 0, 1, 1], dec_op_pads=[0, 1, 1, 0], z_dim=2)
model.cuda(device)
model.train()

VariationalAutoencoder(
  (enc_conv_layers): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.25, inplace=False)
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.25, inplace=False)
    )
    (2): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.25, inplace=False)
    )
    (3): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU(nega

In [9]:
train_ds

Dataset MNIST
    Number of datapoints: 60000
    Root location: ./data/
    Split: Train
    StandardTransform
Transform: ToTensor()

In [10]:
60000/bs/33

14.204545454545455

In [24]:
def vae_kl_loss(mu, log_var):
    return -.5 * t.sum(1 + log_var - mu ** 2 - log_var.exp())

def vae_loss(y_pred, mu, log_var, y_true):
    r_loss = F.binary_cross_entropy(y_pred, y_true, reduction='sum')
    r_loss_factor = 1
    kl_loss = vae_kl_loss(mu, log_var)
    return r_loss_factor*r_loss + kl_loss

In [25]:
lr = .0005
for epoch in tqdm(range(20)):
    optimizer = t.optim.Adam(model.parameters(), lr=lr / (epoch * 2 + 1), betas=(.9, .99), weight_decay=1e-2)
    for i, (data, _) in enumerate(train_dl):
        data = data.to(device)
        optimizer.zero_grad()
        pred, mu, log_var = model(data)
        loss = vae_loss(pred, mu, log_var, data)
        loss.backward()
        optimizer.step()
        if i % 33 == 0:
            print(loss)

print(loss)

  0%|                                                    | 0/20 [00:00<?, ?it/s]

tensor(27425.9941, device='cuda:0', grad_fn=<AddBackward0>)
tensor(27194.8574, device='cuda:0', grad_fn=<AddBackward0>)
tensor(26925.9551, device='cuda:0', grad_fn=<AddBackward0>)
tensor(27584.6895, device='cuda:0', grad_fn=<AddBackward0>)
tensor(27446.9551, device='cuda:0', grad_fn=<AddBackward0>)
tensor(25211.4531, device='cuda:0', grad_fn=<AddBackward0>)
tensor(26710.0527, device='cuda:0', grad_fn=<AddBackward0>)
tensor(27481.5137, device='cuda:0', grad_fn=<AddBackward0>)
tensor(24401.8633, device='cuda:0', grad_fn=<AddBackward0>)
tensor(24927.9707, device='cuda:0', grad_fn=<AddBackward0>)
tensor(25145.7012, device='cuda:0', grad_fn=<AddBackward0>)
tensor(23197.8594, device='cuda:0', grad_fn=<AddBackward0>)
tensor(23601.8770, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22855.0430, device='cuda:0', grad_fn=<AddBackward0>)


  5%|██▏                                         | 1/20 [00:16<05:19, 16.79s/it]

tensor(23399.7520, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21946.8398, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22132.9160, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22751.9805, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22531.1582, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22632.1699, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22708.0586, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22096.5254, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21962.9375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21613.1719, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21507.3164, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21882.6094, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21843.3281, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22493.6484, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21218.6641, device='cuda:0', grad_fn=<AddBackward0>)


 10%|████▍                                       | 2/20 [00:33<05:03, 16.83s/it]

tensor(21972.8066, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21702.7363, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21772.7344, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21447.6016, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22361.7891, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22463.9395, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21520.9805, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20768.1621, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21493.5703, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21191.2949, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21195.2402, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21838.0508, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22304.3672, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21234.9297, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22536.2207, device='cuda:0', grad_fn=<AddBackward0>)


 15%|██████▌                                     | 3/20 [00:50<04:46, 16.85s/it]

tensor(21157.7070, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22727.0215, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21174.9141, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21394.5039, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21393.9805, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22052.9648, device='cuda:0', grad_fn=<AddBackward0>)
tensor(19875.2148, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22059.1094, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21324.9160, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21447.2129, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21935.3281, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21420.9531, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21510.5098, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21337.0137, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20834.4395, device='cuda:0', grad_fn=<AddBackward0>)


 20%|████████▊                                   | 4/20 [01:07<04:29, 16.86s/it]

tensor(21103.5137, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21335.5859, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21773.2910, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20831.9629, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20726.4531, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21406.0742, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21337.4375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22220.1289, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20887.2832, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21500.2754, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21543.3926, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20839.0938, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21631.7949, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21166.7637, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21281.7227, device='cuda:0', grad_fn=<AddBackward0>)


 25%|███████████                                 | 5/20 [01:24<04:13, 16.88s/it]

tensor(20918.7871, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22658.7305, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21302.9141, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21264.4746, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21740.4121, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21367.0254, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21380.2656, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21238.2832, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22508.3809, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21282.7598, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20299.0391, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20139.9199, device='cuda:0', grad_fn=<AddBackward0>)
tensor(19905.5410, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21565.8008, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21265.8066, device='cuda:0', grad_fn=<AddBackward0>)


 30%|█████████████▏                              | 6/20 [01:41<03:56, 16.88s/it]

tensor(21196.0020, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21234.9648, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20182.0117, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22231.5879, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21798.0898, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21023.3652, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21236.3418, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21532.6836, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20649.0020, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21905.1484, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21231.7012, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21459.9844, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21289.2266, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21199.6953, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21053.6406, device='cuda:0', grad_fn=<AddBackward0>)


 35%|███████████████▍                            | 7/20 [01:58<03:40, 16.94s/it]

tensor(21122.8984, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21592.5801, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20871.1758, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22520.5879, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22549.5098, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21404.4453, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21377.5195, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21517.7930, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21715.2695, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21986.9922, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22145.3691, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21177.4023, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20757.8672, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21098.7285, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21141.9336, device='cuda:0', grad_fn=<AddBackward0>)


 40%|█████████████████▌                          | 8/20 [02:15<03:23, 16.94s/it]

tensor(21000.7715, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21575.2188, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20975.9004, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21363.4590, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22326.5586, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22292.9492, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21952.9883, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21193.9863, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21362.3730, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20607.2051, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22281.3418, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20517.4023, device='cuda:0', grad_fn=<AddBackward0>)
tensor(19997.6465, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20228.1641, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21704.8242, device='cuda:0', grad_fn=<AddBackward0>)


 45%|███████████████████▊                        | 9/20 [02:32<03:06, 16.97s/it]

tensor(21284.5527, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20661.3594, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21835.0371, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21117.9062, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22029.0918, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20696.3008, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21952.3320, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20589.9551, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20908.2812, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20253.8184, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20558.4473, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21207.1055, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21811.3711, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20556.3848, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21484.6816, device='cuda:0', grad_fn=<AddBackward0>)


 50%|█████████████████████▌                     | 10/20 [02:49<02:50, 17.03s/it]

tensor(20931.1445, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21075.8242, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20595.7070, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21591.9395, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22356.7207, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21810.9414, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21094.6973, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20344.4785, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20841.1074, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21651.1992, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21387.0449, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21025.7012, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20787.8379, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21194.4551, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21659.6055, device='cuda:0', grad_fn=<AddBackward0>)


 55%|███████████████████████▋                   | 11/20 [03:06<02:34, 17.12s/it]

tensor(20992.3320, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21123.9727, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20655.3691, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21178.9785, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20839.1582, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21191.0137, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21109.7812, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20165.7676, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20931.7441, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21095.4863, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21418.6016, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20956.2852, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20539.9590, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22384.9180, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20984.9727, device='cuda:0', grad_fn=<AddBackward0>)


 60%|█████████████████████████▊                 | 12/20 [03:23<02:17, 17.15s/it]

tensor(21254.7305, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21491.1328, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20540.2109, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21033.2559, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21638.3340, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21909.9941, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21202.8848, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20950.1895, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21419.1035, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21621.4551, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20951.3984, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21313.1855, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20743.2422, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20064.1387, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20430.9980, device='cuda:0', grad_fn=<AddBackward0>)


 65%|███████████████████████████▉               | 13/20 [03:41<02:00, 17.15s/it]

tensor(20966.6191, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21590.3359, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20823.4375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21378.2148, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20585.0137, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21325.5938, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20755.5703, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21612.4531, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21034.7090, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21225.6582, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20945.6289, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21345.3184, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22220.0352, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21353.3398, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20616.7305, device='cuda:0', grad_fn=<AddBackward0>)


 70%|██████████████████████████████             | 14/20 [03:58<01:42, 17.11s/it]

tensor(21103.4160, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20897.0781, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21204.4746, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21702.2812, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21623.1543, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21401.0820, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20443.8652, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21282.0215, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20047.7578, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21639.0918, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21431.0742, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20361.4883, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21542.6406, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21039.1992, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20783.4922, device='cuda:0', grad_fn=<AddBackward0>)


 75%|████████████████████████████████▎          | 15/20 [04:14<01:25, 17.03s/it]

tensor(20750.3555, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21047.7109, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21223.6738, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20881.3008, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21667.4824, device='cuda:0', grad_fn=<AddBackward0>)
tensor(23717.8457, device='cuda:0', grad_fn=<AddBackward0>)
tensor(19949.4570, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20769.1562, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21992.5234, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21570.9727, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20899.2773, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21021.3301, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21387.7715, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20624.2266, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20835.6875, device='cuda:0', grad_fn=<AddBackward0>)


 80%|██████████████████████████████████▍        | 16/20 [04:32<01:08, 17.08s/it]

tensor(20561.9102, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20282.8379, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21180.4961, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21040.8184, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21483.2422, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20284.4785, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21133.6309, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21323.4551, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21420.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21075.1230, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20402.1406, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21746.0215, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20438.2637, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20060.9785, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20169.2832, device='cuda:0', grad_fn=<AddBackward0>)


 85%|████████████████████████████████████▌      | 17/20 [04:49<00:51, 17.05s/it]

tensor(21797.4043, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21462.0078, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21813.1484, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21000.3164, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20323.2539, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21052.8984, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21275.2422, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21192.2734, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21200.5723, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21018.9863, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21679.3125, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21240.7383, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20243.6426, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21647.8086, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21320.3594, device='cuda:0', grad_fn=<AddBackward0>)


 90%|██████████████████████████████████████▋    | 18/20 [05:06<00:34, 17.09s/it]

tensor(20408.6836, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22111.8750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20995.6523, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21709.1250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20714.3281, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20860.6914, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21295.3105, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21339.0039, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20850.4941, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20414.3770, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21192.6621, device='cuda:0', grad_fn=<AddBackward0>)
tensor(22407.6133, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21123.0332, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20852.9082, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21873.9062, device='cuda:0', grad_fn=<AddBackward0>)


 95%|████████████████████████████████████████▊  | 19/20 [05:23<00:17, 17.01s/it]

tensor(20584.6582, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21277.3301, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20750.3711, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20366.0449, device='cuda:0', grad_fn=<AddBackward0>)
tensor(19736.5410, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21710.2266, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20330.7715, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21635.8164, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21919.7891, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20429.6816, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20928.4395, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21645.4668, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21156.7246, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21668.0547, device='cuda:0', grad_fn=<AddBackward0>)
tensor(20927.2285, device='cuda:0', grad_fn=<AddBackward0>)


100%|███████████████████████████████████████████| 20/20 [05:37<00:00, 16.88s/it]

tensor(20646.6816, device='cuda:0', grad_fn=<AddBackward0>)
tensor(21408.9414, device='cuda:0', grad_fn=<AddBackward0>)





In [13]:
mu.shape

torch.Size([128, 2])

In [14]:
log_var.shape

torch.Size([128, 2])

In [15]:
mu[0:10]

tensor([[ 0.2095, -2.3853],
        [ 1.2137, -2.5154],
        [ 0.4480, -0.9542],
        [ 0.1077, -4.3537],
        [-0.9446, -2.6549],
        [ 1.4333, -0.6426],
        [ 3.2849, -1.8471],
        [-0.0786, -1.9685],
        [-0.1759, -0.9804],
        [-0.1184, -1.4328]], device='cuda:0', grad_fn=<SliceBackward0>)

In [16]:
log_var[0:10]

tensor([[-4.1537, -3.4211],
        [-3.1524, -3.2208],
        [-3.1086, -2.8434],
        [-4.3717, -2.1544],
        [-3.8358, -5.3702],
        [-2.8891,  0.5710],
        [-3.0989, -3.4410],
        [-3.6518, -3.1645],
        [-4.0692, -1.9565],
        [-2.7796, -4.3813]], device='cuda:0', grad_fn=<SliceBackward0>)

In [None]:
std = t.exp(log_var[0:10]/2)
std

In [None]:
loss

In [None]:
print('20_epoch_loss=21229')

In [None]:
t.save(model.state_dict(), 'models/state_dicts/03_03.pth')

In [None]:
m = t.Tensor([0, 0])

In [None]:
log_va = t.Tensor([1, 1])

In [None]:
-.5 * t.sum(1 + log_va - m ** 2 - log_va.exp())

In [None]:
log_va - m ** 2 - log_va.exp()

In [23]:
t.sum(log_va - m ** 2 - log_va.exp())

tensor(-3.4366)