In [12]:
import os
import imageio
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from torchvision.utils import save_image
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.autograd import grad as torch_grad

In [93]:
# Impostiamo gli hyperparametri

# Iperparametri generali
batch_size = 64
num_epoch = 100
z_dimension = 100  # dimensione del vettore di rumore in input al generatore
# clip_value = 0.01

# Iperparametri ottimizzatore
learning_rate = 1e-4
b1 = 0.9
b2 = 0.999

# Iperparametri per la WGAN
n_critic = 5
gp_weight = 10

# Trasformazione delle immagini
def to_img(x):
    out = 0.5 * (x + 1)
    out = out.clamp(0, 1)
    out = out.view(-1, 1, 28, 28)
    return out

img_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Resize((32, 32))
    # transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [94]:
# Solita dataset ;)
dataset = datasets.MNIST('./data', transform=img_transform, download=True)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [95]:
if not os.path.exists('./w_img'):
    os.mkdir('./w_img')


# Definiamo la classe dei modelli nella GAN:
# Discriminatore
class WDiscriminator(nn.Module):
    def __init__(self, img_size=32, dim=16):
        """
        img_size : (int, int, int)
            Height and width must be powers of 2.  E.g. (32, 32, 1) or
            (64, 128, 3). Last number indicates number of channels, e.g. 1 for
            grayscale or 3 for RGB
        """
        super(WDiscriminator, self).__init__()

        self.img_size = img_size

        self.image_to_features = nn.Sequential(
            nn.Conv2d(1, dim, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(dim, 2 * dim, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(2 * dim, 4 * dim, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(4 * dim, 8 * dim, 4, 2, 1),
            nn.Sigmoid()
        )

        # 4 convolutions of stride 2, i.e. halving of size everytime
        # So output size will be 8 * (img_size / 2 ^ 4) * (img_size / 2 ^ 4)
        output_size = int(8 * dim * (img_size / 16) * (img_size / 16))
        # output_size = 128
        self.features_to_prob = nn.Sequential(
            nn.Linear(output_size, 1),
            # nn.Sigmoid()
        )
        print("output_size", output_size)

    def forward(self, input_data):
        batch_size = input_data.size()[0]
        x = self.image_to_features(input_data)
        print("x.shape", x.shape)
        x = x.view(batch_size, -1)
        print("x_view.shape", x.shape)
        return self.features_to_prob(x)

# Generatore
class Generator(nn.Module):
    def __init__(self, img_size=32, latent_dim=z_dimension, dim=16):
        super(Generator, self).__init__()

        self.dim = dim
        self.latent_dim = latent_dim
        # self.img_size = (img_size, img_size, 1)
        # self.feature_sizes = (self.img_size[0] / 16, self.img_size[1] / 16)

        self.img_size = img_size
        self.feature_sizes = int(img_size / 16)
        # self.feature_sizes = int(img_size / 14)

        # print(self.feature_sizes)

        self.latent_to_features = nn.Sequential(
            # nn.Linear(latent_dim, 8 * dim * self.feature_sizes[0] * self.feature_sizes[1]),
            nn.Linear(latent_dim, 8 * dim * self.feature_sizes * self.feature_sizes),
            nn.ReLU()
        )

        self.features_to_image = nn.Sequential(
            nn.ConvTranspose2d(8 * dim, 4 * dim, 4, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(4 * dim),
            nn.ConvTranspose2d(4 * dim, 2 * dim, 4, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(2 * dim),
            nn.ConvTranspose2d(2 * dim, dim, 4, 2, 1),
            nn.ReLU(),
            nn.BatchNorm2d(dim),
            nn.ConvTranspose2d(dim, 1, 4, 2, 1),
            nn.Sigmoid()
        )

    def forward(self, input_data):
        # Map latent into appropriate size for transposed convolutions
        x = self.latent_to_features(input_data)
        # Reshape
        # x = x.view(-1, 8 * self.dim, self.feature_sizes[0], self.feature_sizes[1])
        x = x.view(-1, 8 * self.dim, self.feature_sizes, self.feature_sizes)
        # Return generated image
        return self.features_to_image(x)

    def sample_latent(self, num_samples):
        return torch.randn((num_samples, self.latent_dim))

In [102]:
class Trainer():
    def __init__(self, G, D, g_optimizer, d_optimizer,
                 gp_weight=gp_weight, critic_iterations=n_critic, print_every=50,
                 use_cuda=True):
        self.G = G
        self.G_opt = g_optimizer
        self.D = D
        self.D_opt = d_optimizer
        self.losses = {'G': [], 'D': [], 'GP': [], 'gradient_norm': []}
        self.num_steps = 0
        self.use_cuda = use_cuda
        self.gp_weight = gp_weight
        self.critic_iterations = critic_iterations
        self.print_every = print_every

        if self.use_cuda:
            self.G.cuda()
            self.D.cuda()

    def _critic_train_iteration(self, data):
        """ """
        # Get generated data
        batch_size = data.size()[0]
        generated_data = self.sample_generator(batch_size)

        # Calcoliamo il valore del discriminatore su immagini reali e generate
        if self.use_cuda:
            data = data.cuda()

        print("data.shape", data.shape)
        print("generated_data.shape", generated_data.shape)


        d_real = self.D(data)
        d_generated = self.D(generated_data)


        # Get gradient penalty
        gradient_penalty = self._gradient_penalty(data, generated_data)
        self.losses['GP'].append(gradient_penalty.item())

        # Create total loss and optimize
        self.D_opt.zero_grad()
        d_loss = d_generated.mean() - d_real.mean() + gradient_penalty
        d_loss.backward()

        self.D_opt.step()

        # Record loss
        self.losses['D'].append(d_loss.item())

    def _generator_train_iteration(self, data):
        """ """
        self.G_opt.zero_grad()

        # Get generated data
        batch_size = data.size()[0]
        generated_data = self.sample_generator(batch_size)

        # Calculate loss and optimize
        d_generated = self.D(generated_data)
        g_loss = - d_generated.mean()
        g_loss.backward()
        self.G_opt.step()

        # Record loss
        self.losses['G'].append(g_loss.item())

    def _gradient_penalty(self, real_data, generated_data):
        batch_size = real_data.size()[0]

        # Calculate interpolation
        alpha = torch.rand(batch_size, 1, 1, 1)
        alpha = alpha.expand_as(real_data)
        if self.use_cuda:
            alpha = alpha.cuda()
        interpolated = alpha * real_data.data + (1 - alpha) * generated_data.data
        if self.use_cuda:
            interpolated = interpolated.cuda()

        # Calculate probability of interpolated examples
        prob_interpolated = self.D(interpolated)

        # Calculate gradients of probabilities with respect to examples
        gradients = torch_grad(outputs=prob_interpolated, inputs=interpolated,
                               grad_outputs=torch.ones(prob_interpolated.size(), requires_grad=True).cuda() if self.use_cuda else torch.ones(
                               prob_interpolated.size(), requires_grad=True),
                               create_graph=True, retain_graph=True, only_inputs=True)[0]

        # Gradients have shape (batch_size, num_channels, img_width, img_height),
        # so flatten to easily take norm per example in batch
        gradients = gradients.view(batch_size, -1)
        self.losses['gradient_norm'].append(gradients.norm(2, dim=1).mean().item())

        # Derivatives of the gradient close to 0 can cause problems because of
        # the square root, so manually calculate norm and add epsilon
        gradients_norm = torch.sqrt(torch.sum(gradients ** 2, dim=1) + 1e-12)

        # Return gradient penalty
        return self.gp_weight * ((gradients_norm - 1) ** 2).mean()

    def _train_epoch(self, data_loader):
        for i, data in enumerate(data_loader):
            self.num_steps += 1
            self._critic_train_iteration(data[0])
            # Only update generator every |critic_iterations| iterations
            if self.num_steps % self.critic_iterations == 0:
                self._generator_train_iteration(data[0])

            if i % self.print_every == 0:
                print("Iteration {}".format(i + 1))
                print("D: {}".format(self.losses['D'][-1]))
                print("GP: {}".format(self.losses['GP'][-1]))
                print("Gradient norm: {}".format(self.losses['gradient_norm'][-1]))
                if self.num_steps > self.critic_iterations:
                    print("G: {}".format(self.losses['G'][-1]))

    def train(self, data_loader, epochs, save_training_gif=True):
        if save_training_gif:
            # Fix latents to see how image generation improves during training
            fixed_latents = self.G.sample_latent(64)
            if self.use_cuda:
                fixed_latents = fixed_latents.cuda()
            training_progress_images = []

        for epoch in range(epochs):
            print("\nEpoch {}".format(epoch + 1))
            self._train_epoch(data_loader)

            if save_training_gif:
                # Generate batch of images and convert to grid
                img_grid = make_grid(self.G(fixed_latents).cpu().data)
                # Convert to numpy and transpose axes to fit imageio convention
                # i.e. (width, height, channels)
                img_grid = np.transpose(img_grid.numpy(), (1, 2, 0))
                # Add image grid to training progress
                training_progress_images.append(img_grid)

        if save_training_gif:
            imageio.mimsave('./training_{}_epochs.gif'.format(epochs),
                            training_progress_images)

    def sample_generator(self, num_samples):
        latent_samples = self.G.sample_latent(num_samples)
        if self.use_cuda:
            latent_samples = latent_samples.cuda()
        generated_data = self.G(latent_samples)
        return generated_data

    def sample(self, num_samples):
        generated_data = self.sample_generator(num_samples)
        # Remove color channel
        return generated_data.data.cpu().numpy()[:, 0, :, :]

In [103]:
# Creiamo un'istanza del Discriminatore
D = WDiscriminator().cuda()
# Creiamo un'istanza del Generatore
# G = Generator(img_size=(32, 32, 1), latent_dim=z_dimension, dim=16).cuda()
G = Generator().cuda()

# Definiamo gli ottimizzatori per il discriminatore e il generatore
d_optimizer = torch.optim.Adam(D.parameters(), lr=learning_rate, betas=(b1, b2))
g_optimizer = torch.optim.Adam(G.parameters(), lr=learning_rate, betas=(b1, b2))

output_size 512


In [104]:
# data_loader, _ = get_mnist_dataloaders(batch_size=64)
# img_size = (32, 32, 1)

# generator = Generator(img_size=img_size, latent_dim=100, dim=16)
# discriminator = Discriminator(img_size=img_size, dim=16)

# print(G)
# print(D)

# Initialize optimizers
# lr = 1e-4
# betas = (.9, .99)
# G_optimizer = optim.Adam(generator.parameters(), lr=lr, betas=betas)
# D_optimizer = optim.Adam(discriminator.parameters(), lr=lr, betas=betas)

# Train model
# epochs = 200
trainer = Trainer(G, D, g_optimizer, d_optimizer,
                  use_cuda=torch.cuda.is_available())
trainer.train(dataloader, num_epoch, save_training_gif=True)

# Save models
name = 'mnist_model'
torch.save(trainer.G.state_dict(), './gen_' + name + '.pt')
torch.save(trainer.D.state_dict(), './dis_' + name + '.pt')


Epoch 1
data.shape torch.Size([64, 1, 32, 32])
generated_data.shape torch.Size([64, 1, 32, 32])
x.shape torch.Size([64, 128, 2, 2])
x_view.shape torch.Size([64, 512])
x.shape torch.Size([64, 128, 2, 2])
x_view.shape torch.Size([64, 512])
x.shape torch.Size([64, 128, 2, 2])
x_view.shape torch.Size([64, 512])




RuntimeError: ignored

In [None]:
# Alleniamo le nostre reti
for epoch in range(num_epoch):
    for i, (img, _) in enumerate(dataloader):
        num_img = img.size(0)
        d_optimizer.zero_grad()

        real_img = img.cuda()
        # real_label = torch.ones(num_img).cuda()
        # fake_label = torch.zeros(num_img).cuda()

        # Calcoliamo la loss del Discriminatore sulle immagini reali
        real_out = D(real_img).squeeze(1)
        # d_loss_real = criterion(real_out, real_label)
        # real_scores = real_out  # Più è vicino a 1, meglio è


        z = torch.randn(num_img, z_dimension).cuda() # Noise di input
        fake_img = G(z)
        fake_out = D(fake_img).squeeze(1)

        # Calcoliamo la distanza di Wasserstein
        d_loss = torch.mean(fake_out) - torch.mean(real_out)
        d_loss.backward()
        d_optimizer.step()

        # Clip dei pesi del discriminatore per mantenere la condizione di
        # 1-Lipschitz (in alternativa si può inserire la gradient penalty)
        for p in D.parameters():
            p.data.clamp_(-clip_value, clip_value)


        if i % n_critic == 0:
            # train Generator
            g_optimizer.zero_grad()

            # Calcoliamo la loss del Generatore sulle immagini generate
            z = torch.randn(num_img, z_dimension).cuda() # Noise di input
            fake_img = G(z)
            # Adversarial loss
            g_loss = -torch.mean(D(fake_img))
            g_loss.backward()
            g_optimizer.step()

        if (i+1) % 100 == 0:
            print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f}'
                    .format(epoch, num_epoch, d_loss.item(), g_loss.item()))
    if epoch == 0:
        real_images = to_img(real_img.cpu().data)
        save_image(real_images, './w_img/real_images.png')

    fake_images = to_img(fake_img.cpu().data)
    save_image(fake_images, './w_img/fake_images-{}.png'.format(epoch+1))

torch.save(G.state_dict(), './generator.pth')
torch.save(D.state_dict(), './discriminator.pth')

Epoch [0/100], d_loss: -0.000000, g_loss: -0.000100
Epoch [0/100], d_loss: -0.000000, g_loss: -0.000100
Epoch [0/100], d_loss: -0.000000, g_loss: -0.000100
Epoch [0/100], d_loss: -0.000000, g_loss: -0.000100
Epoch [1/100], d_loss: -0.000000, g_loss: -0.000100
Epoch [1/100], d_loss: -0.000000, g_loss: -0.000100
Epoch [1/100], d_loss: -0.000000, g_loss: -0.000100
Epoch [1/100], d_loss: -0.000000, g_loss: -0.000100
Epoch [2/100], d_loss: -0.000000, g_loss: -0.000100
Epoch [2/100], d_loss: -0.000000, g_loss: -0.000100
Epoch [2/100], d_loss: -0.000000, g_loss: -0.000100
Epoch [2/100], d_loss: -0.000000, g_loss: -0.000100
Epoch [3/100], d_loss: -0.000000, g_loss: -0.000100
Epoch [3/100], d_loss: -0.000000, g_loss: -0.000100
Epoch [3/100], d_loss: -0.000000, g_loss: -0.000100
Epoch [3/100], d_loss: -0.000000, g_loss: -0.000100
Epoch [4/100], d_loss: -0.000000, g_loss: -0.000100
Epoch [4/100], d_loss: -0.000000, g_loss: -0.000100
Epoch [4/100], d_loss: -0.000000, g_loss: -0.000100
Epoch [4/100

KeyboardInterrupt: ignored

In [None]:
# Generiamo nuovi campioni
z = torch.randn(4, z_dimension).cuda()
fake_img = G(z)

# Guardiamo i campioni generati
plt.subplot(2,2,1)
plt.imshow(fake_img.detach().cpu()[0][0], cmap="gray")
plt.subplot(2,2,2)
plt.imshow(fake_img.detach().cpu()[1][0], cmap="gray")
plt.subplot(2,2,3)
plt.imshow(fake_img.detach().cpu()[2][0], cmap="gray")
plt.subplot(2,2,4)
plt.imshow(fake_img.detach().cpu()[3][0], cmap="gray")