In [1]:
import numpy as np
import matplotlib.pyplot as plt
import kagglehub
from data_loading import load_dataset
from preprocessing import preprocess_images, generated_to_image

from Models.ModelBase import ModelBase
from Models.WandbConfig import WandbConfig
from Models.VariationalAutoEncoder import VariationalAutoEncoder
from Models.GenerativeAdversarialNetwork import GenerativeAdversarialNetwork
from FidScorer import FidScorer

In [2]:
images = load_dataset(kagglehub.dataset_download("borhanitrash/cat-dataset"))
images = preprocess_images(images)

In [2]:
vae = VariationalAutoEncoder(
    latent_dim=512,
    hidden_dims=[128, 128, 256, 256, 512],
    learning_rate=1e-3,
    lr_decay=0.999,
    beta_1=0.9,
    beta_2=0.999,
    weight_decay=1e-2,
    kl_weight=1.0,
    print_every=5,
    fid_scorer=FidScorer(),
    n_images_for_fid=1000
).with_wandb(WandbConfig(
    experiment_name="vae_long_run",
    artifact_name="vae_default_long_run",
    init_project=True
))

In [None]:
vae.train(images, epochs=500, batch_size=8)

In [74]:
gan = GenerativeAdversarialNetwork(
    latent_dim=512,
    hidden_dims_generator=[512, 256, 128, 64],
    hidden_dims_discriminator=[64, 128, 256, 512],
    learning_rate_generator=1e-4,
    learning_rate_discriminator=7e-5,
    beta_1=0.0,
    beta_2=0.9,
    weight_decay=0.0,
    print_every=5,
    fid_scorer=FidScorer(),
    n_images_for_fid=1000,
    critic_iterations=1
).with_wandb(WandbConfig(
    experiment_name="gan_tweaked_lr_long_run",
    artifact_name="gan_tweaked_lr_long_run",
    init_project=True
))

In [None]:
gan.train(images, epochs=200, batch_size=64)

In [5]:
wgan_gp = GenerativeAdversarialNetwork(
    latent_dim=512,
    hidden_dims_generator=[512, 256, 128, 64],
    hidden_dims_discriminator=[64, 128, 256, 512],
    learning_rate_generator=1e-4,
    learning_rate_discriminator=5e-5,
    beta_1=0.0,
    beta_2=0.9,
    weight_decay=0.0,
    print_every=3,
    fid_scorer=FidScorer(),
    n_images_for_fid=1000,
    use_wgan_gp=True,
    gradient_penalty_weight=10.0,
    critic_iterations=2
).with_wandb(WandbConfig(
    experiment_name="wgan_gp_tweaked_long_run",
    artifact_name="wgan_gp_tweaked_long_run",
    init_project=True
))

In [None]:
wgan_gp.train(images, epochs=150, batch_size=64)

In [2]:
def generate_from_latent_line(model: ModelBase, latent_start: np.ndarray, latent_end: np.ndarray, n_samples: int) -> np.ndarray:
    latents = np.linspace(latent_start, latent_end, n_samples)
    images = model.generate_from_latent(latents)
    return images

def generate_from_latent_circle(model: ModelBase, latent_first: np.ndarray, latent_second: np.ndarray, n_samples: int) -> np.ndarray:
    latents = np.zeros((n_samples, latent_first.shape[0]))
    r = np.linalg.norm(latent_first)
    u1 = latent_first / np.linalg.norm(latent_first)
    u2 = latent_second / np.linalg.norm(latent_second)
    for i in range(n_samples):
        theta = 2 * np.pi * i / n_samples
        latents[i] = u1 * r * np.cos(theta) + u2 * r * np.sin(theta)
    images = model.generate_from_latent(latents)
    return images

def plot_image_series(image_rows: list[np.ndarray], titles: list[str]) -> None:
    if len(image_rows) != len(titles):
        raise ValueError("Number of image rows and titles must be the same")
    
    if any(len(row) != len(image_rows[0]) for row in image_rows):
        raise ValueError("All rows must have the same number of images")
    
    # Create figure with extra space for titles
    fig = plt.figure(figsize=(len(image_rows[0]) * 2, len(image_rows) * 2.5))
    
    # Create a grid with extra space for titles
    gs = plt.GridSpec(len(image_rows) * 2, len(image_rows[0]), height_ratios=[0.1, 1] * len(image_rows))
    
    # Add titles and images
    for i in range(len(image_rows)):
        # Add title centered above each row
        ax = fig.add_subplot(gs[i * 2, :])
        ax.text(0.5, 0.5, titles[i], ha='center', va='center', fontsize=16)
        ax.axis('off')
        
        # Add images in the row below the title
        for j in range(len(image_rows[i])):
            ax = fig.add_subplot(gs[i * 2 + 1, j])
            ax.imshow(generated_to_image(image_rows[i][j]))
            ax.axis('off')

In [None]:
loaded_model = VariationalAutoEncoder.load_state_dict(WandbConfig.get_artifact_from_wandb("vae_default_long_run"))

In [None]:
zero_image = loaded_model.generate_from_latent(np.zeros((512,)))[0]
other_images = loaded_model.generate(4)

generated_images = np.concatenate([zero_image[None, :], other_images], axis=0)
plot_image_series([generated_images], ["VAE images"])

In [None]:
latent_start = np.random.normal(0, 1, (512,))
latent_end = np.random.normal(0, 1, (512,))

line_images = generate_from_latent_line(loaded_model, latent_start, latent_end, 10)
circle_images = generate_from_latent_circle(loaded_model, latent_start, latent_end, 10)
plot_image_series([line_images, circle_images], ["Line between two latent representations", "Circle around the origin in latent space"])
plt.show()

In [11]:
def plot_vae_metrics(metrics: dict[str, list[float | int | str]]) -> None:
    total_loss = metrics["total_loss"]
    reconstruction_loss = metrics["recon_loss"]
    kl_loss = metrics["kl_loss"]
    fid_score = metrics["fid_score"]
    epochs = range(1, len(total_loss)+1)
    
    plt.figure(figsize=(10, 5))
    _, axs = plt.subplots(2, 1, figsize=(12, 10))
    left = axs[0]
    right = axs[1]
    
    left.plot(epochs, total_loss, label="Total Loss")
    left.plot(epochs, reconstruction_loss, label="Reconstruction Loss")
    left.plot(epochs, kl_loss, label="KL Loss")
    left.legend()
    left.set_ylabel("Loss")
    
    right.plot(epochs, fid_score, label="FID Score")
    right.set_xlabel("Epoch")
    right.set_ylabel("FID Score")
    
def plot_gan_metrics(metrics: dict[str, list[float | int | str]]) -> None:
    discriminator_loss = metrics["discriminator_loss"]
    generator_loss = metrics["generator_loss"]
    fid_score = metrics["fid_score"]
    real_acc = metrics["discriminator_real_accuracy"]
    fake_acc = metrics["discriminator_fake_accuracy"]
    epochs = range(1, len(discriminator_loss)+1)
    
    plt.figure(figsize=(10, 5))
    _, axs = plt.subplots(3, 1, figsize=(12, 15))
    top = axs[0]
    middle = axs[1]
    bottom = axs[2]
    
    top.plot(epochs, discriminator_loss, label="Discriminator Loss")
    top.plot(epochs, generator_loss, label="Generator Loss")
    top.legend()
    top.set_ylabel("Loss")
    
    middle.plot(epochs, fid_score, label="FID Score")
    middle.set_ylabel("FID Score")
    
    bottom.plot(epochs, real_acc, label="Real Accuracy")
    bottom.plot(epochs, fake_acc, label="Fake Accuracy")
    bottom.legend()
    bottom.set_xlabel("Epoch")
    bottom.set_ylabel("Accuracy")
    
def plot_wgan_gp_metrics(metrics: dict[str, list[float | int | str]]) -> None:
    discriminator_loss = metrics["discriminator_loss"]
    generator_loss = metrics["generator_loss"]
    fid_score = metrics["fid_score"]
    gradient_penalty = metrics["gradient_penalty"]
    wasserstein_distance = metrics["wasserstein_distance"]
    epochs = range(1, len(discriminator_loss)+1)
    
    plt.figure(figsize=(10, 5))
    _, axs = plt.subplots(3, 1, figsize=(12, 15))
    top = axs[0]
    middle = axs[1]
    bottom = axs[2]
    
    top.plot(epochs, discriminator_loss, label="Discriminator Loss")
    top.plot(epochs, generator_loss, label="Generator Loss")
    top.legend()
    top.set_ylabel("Loss")
    
    middle.plot(epochs, fid_score, label="FID Score")
    middle.set_ylabel("FID Score")
    
    bottom.plot(epochs, gradient_penalty, label="Gradient Penalty")
    bottom.plot(epochs, wasserstein_distance, label="Wasserstein Distance")
    bottom.legend()
    bottom.set_xlabel("Epoch")

In [None]:
wgan_gp_run_metrics = WandbConfig.get_run_metrics("wgan_gp_tweaked_long_run")
plot_wgan_gp_metrics(wgan_gp_run_metrics)
plt.show()

In [None]:
gan_run_metrics = WandbConfig.get_run_metrics("gan_tweaked_lr_long_run")
plot_gan_metrics(gan_run_metrics)
plt.show()

In [None]:
vae_run_metrics = WandbConfig.get_run_metrics("vae_long_run")
plot_vae_metrics(vae_run_metrics)
plt.show()