In [1]:
import yaml
from PIL import Image

import torch
from torch.utils.data import DataLoader
from torch.optim import Adam
import torchvision.transforms as transforms

from models.vae import VAE
from models.lpips import LPIPS
from models.discriminator import PatchGANDiscriminator
from utils.dataset import Dataset
from utils.train_vae import train


DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
CONFIG = "configs/vae.yaml"

In [2]:
# Read the config file #
with open(CONFIG, 'r') as file:
    try:
        config = yaml.safe_load(file)
    except yaml.YAMLError as exc:
        print(exc)
model_config = config['model_config']
dataset_config = config['dataset_config']
training_config = config['training_config']

In [None]:
model = VAE(model_config = model_config).to(DEVICE)
lpips_model = LPIPS().eval().to(DEVICE) # frozen
discriminator = PatchGANDiscriminator().to(DEVICE)

In [4]:
transform = [
    transforms.ToTensor(),
    transforms.Resize((dataset_config['IMG_SIZE'], dataset_config['IMG_SIZE']), Image.BICUBIC),
    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
]


data_loader = DataLoader(
    Dataset(dataset_config['ROOT'],transform),
    batch_size= dataset_config['BATCH_SIZE'],
    shuffle = True,
    num_workers = 2
)

In [5]:
recon_criterion = torch.nn.MSELoss()
adv_criterion = torch.nn.BCEWithLogitsLoss()

In [6]:
optimizer_d = Adam(discriminator.parameters(), lr = 1E-5, betas=(0.5, 0.999))
optimizer_g = Adam(model.parameters(), lr=1E-5, betas=(0.5, 0.999))

In [None]:
train(
    model = model,
    discriminator = discriminator,
    lpips_model = lpips_model,
    num_epochs = training_config['NUM_EPOCHS'],
    data_loader = data_loader,
    optimizer_g = optimizer_g,
    optimizer_d = optimizer_d,
    recon_criterion = recon_criterion,
    adv_criterion = adv_criterion,
    adv_start = training_config['ADV_START'],
    sample_step = training_config['SAMPLE_STEP']
)