In [1]:
import numpy as np
import matplotlib.pyplot as plt
import kagglehub
from data_loading import load_dataset
from preprocessing import preprocess_images, generated_to_image

images = load_dataset(kagglehub.dataset_download("borhanitrash/cat-dataset"))
images = preprocess_images(images)

In [2]:
from Models.WandbConfig import WandbConfig
from Models.VariationalAutoEncoder import VariationalAutoEncoder
from FidScorer import FidScorer

vae = VariationalAutoEncoder(
    latent_dim=512,
    hidden_dims=[128, 128, 256, 256, 512],
    learning_rate=1e-3,
    lr_decay=0.999,
    beta_1=0.9,
    beta_2=0.999,
    weight_decay=1e-2,
    kl_weight=1.0,
    print_every=5,
    fid_scorer=FidScorer(),
    n_images_for_fid=1000
).with_wandb(WandbConfig(
    experiment_name="vae_long_run",
    config_name="vae_default",
    artifact_name="vae_default_long_run",
    init_project=True
))

In [None]:
vae.train(images, epochs=500, batch_size=8)

In [2]:
from Models.GenerativeAdversarialNetwork import GenerativeAdversarialNetwork

gan = GenerativeAdversarialNetwork(
    latent_dim=512,
    hidden_dims_generator=[1024, 512, 256, 128, 64, 32],
    hidden_dims_discriminator=[32, 64, 128, 256, 512, 1024],
    learning_rate_generator=1e-4,
    learning_rate_discriminator=1e-4,
    beta_1=0.5,
    beta_2=0.999,
    weight_decay=0.0,
    print_every=3,
    fid_scorer=FidScorer(),
    n_images_for_fid=1000
).with_wandb(WandbConfig(
    experiment_name="gan_long_run",
    config_name="gan_default",
    artifact_name="gan_default_long_run",
    init_project=True
))

In [None]:
gan.train(images, epochs=150, batch_size=8)