# VAE Training - Faces dataset

In [1]:
import torch as t
import torch.nn.functional as F
import zipfile
import os
from tqdm import tqdm
from models.CelebVariationalAutoencoder import CelebVariationalAutoencoder
from utils.data.FaceDataset import FaceDataset
from google_drive_downloader import GoogleDriveDownloader as gdd

# this will take some time :), its 1.3gb to download
faces_zip = 'data/faces.zip'
if not os.path.exists(faces_zip):
    gdd.download_file_from_google_drive(file_id='0B7EVK8r0v71pZjFTYXZWM3FlRnM',
                                        dest_path=faces_zip)
    with zipfile.ZipFile(faces_zip, 'r') as zip_ref:
        zip_ref.extractall('data/faces')

bs = 64
train_ds = FaceDataset("data/faces/img_align_celeba/")
train_dl = t.utils.data.DataLoader(dataset=train_ds, batch_size=bs, shuffle=True, drop_last=True)

device = t.device('cuda') if t.cuda.is_available() else 'cpu'
model = CelebVariationalAutoencoder(train_ds[0][0][None], in_c=3, enc_out_c=[32, 64, 64, 64],
                               enc_ks=[3, 3, 3, 3], enc_pads=[1, 1, 0, 1], enc_strides=[1, 2, 2, 1],
                               dec_out_c=[64, 64, 32, 3], dec_ks=[3, 3, 3, 3], dec_strides=[1, 2, 2, 1],
                               dec_pads=[1, 0, 1, 1], dec_op_pads=[0, 1, 1, 0], z_dim=200)
model.cuda(device)
model.train()

CelebVariationalAutoencoder(
  (enc_conv_layers): Sequential(
    (0): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.25, inplace=False)
    )
    (1): Sequential(
      (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.25, inplace=False)
    )
    (2): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(2, 2))
      (1): LeakyReLU(negative_slope=0.01)
      (2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (3): Dropout(p=0.25, inplace=False)
    )
    (3): Sequential(
      (0): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): LeakyReLU

In [2]:
def vae_kl_loss(mu, log_var):
    return -.5 * t.sum(1 + log_var - mu ** 2 - log_var.exp())

def vae_loss(y_pred, mu, log_var, y_true, r_loss_factor=1000):
    r_loss = F.binary_cross_entropy(y_pred, y_true, reduction='sum')
    kl_loss = vae_kl_loss(mu, log_var)
    return r_loss_factor * r_loss + kl_loss

lr = .0005
for epoch in tqdm(range(3)):
    optimizer = t.optim.Adam(model.parameters(), lr=lr / (epoch * 2 + 1), betas=(.9, .99), weight_decay=1e-2)
    for i, (data, _) in enumerate(train_dl):
        data = data.to(device)
        optimizer.zero_grad()
        pred, mu, log_var = model(data)
        loss = vae_loss(pred, mu, log_var, data)
        loss.backward()
        t.nn.utils.clip_grad_norm_(model.parameters(), .25)
        optimizer.step()
        if i % 33 == 0:
            print(loss / bs)

print(loss)

  0%|          | 0/3 [00:00<?, ?it/s]

tensor(55878044., device='cuda:0', grad_fn=<DivBackward0>)
tensor(42448192., device='cuda:0', grad_fn=<DivBackward0>)
tensor(35465824., device='cuda:0', grad_fn=<DivBackward0>)
tensor(30940976., device='cuda:0', grad_fn=<DivBackward0>)
tensor(30071830., device='cuda:0', grad_fn=<DivBackward0>)
tensor(29281372., device='cuda:0', grad_fn=<DivBackward0>)
tensor(29052596., device='cuda:0', grad_fn=<DivBackward0>)
tensor(27535826., device='cuda:0', grad_fn=<DivBackward0>)
tensor(27238960., device='cuda:0', grad_fn=<DivBackward0>)
tensor(27496606., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26781960., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26130110., device='cuda:0', grad_fn=<DivBackward0>)
tensor(25813870., device='cuda:0', grad_fn=<DivBackward0>)
tensor(27273876., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26892190., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26024736., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26949876., device='cuda:0', grad_fn=<DivBackward0

 33%|███▎      | 1/3 [26:47<53:34, 1607.02s/it]

tensor(26885404., device='cuda:0', grad_fn=<DivBackward0>)
tensor(25962064., device='cuda:0', grad_fn=<DivBackward0>)
tensor(25586096., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26627036., device='cuda:0', grad_fn=<DivBackward0>)
tensor(25364480., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26790918., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26882406., device='cuda:0', grad_fn=<DivBackward0>)
tensor(27293300., device='cuda:0', grad_fn=<DivBackward0>)
tensor(25610908., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26584938., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26097558., device='cuda:0', grad_fn=<DivBackward0>)
tensor(25284448., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26113560., device='cuda:0', grad_fn=<DivBackward0>)
tensor(25607128., device='cuda:0', grad_fn=<DivBackward0>)
tensor(25341408., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26445146., device='cuda:0', grad_fn=<DivBackward0>)
tensor(25692274., device='cuda:0', grad_fn=<DivBackward0

 67%|██████▋   | 2/3 [53:06<26:38, 1598.89s/it]

tensor(26537340., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26078576., device='cuda:0', grad_fn=<DivBackward0>)
tensor(25644586., device='cuda:0', grad_fn=<DivBackward0>)
tensor(25917624., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26084042., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26706180., device='cuda:0', grad_fn=<DivBackward0>)
tensor(25959336., device='cuda:0', grad_fn=<DivBackward0>)
tensor(25217452., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26832634., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26372622., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26084352., device='cuda:0', grad_fn=<DivBackward0>)
tensor(25709944., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26116460., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26044390., device='cuda:0', grad_fn=<DivBackward0>)
tensor(25651988., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26190934., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26602564., device='cuda:0', grad_fn=<DivBackward0

100%|██████████| 3/3 [1:19:28<00:00, 1593.57s/it]

tensor(1657669632., device='cuda:0', grad_fn=<AddBackward0>)





In [3]:
t.save(model.state_dict(), 'models/state_dicts/03_05.pth')


In [4]:
lr = .0005
for epoch in tqdm(range(3, 5)):
    optimizer = t.optim.Adam(model.parameters(), lr=lr / (epoch * 2 + 1), betas=(.9, .99), weight_decay=1e-2)
    for i, (data, _) in enumerate(train_dl):
        data = data.to(device)
        optimizer.zero_grad()
        pred, mu, log_var = model(data)
        loss = vae_loss(pred, mu, log_var, data)
        loss.backward()
        t.nn.utils.clip_grad_norm_(model.parameters(), .25)
        optimizer.step()
        if i % 33 == 0:
            print(loss / bs)

print(loss)

  0%|          | 0/2 [00:00<?, ?it/s]

tensor(26122478., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26074100., device='cuda:0', grad_fn=<DivBackward0>)
tensor(25833470., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26289132., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26832608., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26591926., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26524614., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26615118., device='cuda:0', grad_fn=<DivBackward0>)
tensor(27162468., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26599992., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26118604., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26165174., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26279220., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26323964., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26328696., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26620050., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26036438., device='cuda:0', grad_fn=<DivBackward0

 50%|█████     | 1/2 [26:45<26:45, 1605.57s/it]

tensor(26394400., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26236288., device='cuda:0', grad_fn=<DivBackward0>)
tensor(25804286., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26903614., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26662452., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26644092., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26373228., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26691236., device='cuda:0', grad_fn=<DivBackward0>)
tensor(25686442., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26356246., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26791182., device='cuda:0', grad_fn=<DivBackward0>)
tensor(25466834., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26669580., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26525412., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26029278., device='cuda:0', grad_fn=<DivBackward0>)
tensor(25963376., device='cuda:0', grad_fn=<DivBackward0>)
tensor(26508208., device='cuda:0', grad_fn=<DivBackward0

100%|██████████| 2/2 [53:30<00:00, 1605.39s/it]

tensor(1662345472., device='cuda:0', grad_fn=<AddBackward0>)





In [5]:
t.save(model.state_dict(), 'models/state_dict/03_05_full.pth')


In [6]:
loss / bs

tensor(25974148., device='cuda:0', grad_fn=<DivBackward0>)