# PokeGAN

Simple GAN trained on all the pokemon sprite images. No transfer learning is used and the only input to the generator is the random vector.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from helper import *

### Load Dataset

In [None]:
image_root_dir = "../data/images/"
image_data, filenames = load_raw_data(image_root_dir, image_preprocessing)

In [None]:
image_data = [
    image_data[filenames.index("001MS.png")],
    image_data[filenames.index("002MS.png")],
    image_data[filenames.index("003XYMS.png")]
]

In [None]:
image_w, image_h = image_data[0].size
image_h, image_w, image_data[0].mode

In [None]:
type(image_data[0])

In [None]:
sample = np.random.choice(len(image_data), size=len(image_data), replace=False)
fig = plt.figure(figsize=(15, 7))
for i, img_index in enumerate(sample):
    ax = fig.add_subplot(3, 6, i + 1, xticks=[], yticks=[])
    ax.imshow(image_data[img_index])

### Datasets & DataLoaders

In [None]:
use_color = True
n_channels = 3 if use_color else 1
if use_color:
    data_transforms = transforms.Compose([
        transforms.ToTensor(),
    ])
else:
    data_transforms = transforms.Compose([
        transforms.Grayscale(),
        transforms.ToTensor(),
    ])

In [None]:
training_dataset = ImageDataset(image_data, data_transforms)
training_dataloader = DataLoader(training_dataset, batch_size=16, shuffle=True)

In [None]:
batch_sample = next(iter(training_dataloader))
dataset_sample = batch_sample[0]
dataset_sample = dataset_sample.permute(1, 2, 0)
if not use_color:
    dataset_sample = dataset_sample.squeeze()
plt.imshow(dataset_sample)

In [None]:
batch_sample.shape

## Networks

In [None]:
torch.random.manual_seed(50)

### Discriminator

In [None]:
D = DCDiscriminator(image_h, image_w, kernel_size=4, padding=1, is_rgb=use_color)
D.init_weights()
D

### Generator

In [None]:
noise_size = 100
G = DCGenerator(noise_size, image_h, image_w, kernel_size=4, padding=1, is_rgb=use_color)
G.init_weights()
G

### Optimisers

In [None]:
# learning rate for optimizers
lr_d = 0.0004
lr_g = 0.001

# Create optimizers for the discriminator and generator
d_optimizer = optim.Adam(D.parameters(), lr=lr_d, betas=(0.5, 0.999))
g_optimizer = optim.Adam(G.parameters(), lr=lr_g, betas=(0.5, 0.999))

## Training

In [None]:
# training hyperparams
num_epochs = 500

# keep track of loss and generated, "fake" samples
samples = []
losses = []

print_every = 400

# Get some fixed data for sampling. These are images that are held
# constant throughout training, and allow us to inspect the model's performance
sample_size=16
fixed_z = generate_random((sample_size, noise_size))
fixed_z = torch.from_numpy(fixed_z).float()

# train the network
D.train()
G.train()
for epoch in range(num_epochs):
    for batch_i, real_images in enumerate(training_dataloader):                
        batch_size = real_images.size(0)
        real_images = scale(real_images)
        
        # ============================================
        #            TRAIN THE DISCRIMINATOR
        # ============================================
                
        # 1. Train with real images
        d_optimizer.zero_grad()
        
        # Compute the discriminator losses on real images
        # use smoothed labels
        batch_out_real = D(real_images)
        d_real_loss = real_loss(batch_out_real, smooth=True)
        
        # 2. Train with fake images
        # Generate fake images
        z = generate_random((batch_size, noise_size))
        z = torch.from_numpy(z).float()
        fake_images = G(z)
        
        # Compute the discriminator losses on fake images        
        
        # add up real and fake losses and perform backprop
        batch_out_fake = D(fake_images)
        d_fake_loss = fake_loss(batch_out_fake)
    
        d_loss = d_fake_loss + d_real_loss
        d_loss.backward(retain_graph=True)
        d_optimizer.step()
        
        # =========================================
        #            TRAIN THE GENERATOR
        # =========================================
        
        
        # 1. Train with fake images and flipped labels
        g_optimizer.zero_grad()
        g_loss = real_loss(batch_out_fake)
        g_loss.backward()
        g_optimizer.step()

        # Print some loss stats
        if batch_i % print_every == 0:
            # print discriminator and generator loss
            print('Epoch [{:5d}/{:5d}] | d_loss: {:6.4f} | g_loss: {:6.4f}'.format(
                    epoch+1, num_epochs, d_loss.item(), g_loss.item()))

    
    ## AFTER EACH EPOCH##
    # append discriminator loss and generator loss
    losses.append((d_loss.item(), g_loss.item()))
    
    # generate and save sample, fake images
    G.eval() # eval mode for generating samples
    samples_z = G(fixed_z)
    samples.append(samples_z)
    G.train() # back to train mode


# Save training generator samples
with open('train_samples.pkl', 'wb') as f:
    pkl.dump(samples, f)

## Training loss

Here we'll plot the training losses for the generator and discriminator, recorded after each epoch.

In [None]:
fig, ax = plt.subplots()
losses = np.array(losses)
plt.plot(losses.T[0], label='Discriminator')
plt.plot(losses.T[1], label='Generator')
plt.title("Training Losses")
plt.legend()

## Sampling

In [None]:
def reshape_sample(samples, img_size, is_rgb=True):
    samples = (samples + 1) / 2  # Reverse scaling
    samples = samples.permute(0, 2, 3, 1)
    samples.detach()
    return samples.numpy()

In [None]:
# randomly generated, new latent vectors
sample_size=16
rand_z = generate_random((sample_size, noise_size))
rand_z = torch.from_numpy(rand_z).float()

G.eval() # eval mode
# generated samples
with torch.no_grad():
    rand_images = G(rand_z)
    
reshaped_images = reshape_sample(rand_images, image_w, is_rgb=use_color).squeeze()

In [None]:
fig = plt.figure(figsize=(7, 7))
for i, s in enumerate(reshaped_images):
    ax = fig.add_subplot(4, 4, i + 1, xticks=[], yticks=[])
    ax.imshow(s, cmap='gray')

### Load Training Samples

In [None]:
training_samples = pkl.load(open("train_samples.pkl", 'rb'))

In [None]:
for i in range(0, len(training_samples), 25):
    reshaped_images = reshape_sample(training_samples[i].detach(), image_w, is_rgb=use_color)
    fig = plt.figure(figsize=(7, 7))
    for i, s in enumerate(reshaped_images):
        ax = fig.add_subplot(4, 4, i + 1, xticks=[], yticks=[])
        ax.imshow(s, cmap='gray')