In [None]:
import torch
from torch import nn
import numpy as np
import matplotlib.animation as animation
from IPython.display import HTML

import math
import matplotlib.pyplot as plt
import torchvision
from torchvision import datasets, transforms, utils, models
import torchvision.transforms as transforms

### Activate GPU if one is avalible

In [None]:
device = ""
if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

### Set up data

In [None]:
DATA_DIR = "../input/pokemon-bw-28/pokemon_sprites_bw_28x28"
IMAGE_SIZE = 28
batch_size = 32

transform = transforms.Compose(
    [
     transforms.Grayscale(num_output_channels=1),
     transforms.ToTensor(),
     transforms.Normalize((0.5,), (0.5,)
    )
    ]
)

dataset = datasets.ImageFolder(root=DATA_DIR, transform=transform)

In [None]:
train_loader = torch.utils.data.DataLoader(
    dataset, batch_size=batch_size, shuffle=True
)

In [None]:
sample_batch = next(iter(train_loader))
plt.figure(figsize=(10, 8)); plt.axis("off"); plt.title("Sample Training Images")
plt.imshow(np.transpose(utils.make_grid(sample_batch[0], padding=1, normalize=True),(1, 2, 0)));

In [None]:
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))]
)

train_set_ex = torchvision.datasets.MNIST(
    root=".", train=True, download=True, transform=transform
)

batch_size = 32
train_loader_ex = torch.utils.data.DataLoader(
    train_set_ex, batch_size=batch_size, shuffle=True
)

In [None]:
# test iterator
real_samples_ex, mnist_labels_ex = next(iter(train_loader_ex))

In [None]:
real_samples, mnist_labels = next(iter(train_loader))

### Prepare/ Deploy GAN

In [None]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(784, 1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        x = x.view(x.size(0), 784)
        output = self.model(x)
        return output

In [None]:
discriminator = Discriminator().to(device=device)

In [None]:
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(100, 256),
            nn.ReLU(),
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Linear(512, 1024),
            nn.ReLU(),
            nn.Linear(1024, 784),
            nn.Tanh(),
        )

    def forward(self, x):
        output = self.model(x)
        output = output.view(x.size(0), 1, 28, 28)
        return output


In [None]:
generator = Generator().to(device=device)

In [None]:
#test generator
latent_space_samples = torch.randn(batch_size, 100).to(device=device)
generated_samples = generator(latent_space_samples)
generated_list_img = generated_samples.cpu().detach()
plt.imshow(generated_list_img[0].reshape(28,28), cmap="gray")

In [None]:
lr = 0.00005
num_epochs = 10000
loss_function = nn.BCELoss()

optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)


In [None]:
LATENT_SIZE = 100
fixed_noise = torch.randn(batch_size, LATENT_SIZE, 1, 1, device=device) 

In [None]:
img_list = []
gen_losses = []
dis_losses = []
ITERS = 0

print_n = 0


for epoch in range(num_epochs):
    for n, (real_samples, labels) in enumerate(train_loader):
        # Data for training the discriminator
        real_samples = real_samples.to(device=device)
        real_samples_labels = torch.ones((batch_size, 1)).to(
            device=device
        )
        latent_space_samples = torch.randn((batch_size, 100)).to(
            device=device
        )
        generated_samples = generator(latent_space_samples)
        generated_samples_labels = torch.zeros((batch_size, 1)).to(
            device=device
        )
        all_samples = torch.cat((real_samples, generated_samples))
        all_samples_labels = torch.cat(
            (real_samples_labels, generated_samples_labels)
        )

        # Training the discriminator
        discriminator.zero_grad()
        output_discriminator = discriminator(all_samples)
        loss_discriminator = loss_function(
            output_discriminator, all_samples_labels
        )
        loss_discriminator.backward()
        optimizer_discriminator.step()

        # Data for training the generator
        latent_space_samples = torch.randn((batch_size, 100)).to(
            device=device
        )

        # Training the generator
        generator.zero_grad()
        generated_samples = generator(latent_space_samples)
        output_discriminator_generated = discriminator(generated_samples)
        loss_generator = loss_function(
            output_discriminator_generated, real_samples_labels
        )
        loss_generator.backward()
        optimizer_generator.step()
        
        if n == batch_size - 1:
            print_n += 1 

        if print_n == 10:
            generated_list_img = generated_samples.cpu().detach()
            img_list.append(utils.make_grid(generated_list_img, nrow=4, normalize=True))
        
        # Show loss
        if print_n == 10:
            print_n = 0
            print(f"Epoch: {epoch} Loss D.: {loss_discriminator}")
            print(f"Epoch: {epoch} Loss G.: {loss_generator}")
        

In [None]:
latent_space_samples = torch.randn(batch_size, 100).to(device=device)
generated_samples = generator(latent_space_samples)

### Examine/Save Results

In [None]:
generated_samples = generated_samples.cpu().detach()
for i in range(32):
    ax = plt.subplot(4, 8, i + 1)
    plt.imshow(generated_samples[i].reshape(28, 28), cmap="gray")
    plt.xticks([])
    plt.yticks([])
    #plt.savefig('pokemon_1000_results.png')

In [None]:
generated_samples = generated_samples.cpu().detach()
plt.imshow(generated_samples[15].reshape(28, 28), cmap="gray")
#plt.savefig('pokemon_1.png')

In [None]:
generated_samples = generated_samples.cpu().detach()
plt.imshow(generated_samples[30].reshape(28, 28), cmap="gray")

In [None]:
%%capture
fig = plt.figure(figsize=(6,6))
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in generated_samples]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)
ani.save('pokemon_1000.gif', writer='imagemagick', fps=2)

In [None]:
ani.save('pokemon_1000.gif', writer='imagemagick', fps=2)

In [None]:
HTML(ani.to_jshtml()) # run this in a new cell to produce the below animation

In [None]:
img_list

In [None]:
%%capture
fig = plt.figure(figsize=(6,6))
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

In [None]:
HTML(ani.to_jshtml())