In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
"""
Generate some fake data with noise in it just in case

"""

from sklearn.datasets import fetch_olivetti_faces
from skimage.util import random_noise

images = fetch_olivetti_faces()["images"]
assert images.shape[1] == images.shape[2]

# Downsample them to 32x32 to match the real data
images = images[:, ::2, ::2]

# Add some noise
noisy_images = random_noise(images, mode="speckle", var=0.01)

# Normalize the images to [-1, 1]
images = 2 * (images - images.min()) / (images.max() - images.min()) - 1
noisy_images = (
    2 * (noisy_images - noisy_images.min()) / (noisy_images.max() - noisy_images.min())
    - 1
)

In [None]:
"""
Plot an example

"""

import numpy as np
import matplotlib.pyplot as plt

# Plot both
fig, axes = plt.subplots(1, 3, figsize=(15, 5))

# Plot clean data
plot_kw = {"cmap": "grey", "aspect": "auto", "vmin": 0, "vmax": 1}
im1 = axes[0].imshow(images[0], **plot_kw)
im2 = axes[1].imshow(noisy_images[0], **plot_kw)
im3 = axes[2].imshow(noisy_images[0] - images[0], **plot_kw)

for axis, im, label in zip(axes, [im1, im2, im3], ["Clean", "Noisy", "Difference"]):
    axis.set_title(label)
    axis.set_xlabel("X")
    axis.set_ylabel("Y")
    plt.colorbar(im, ax=axis)

fig.tight_layout()

In [None]:
"""
Turn our images into a dataloader with the right transforms
"""

import torch


class ImageLoader(torch.utils.data.Dataset):
    def __init__(self, images):
        assert np.isclose(images.max(), 1.0, atol=0.01)
        assert np.isclose(images.min(), -1.0, atol=0.01)

        self.images = images
        self.mean, self.std = images.mean(), images.std()

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        image = torch.FloatTensor(self.images[idx]).unsqueeze(0)
        return image


dataset = ImageLoader(images)

In [None]:
import torch

config = {
    "n_epochs": 100,
    "n_critic": 5,
    "lambda_gp": 10,
    "learning_rate": 0.0001,
    "latent_dim": 64,
    "img_size": images.shape[1],
    "channels": 1,
    "loss": torch.nn.BCELoss(),
    "dataloader": torch.utils.data.DataLoader(
        dataset, batch_size=64, shuffle=True, num_workers=8
    ),
}

In [None]:
"""
Define the GAN

"""

from current_denoising.generation import dcgan

generator = dcgan.Generator(config)
discriminator = dcgan.Discriminator(config)

In [None]:
"""
Train the GAN

"""

generator, discriminator, gen_loss, disc_loss = dcgan.train(generator, discriminator, config)

In [None]:
from current_denoising.plotting import training

training.plot_losses(gen_loss, disc_loss)

In [None]:
"""
Show some generated patches

"""