# VAE Training

In [1]:
import torch as t
import torch.nn.functional as F
from tqdm import tqdm
from models.VariationalAutoencoder import VariationalAutoencoder
from torchvision import datasets, transforms

bs = 32
train_ds = datasets.MNIST(root='./data/', train=True, transform=transforms.ToTensor(), download=True)
train_dl = t.utils.data.DataLoader(dataset=train_ds, batch_size=bs, shuffle=True, drop_last=True)

device = t.device('cuda' if t.cuda.is_available() else 'cpu')
model = VariationalAutoencoder(train_ds[0][0][None], in_c=1, enc_out_c=[32, 64, 64, 64],
                    enc_ks=[3, 3, 3, 3], enc_pads=[1, 1, 0, 1], enc_strides=[1, 2, 2, 1],
                    dec_out_c=[64, 64, 32, 1], dec_ks=[3, 3, 3, 3], dec_strides=[1, 2, 2, 1],
                    dec_pads=[1, 0, 1, 1], dec_op_pads=[0, 1, 1, 0], z_dim=2)
model.cuda(device)
model.train()

VariationalAutoencoder(
  (enc_conv_layers): Sequential(
    (0): Sequential(
      (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.25, inplace=False)
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.25, inplace=False)
    )
    (2): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.25, inplace=False)
    )
    (3): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU(nega

In [2]:
def vae_kl_loss(mu, log_var):
    return -.5 * t.sum(1 + log_var - mu ** 2 - log_var.exp())

def vae_loss(y_pred, mu, log_var, y_true, r_loss_factor=100):
    r_loss = F.binary_cross_entropy(y_pred, y_true, reduction='sum')
    kl_loss = vae_kl_loss(mu, log_var)
    return r_loss_factor * r_loss + kl_loss

In [3]:
lr = .0005
for epoch in tqdm(range(20)):
    optimizer = t.optim.Adam(model.parameters(), lr=lr / (epoch * 2 + 1), betas=(.9, .99), weight_decay=1e-2)
    for i, (data, _) in enumerate(train_dl):
        data = data.to(device)
        optimizer.zero_grad()
        pred, mu, log_var = model(data)
        loss = vae_loss(pred, mu, log_var, data)
        loss.backward()
        optimizer.step()
        if i % 33 == 0:
            print(loss)

print(loss)

  0%|          | 0/20 [00:00<?, ?it/s]

tensor(3661121.7500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1680244.6250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1436101., device='cuda:0', grad_fn=<AddBackward0>)
tensor(1181363.1250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(1196160.3750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(848027.3750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(817010.9375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(854536.7500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(708343.3125, device='cuda:0', grad_fn=<AddBackward0>)
tensor(667996.8750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(661654.3125, device='cuda:0', grad_fn=<AddBackward0>)
tensor(579668.1250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(640547.4375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(610358.9375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(657156.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(587984.0625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(599292.1250, dev

  5%|▌         | 1/20 [00:43<13:47, 43.54s/it]

tensor(503492.0625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(548645.2500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(599311.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(558159.7500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(559970.0625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(516686.4688, device='cuda:0', grad_fn=<AddBackward0>)
tensor(492067.5938, device='cuda:0', grad_fn=<AddBackward0>)
tensor(495852.2188, device='cuda:0', grad_fn=<AddBackward0>)
tensor(544231.0625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(540126.5625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(494645.0938, device='cuda:0', grad_fn=<AddBackward0>)
tensor(530919.9375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(502906.4375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(583522.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(516078.3125, device='cuda:0', grad_fn=<AddBackward0>)
tensor(501771.7500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(548026.2500, devi

 10%|█         | 2/20 [01:26<12:59, 43.31s/it]

tensor(521679.3125, device='cuda:0', grad_fn=<AddBackward0>)
tensor(527384.8750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(451608.6875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(524861.8750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(506000.5938, device='cuda:0', grad_fn=<AddBackward0>)
tensor(497120.1875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(531675.3125, device='cuda:0', grad_fn=<AddBackward0>)
tensor(511057.6562, device='cuda:0', grad_fn=<AddBackward0>)
tensor(502735.8125, device='cuda:0', grad_fn=<AddBackward0>)
tensor(528228.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(496170.8438, device='cuda:0', grad_fn=<AddBackward0>)
tensor(486999.8750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(511002.1562, device='cuda:0', grad_fn=<AddBackward0>)
tensor(541538.6250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(539785.6250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(527747.2500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(504401.5625, devi

 15%|█▌        | 3/20 [02:09<12:13, 43.13s/it]

tensor(483841.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(490881.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(458083.8438, device='cuda:0', grad_fn=<AddBackward0>)
tensor(533851.3750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(522755.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(513932.8438, device='cuda:0', grad_fn=<AddBackward0>)
tensor(504177.0312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(564163.1875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(532030.8750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(505052.5312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(525183.4375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(493468.5625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(520065.1875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(541589.7500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(520243.1250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(511500.0312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(527484.9375, devi

 20%|██        | 4/20 [02:51<11:28, 43.02s/it]

tensor(475581.4688, device='cuda:0', grad_fn=<AddBackward0>)
tensor(494108.6875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(505255.4688, device='cuda:0', grad_fn=<AddBackward0>)
tensor(520384.5938, device='cuda:0', grad_fn=<AddBackward0>)
tensor(535897.1250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(501814.6875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(579885.1875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(510257.1562, device='cuda:0', grad_fn=<AddBackward0>)
tensor(516128.3125, device='cuda:0', grad_fn=<AddBackward0>)
tensor(516433., device='cuda:0', grad_fn=<AddBackward0>)
tensor(484626.2812, device='cuda:0', grad_fn=<AddBackward0>)
tensor(500878.2500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(528676.1250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(502010.0625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(500886.4375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(477061.4375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(498338.1875, device='

 25%|██▌       | 5/20 [03:34<10:44, 42.97s/it]

tensor(520147.0625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(483633.2812, device='cuda:0', grad_fn=<AddBackward0>)
tensor(527644.1875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(482843.1250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(451934.1875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(455265.5312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(486621.5625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(458094.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(484218.2812, device='cuda:0', grad_fn=<AddBackward0>)
tensor(506286.7812, device='cuda:0', grad_fn=<AddBackward0>)
tensor(484664.5938, device='cuda:0', grad_fn=<AddBackward0>)
tensor(469247.1875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(474150.3750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(482845.8125, device='cuda:0', grad_fn=<AddBackward0>)
tensor(434306.1875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(481811.4062, device='cuda:0', grad_fn=<AddBackward0>)
tensor(508680., device='

 30%|███       | 6/20 [04:17<10:00, 42.90s/it]

tensor(503267.2812, device='cuda:0', grad_fn=<AddBackward0>)
tensor(506404.4062, device='cuda:0', grad_fn=<AddBackward0>)
tensor(561696.3125, device='cuda:0', grad_fn=<AddBackward0>)
tensor(496974.5938, device='cuda:0', grad_fn=<AddBackward0>)
tensor(479573.0938, device='cuda:0', grad_fn=<AddBackward0>)
tensor(520283.0312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(434623.9062, device='cuda:0', grad_fn=<AddBackward0>)
tensor(480020.3438, device='cuda:0', grad_fn=<AddBackward0>)
tensor(507768.8750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(533355.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(497705., device='cuda:0', grad_fn=<AddBackward0>)
tensor(530289.8750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(549439.9375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(478376.6562, device='cuda:0', grad_fn=<AddBackward0>)
tensor(537704.6875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(486247.4688, device='cuda:0', grad_fn=<AddBackward0>)
tensor(507343.4375, device='

 35%|███▌      | 7/20 [05:00<09:16, 42.84s/it]

tensor(535288., device='cuda:0', grad_fn=<AddBackward0>)
tensor(461061.6875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(501573.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(490363.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(479718., device='cuda:0', grad_fn=<AddBackward0>)
tensor(454361.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(484457.8438, device='cuda:0', grad_fn=<AddBackward0>)
tensor(479432., device='cuda:0', grad_fn=<AddBackward0>)
tensor(531561.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(509360.4688, device='cuda:0', grad_fn=<AddBackward0>)
tensor(553887.6875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(569884.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(561966.9375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(470501.5312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(490383.1875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(485228.4688, device='cuda:0', grad_fn=<AddBackward0>)
tensor(541064., device='cuda:0', gra

 40%|████      | 8/20 [05:42<08:34, 42.84s/it]

tensor(514628.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(466060.3125, device='cuda:0', grad_fn=<AddBackward0>)
tensor(523414.0312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(496350.1250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(452465.3750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(504353.0312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(499630.4375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(462366.5312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(494879.0312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(460017.7500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(491002.1875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(524860., device='cuda:0', grad_fn=<AddBackward0>)
tensor(474704.2188, device='cuda:0', grad_fn=<AddBackward0>)
tensor(459503.7812, device='cuda:0', grad_fn=<AddBackward0>)
tensor(498096.1875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(471478.4688, device='cuda:0', grad_fn=<AddBackward0>)
tensor(490083.6250, device='

 45%|████▌     | 9/20 [06:25<07:50, 42.81s/it]

tensor(519387.8750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(495216.4375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(493738.6562, device='cuda:0', grad_fn=<AddBackward0>)
tensor(461369.8438, device='cuda:0', grad_fn=<AddBackward0>)
tensor(470361.9688, device='cuda:0', grad_fn=<AddBackward0>)
tensor(463637.8750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(470214.0312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(496362.1875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(490746.1562, device='cuda:0', grad_fn=<AddBackward0>)
tensor(496937.5938, device='cuda:0', grad_fn=<AddBackward0>)
tensor(481471.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(494319.0625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(446119.1562, device='cuda:0', grad_fn=<AddBackward0>)
tensor(469243.9688, device='cuda:0', grad_fn=<AddBackward0>)
tensor(482452.1562, device='cuda:0', grad_fn=<AddBackward0>)
tensor(500731.9375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(523293.6250, devi

 50%|█████     | 10/20 [07:08<07:07, 42.79s/it]

tensor(465882.1875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(491294.8438, device='cuda:0', grad_fn=<AddBackward0>)
tensor(452922.3750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(469016., device='cuda:0', grad_fn=<AddBackward0>)
tensor(530792.6875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(509344.5625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(565647.4375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(461235.4375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(500471.1875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(495579.4688, device='cuda:0', grad_fn=<AddBackward0>)
tensor(437185.9375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(498217.1250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(499353.0312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(533876.0625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(498490.4688, device='cuda:0', grad_fn=<AddBackward0>)
tensor(544459.6875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(484287.1562, device='

 55%|█████▌    | 11/20 [07:51<06:25, 42.80s/it]

tensor(446244.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(471853.0625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(491685.4375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(442998.7500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(514736.0312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(491880.9062, device='cuda:0', grad_fn=<AddBackward0>)
tensor(441311.0312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(522293.0938, device='cuda:0', grad_fn=<AddBackward0>)
tensor(475799.3125, device='cuda:0', grad_fn=<AddBackward0>)
tensor(458691.5625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(508435.9375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(524634.6875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(525730.2500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(415036.9375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(529667.3750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(509684.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(451427.1562, devi

 60%|██████    | 12/20 [08:34<05:42, 42.79s/it]

tensor(475589.0938, device='cuda:0', grad_fn=<AddBackward0>)
tensor(493539.2188, device='cuda:0', grad_fn=<AddBackward0>)
tensor(484101.7500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(517451.8125, device='cuda:0', grad_fn=<AddBackward0>)
tensor(555931.6250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(485907.7812, device='cuda:0', grad_fn=<AddBackward0>)
tensor(491494.5938, device='cuda:0', grad_fn=<AddBackward0>)
tensor(485450.6875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(500642.4375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(456014.4375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(459489.2812, device='cuda:0', grad_fn=<AddBackward0>)
tensor(496176.0312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(453539.7188, device='cuda:0', grad_fn=<AddBackward0>)
tensor(518645.8750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(459875.1875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(535716.8750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(478560.7812, devi

 65%|██████▌   | 13/20 [09:16<04:59, 42.81s/it]

tensor(483018.7188, device='cuda:0', grad_fn=<AddBackward0>)
tensor(491883.4062, device='cuda:0', grad_fn=<AddBackward0>)
tensor(539674.5625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(513071.1875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(471183.6562, device='cuda:0', grad_fn=<AddBackward0>)
tensor(475222.4062, device='cuda:0', grad_fn=<AddBackward0>)
tensor(558578.7500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(550604.1250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(482405.9375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(482730.9688, device='cuda:0', grad_fn=<AddBackward0>)
tensor(489631.1562, device='cuda:0', grad_fn=<AddBackward0>)
tensor(452092.0312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(494906.6562, device='cuda:0', grad_fn=<AddBackward0>)
tensor(442080., device='cuda:0', grad_fn=<AddBackward0>)
tensor(503988.3125, device='cuda:0', grad_fn=<AddBackward0>)
tensor(501580.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(505578.6562, device='

 70%|███████   | 14/20 [09:59<04:16, 42.82s/it]

tensor(545046., device='cuda:0', grad_fn=<AddBackward0>)
tensor(583974.8125, device='cuda:0', grad_fn=<AddBackward0>)
tensor(531929.5625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(470638.0312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(465068., device='cuda:0', grad_fn=<AddBackward0>)
tensor(505524.0938, device='cuda:0', grad_fn=<AddBackward0>)
tensor(513623.6875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(486227.4062, device='cuda:0', grad_fn=<AddBackward0>)
tensor(482753.4062, device='cuda:0', grad_fn=<AddBackward0>)
tensor(513689.3125, device='cuda:0', grad_fn=<AddBackward0>)
tensor(509918.1250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(492238.5625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(427169.0625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(545187.3750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(481465.1250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(514192.1875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(491566.5625, device='cuda

 75%|███████▌  | 15/20 [10:42<03:33, 42.70s/it]

tensor(457142.0312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(521937.2500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(533925.9375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(483046.6250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(528979.6875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(505089.2812, device='cuda:0', grad_fn=<AddBackward0>)
tensor(474241.9688, device='cuda:0', grad_fn=<AddBackward0>)
tensor(465238.0625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(475141.2500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(500828.3438, device='cuda:0', grad_fn=<AddBackward0>)
tensor(531346.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(496502.5000, device='cuda:0', grad_fn=<AddBackward0>)
tensor(467549.9062, device='cuda:0', grad_fn=<AddBackward0>)
tensor(502904.9375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(500967.7812, device='cuda:0', grad_fn=<AddBackward0>)
tensor(426588.0938, device='cuda:0', grad_fn=<AddBackward0>)
tensor(542487., device='

 80%|████████  | 16/20 [11:24<02:50, 42.53s/it]

tensor(491514.7500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(456100.1562, device='cuda:0', grad_fn=<AddBackward0>)
tensor(499222.4375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(482183.5312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(454081.0312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(466742.5625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(473458.9688, device='cuda:0', grad_fn=<AddBackward0>)
tensor(527438.8750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(490735.3750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(449398.1250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(461201.1875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(433856.4688, device='cuda:0', grad_fn=<AddBackward0>)
tensor(548244.3125, device='cuda:0', grad_fn=<AddBackward0>)
tensor(512711.2500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(451806., device='cuda:0', grad_fn=<AddBackward0>)
tensor(485005.3750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(483006.7500, device='

 85%|████████▌ | 17/20 [12:06<02:07, 42.39s/it]

tensor(464632.2500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(521314.6562, device='cuda:0', grad_fn=<AddBackward0>)
tensor(527617.0625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(486496.2812, device='cuda:0', grad_fn=<AddBackward0>)
tensor(481516.6875, device='cuda:0', grad_fn=<AddBackward0>)
tensor(490423.5312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(464988.9062, device='cuda:0', grad_fn=<AddBackward0>)
tensor(457457.4375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(497814.9688, device='cuda:0', grad_fn=<AddBackward0>)
tensor(435572.3125, device='cuda:0', grad_fn=<AddBackward0>)
tensor(463081.6562, device='cuda:0', grad_fn=<AddBackward0>)
tensor(499817.0312, device='cuda:0', grad_fn=<AddBackward0>)
tensor(448658.4375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(493813.2500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(472870.8750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(492549.1250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(496076.3750, devi

 90%|█████████ | 18/20 [12:48<01:24, 42.30s/it]

tensor(493854.0938, device='cuda:0', grad_fn=<AddBackward0>)
tensor(468791.2812, device='cuda:0', grad_fn=<AddBackward0>)
tensor(486138.1562, device='cuda:0', grad_fn=<AddBackward0>)
tensor(446362.6250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(498076.8438, device='cuda:0', grad_fn=<AddBackward0>)
tensor(496157.8438, device='cuda:0', grad_fn=<AddBackward0>)
tensor(475429.4375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(508219.2500, device='cuda:0', grad_fn=<AddBackward0>)
tensor(537365.1250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(514909., device='cuda:0', grad_fn=<AddBackward0>)
tensor(454199.8438, device='cuda:0', grad_fn=<AddBackward0>)
tensor(502949.5625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(483019.3125, device='cuda:0', grad_fn=<AddBackward0>)
tensor(493767.1562, device='cuda:0', grad_fn=<AddBackward0>)
tensor(459879.8438, device='cuda:0', grad_fn=<AddBackward0>)
tensor(447755.6562, device='cuda:0', grad_fn=<AddBackward0>)
tensor(496561.2188, device='

 95%|█████████▌| 19/20 [13:30<00:42, 42.26s/it]

tensor(435721.0625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(502921.8438, device='cuda:0', grad_fn=<AddBackward0>)
tensor(537985.1250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(471795.4062, device='cuda:0', grad_fn=<AddBackward0>)
tensor(502709.4375, device='cuda:0', grad_fn=<AddBackward0>)
tensor(496579.1562, device='cuda:0', grad_fn=<AddBackward0>)
tensor(495587.2812, device='cuda:0', grad_fn=<AddBackward0>)
tensor(499000.5625, device='cuda:0', grad_fn=<AddBackward0>)
tensor(484532.3750, device='cuda:0', grad_fn=<AddBackward0>)
tensor(448906.8438, device='cuda:0', grad_fn=<AddBackward0>)
tensor(498294.7812, device='cuda:0', grad_fn=<AddBackward0>)
tensor(463099.3125, device='cuda:0', grad_fn=<AddBackward0>)
tensor(483027.2812, device='cuda:0', grad_fn=<AddBackward0>)
tensor(446956.6250, device='cuda:0', grad_fn=<AddBackward0>)
tensor(478200.8125, device='cuda:0', grad_fn=<AddBackward0>)
tensor(515652.9062, device='cuda:0', grad_fn=<AddBackward0>)
tensor(494237.9688, devi

100%|██████████| 20/20 [14:12<00:00, 42.20s/it]

tensor(507324.0312, device='cuda:0', grad_fn=<AddBackward0>)





In [4]:
t.save(model.state_dict(), 'models/state_dicts/03_02.pth')