In [1]:
import numpy as np
import matplotlib.pyplot as plt
import kagglehub
from data_loading import load_dataset
from preprocessing import preprocess_images, generated_to_image

from Models.ModelBase import ModelBase
from Models.WandbConfig import WandbConfig
from Models.VariationalAutoEncoder import VariationalAutoEncoder
from Models.GenerativeAdversarialNetwork import GenerativeAdversarialNetwork
from FidScorer import FidScorer

images = load_dataset(kagglehub.dataset_download("borhanitrash/cat-dataset"))
images = preprocess_images(images)

In [2]:
vae = VariationalAutoEncoder(
    latent_dim=512,
    hidden_dims=[128, 128, 256, 256, 512],
    learning_rate=1e-3,
    lr_decay=0.999,
    beta_1=0.9,
    beta_2=0.999,
    weight_decay=1e-2,
    kl_weight=1.0,
    print_every=5,
    fid_scorer=FidScorer(),
    n_images_for_fid=1000
).with_wandb(WandbConfig(
    experiment_name="vae_long_run",
    config_name="vae_default",
    artifact_name="vae_default_long_run",
    init_project=True
))

In [None]:
vae.train(images, epochs=500, batch_size=8)

In [2]:
gan = GenerativeAdversarialNetwork(
    latent_dim=512,
    hidden_dims_generator=[1024, 512, 256, 128, 64, 32],
    hidden_dims_discriminator=[32, 64, 128, 256, 512, 1024],
    learning_rate_generator=1e-4,
    learning_rate_discriminator=1e-4,
    beta_1=0.5,
    beta_2=0.999,
    weight_decay=0.0,
    print_every=3,
    fid_scorer=FidScorer(),
    n_images_for_fid=1000
).with_wandb(WandbConfig(
    experiment_name="gan_long_run",
    config_name="gan_default",
    artifact_name="gan_default_long_run",
    init_project=True
))

In [None]:
gan.train(images, epochs=150, batch_size=8)

In [None]:
vae = VariationalAutoEncoder.load_state_dict(WandbConfig.get_artifact_from_wandb("vae_default_long_run"))

In [89]:
def generate_from_latent_line(model: ModelBase, latent_start: np.ndarray, latent_end: np.ndarray, n_samples: int) -> np.ndarray:
    latents = np.linspace(latent_start, latent_end, n_samples)
    images = model.generate_from_latent(latents)
    return images

def generate_from_latent_circle(model: ModelBase, latent_first: np.ndarray, latent_second: np.ndarray, n_samples: int) -> np.ndarray:
    latents = np.zeros((n_samples, latent_first.shape[0]))
    r = np.linalg.norm(latent_first)
    u1 = latent_first / np.linalg.norm(latent_first)
    u2 = latent_second / np.linalg.norm(latent_second)
    for i in range(n_samples):
        theta = 2 * np.pi * i / n_samples
        latents[i] = u1 * r * np.cos(theta) + u2 * r * np.sin(theta)
    images = model.generate_from_latent(latents)
    return images

def plot_image_series(image_rows: list[np.ndarray], titles: list[str]) -> None:
    if len(image_rows) != len(titles):
        raise ValueError("Number of image rows and titles must be the same")
    
    if any(len(row) != len(image_rows[0]) for row in image_rows):
        raise ValueError("All rows must have the same number of images")
    
    # Create figure with extra space for titles
    fig = plt.figure(figsize=(len(image_rows[0]) * 2, len(image_rows) * 2.5))
    
    # Create a grid with extra space for titles
    gs = plt.GridSpec(len(image_rows) * 2, len(image_rows[0]), height_ratios=[0.1, 1] * len(image_rows))
    
    # Add titles and images
    for i in range(len(image_rows)):
        # Add title centered above each row
        ax = fig.add_subplot(gs[i * 2, :])
        ax.text(0.5, 0.5, titles[i], ha='center', va='center', fontsize=16)
        ax.axis('off')
        
        # Add images in the row below the title
        for j in range(len(image_rows[i])):
            ax = fig.add_subplot(gs[i * 2 + 1, j])
            ax.imshow(generated_to_image(image_rows[i][j]))
            ax.axis('off')

In [None]:
latent_start = np.random.normal(0, 1, (512,))
latent_end = np.random.normal(0, 1, (512,))

line_images = generate_from_latent_line(vae, latent_start, latent_end, 10)
circle_images = generate_from_latent_circle(vae, latent_start, latent_end, 10)
plot_image_series([line_images, circle_images], ["Line between two latent representations", "Circle around the origin in latent space"])
plt.show()