In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
"""
Generate some fake data with noise in it just in case

"""

from sklearn.datasets import fetch_olivetti_faces

images = fetch_olivetti_faces()["images"]
assert images.shape[1] == images.shape[2]

# Downsample them to 32x32 to match the real data
images = images[:, ::2, ::2]

# Normalize the images to [-1, 1]
images = 2 * (images - images.min()) / (images.max() - images.min()) - 1

In [None]:
"""
Plot an example

"""

import numpy as np
import matplotlib.pyplot as plt

# Plot both
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Plot clean data
plot_kw = {"cmap": "grey", "aspect": "auto", "vmin": 0, "vmax": 1}
for axis, im in zip(axes, images, strict=False):
    axis.imshow(im, **plot_kw)
    axis.axis("off")

fig.tight_layout()

In [None]:
"""
Turn our images into a dataloader with the right transforms
"""

import torch


class ImageLoader(torch.utils.data.Dataset):
    def __init__(self, images):
        assert np.isclose(images.max(), 1.0, atol=0.01)
        assert np.isclose(images.min(), -1.0, atol=0.01)

        self.images = images
        self.mean, self.std = images.mean(), images.std()

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = torch.FloatTensor(self.images[idx]).unsqueeze(0)
        return image


dataset = ImageLoader(images)

In [None]:
""" First, we'll quickly train a bad model"""
import pathlib
import torch

batch_size = 64
config = {
    "n_epochs": 500,
    "n_critic": 5,
    "lambda_gp": 10,
    "learning_rate": 0.00005,
    "latent_dim": 2,
    "img_size": images.shape[1],
    "channels": 1,
    "batch_size": batch_size,
    "dataloader": torch.utils.data.DataLoader(
        dataset, batch_size=batch_size, shuffle=True, num_workers=8
    ),
    "output_dir": pathlib.Path("output/"),
}
if not config["output_dir"].is_dir():
    config["output_dir"].mkdir(parents=True)

In [None]:
"""
Define the GAN

"""

from current_denoising.generation import dcgan

generator = dcgan.Generator(config)
discriminator = dcgan.Discriminator(config)

In [None]:
"""
Train the GAN

"""

generator, discriminator, gen_loss, disc_loss, fid_scores = dcgan.train(generator, discriminator, config)

In [None]:
from current_denoising.plotting import training

fig, axes = plt.subplots(1, 2, figsize=(15, 5))
_ = training.plot_losses(
    gen_loss, disc_loss, labels=("Generator Loss", "Discriminator Loss"), axis=axes[0]
)

axes[1].plot([20 * i for i, _ in enumerate(fid_scores)], fid_scores)
axes[1].set_title("fid_score")
fig.savefig("bad_fid.png")

In [None]:
"""Now, we'll train a better one"""

config["latent_dim"] = 64

generator, discriminator, gen_loss, disc_loss, fid_scores = dcgan.train(
    dcgan.Generator(config),
    dcgan.Discriminator(config),
    config,
)

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(15, 5))
_ = training.plot_losses(
    gen_loss, disc_loss, labels=("Generator Loss", "Discriminator Loss"), axis=axes[0]
)

axes[1].plot([20 * i for i, _ in enumerate(fid_scores)], fid_scores)
axes[1].set_title("fid_score")
fig.savefig("good_fid.png")

In [None]:
"""
Show some generated patches

"""
from matplotlib import animation
img_paths = sorted(list(pathlib.Path(config["output_dir"]).glob("*/*.png")))


# Quick and simple version
fig, ax = plt.subplots(figsize=(10, 10))
ax.axis('off')

def animate(frame):
    ax.clear()
    ax.axis('off')
    img = plt.imread(img_paths[frame])
    ax.imshow(img, cmap='gray')
    ax.set_title(f'Epoch {frame * 20}')  # Assuming every 20 epochs

anim = animation.FuncAnimation(fig, animate, frames=len(img_paths), interval=100, repeat=True)
anim.save("gan_simple.mp4", writer='ffmpeg', fps=10)


In [None]:
"""
Generate a new image and display it
"""
from torch.autograd import Variable

z_g = Variable(
    torch.cuda.FloatTensor(np.random.normal(0, 1, (batch_size, config["latent_dim"])))
)
gen_imgs = generator(z_g)

In [None]:
from current_denoising.plotting import img_validation

fig = img_validation.show(gen_imgs, cmap="gray")

In [None]:
fig = img_validation.hist(gen_imgs, bins=50, density=True)
fig.suptitle("Generated images")

In [None]:
fig = img_validation.hist(next(iter(config["dataloader"])), bins=50, density=True)
fig.suptitle("Real images")

In [None]:
fig = img_validation.fft(gen_imgs)
fig.suptitle("Generated images FFT")

In [None]:
fig = img_validation.fft(next(iter(config["dataloader"])))
fig.suptitle("Real images FFT")