# we use Generative Adversarial Networks(GAN) to generate Atari gameplay photos. One is generator and one is discriminator

In [1]:
import random
import argparse
import cv2

import torch
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter

import torchvision.utils as vutils

import gym
import gym.spaces

import numpy as np

In [2]:
log = gym.logger
log.set_level(gym.logger.INFO)

LATENT_VECTOR_SIZE = 100
DISCR_FILTERS = 64
GENER_FILTERS = 64
BATCH_SIZE = 16

# dimension input image will be rescaled
IMAGE_SIZE = 64

LEARNING_RATE = 0.0001
REPORT_EVERY_ITER = 100
SAVE_IMAGE_EVERY_ITER = 1000

In [3]:
class InputWrapper(gym.ObservationWrapper):
    """
    Preprocessing of input numpy array:
    1. resize image into predefined size
    2. move color channel axis to a first place
    """
    def __init__(self, *args):
        super(InputWrapper, self).__init__(*args)
        assert isinstance(self.observation_space, gym.spaces.Box)
        old_space = self.observation_space
        self.observation_space = gym.spaces.Box(self.observation(old_space.low),
                                                self.observation(old_space.high), dtype=np.float32)
        
    def observation(self, observation):
        #resize image
        new_obs = cv2.resize(observation, (IMAGE_SIZE, IMAGE_SIZE))
        #transform (210, 160, 3) -> (3, 210, 160)
        new_obs = np.moveaxis(new_obs, 2, 0)
        #normalize value
        return new_obs.astype(np.float32) / 255.0

In [4]:
class Discriminator(nn.Module):
    def __init__(self, input_shape):
        super(Discriminator, self).__init__()
        # this pipe converges image into the single number
        self.conv_pipe = nn.Sequential(
            nn.Conv2d(in_channels=input_shape[0], out_channels=DISCR_FILTERS,
                      kernel_size=4, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=DISCR_FILTERS, out_channels=DISCR_FILTERS*2,
                      kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(DISCR_FILTERS*2),
            nn.ReLU(),
            nn.Conv2d(in_channels=DISCR_FILTERS * 2, out_channels=DISCR_FILTERS * 4,
                      kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(DISCR_FILTERS * 4),
            nn.ReLU(),
            nn.Conv2d(in_channels=DISCR_FILTERS * 4, out_channels=DISCR_FILTERS * 8,
                      kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(DISCR_FILTERS * 8),
            nn.ReLU(),
            nn.Conv2d(in_channels=DISCR_FILTERS * 8, out_channels=1,
                      kernel_size=4, stride=1, padding=0),
            nn.Sigmoid()
        )

    def forward(self, x):
        conv_out = self.conv_pipe(x)
        return conv_out.view(-1, 1).squeeze(dim=1)

In [5]:
class Generator(nn.Module):
    def __init__(self, output_shape):
        super(Generator, self).__init__()
        # pipe deconvolves input vector into (3, 64, 64) image
        self.pipe = nn.Sequential(
            nn.ConvTranspose2d(in_channels=LATENT_VECTOR_SIZE, out_channels=GENER_FILTERS * 8,
                               kernel_size=4, stride=1, padding=0),
            nn.BatchNorm2d(GENER_FILTERS * 8),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=GENER_FILTERS * 8, out_channels=GENER_FILTERS * 4,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(GENER_FILTERS * 4),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=GENER_FILTERS * 4, out_channels=GENER_FILTERS * 2,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(GENER_FILTERS * 2),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=GENER_FILTERS * 2, out_channels=GENER_FILTERS,
                               kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(GENER_FILTERS),
            nn.ReLU(),
            nn.ConvTranspose2d(in_channels=GENER_FILTERS, out_channels=output_shape[0],
                               kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        return self.pipe(x)

In [6]:
def iterate_batches(envs, batch_size=BATCH_SIZE):
    #sample from environment and store in batch, if batch has the minimum size we needed, we use yield to get the data
    batch = [e.reset() for e in envs]
    env_gen = iter(lambda: random.choice(envs), None)
    
    while True:
        e = next(env_gen)
        obs, reward, is_done, _ = e.step(e.action_space.sample())
        if np.mean(obs) > 0.01:
            batch.append(obs)
        if len(batch) == batch_size:
            # Normalising input between -1 to 1
            batch_np = np.array(batch, dtype=np.float32) * 2.0 / 255.0 - 1.0
            yield torch.tensor(batch_np)
            batch.clear()
        if is_done:
            e.reset()

In [7]:
if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    #Use GPU to process model
    parser.add_argument("--cuda", default=False, action='store_true')
    #args = parser.parse_args()
    args, unknown = parser.parse_known_args()
    device = torch.device("cuda" if args.cuda else "cpu")
    
    #Use wrapper to create environment pool, it will pass to iterate_batches function to generate training data
    env_names = ('Breakout-v0', 'AirRaid-v0', 'Pong-v0')
    envs = [InputWrapper(gym.make(name)) for name in env_names]
    input_shape = envs[0].observation_space.shape
    
    #We create a output writer, 2 networks, loss function and 2 optimizers
    writer = SummaryWriter()
    net_discr = Discriminator(input_shape=input_shape).to(device)
    net_gener = Generator(output_shape=input_shape).to(device)
    
    # 1 means real sample, 0 means fake
    objective = nn.BCELoss()
    gen_optimizer = optim.Adam(params=net_gener.parameters(), lr=LEARNING_RATE)
    dis_optimizer = optim.Adam(params=net_discr.parameters(), lr=LEARNING_RATE)
    
    #We will send all real and fake samples to discriminator, but all tags are 1, we will only update generator weightings
    #The 2nd input will use fake input to pretend as real and let generator to learn how to deceive discriminator
    gen_losses = []
    dis_losses = []
    iter_no = 0
    
    true_labels_v = torch.ones(BATCH_SIZE, dtype=torch.float32, device=device)
    fake_labels_v = torch.zeros(BATCH_SIZE, dtype=torch.float32, device=device)
    
    #in the loop, we will generate random vector and pass to generator
    for batch_v in iterate_batches(envs):
        #generate extra fake samples, input is 4D: batch, filters, x, y
        gen_input_v = torch.FloatTensor(BATCH_SIZE, LATENT_VECTOR_SIZE, 1, 1).normal_(0, 1).to(device)
        batch_v = batch_v.to(device)
        gen_output_v = net_gener(gen_input_v)
        
        #We will train discriminator twice: first with batch of real examples, 2nd with fake examples
        #We need to use generator output to call detach() function to prevent the gradient go to generator
        #detach() is a tensor method, it can copy item but no need to use parent functions
        dis_optimizer.zero_grad()
        dis_output_true_v = net_discr(batch_v)
        dis_output_fake_v = net_discr(gen_output_v.detach())
        dis_loss = objective(dis_output_true_v, true_labels_v) + objective(dis_output_fake_v, fake_labels_v)
        dis_loss.backward()
        dis_optimizer.step()
        dis_losses.append(dis_loss.item())
        
        #We train generator here, we put the generated output to discriminator, but we put true labels to the 
        #generated data, so we can let generator make some more "real" images to the discriminator
        gen_optimizer.zero_grad()
        dis_output_v = net_discr(gen_output_v)
        gen_loss_v = objective(dis_output_v, true_labels_v)
        gen_loss_v.backward()
        gen_optimizer.step()
        gen_losses.append(gen_loss_v.item())
        
        #Below will report losses and send the images to TensorBoard
        iter_no += 1
        if iter_no % REPORT_EVERY_ITER == 0:
            log.info("Iter %d: gen_loss=%.3e, dis_loss=%.3e", iter_no, np.mean(gen_losses), np.mean(dis_losses))
            writer.add_scalar("gen_loss", np.mean(gen_losses), iter_no)
            writer.add_scalar("dis_loss", np.mean(dis_losses), iter_no)
            gen_losses = []
            dis_losses = []
        
        if iter_no % SAVE_IMAGE_EVERY_ITER == 0:
            writer.add_image("fake", vutils.make_grid(gen_output_v.data[:64]), iter_no)
            writer.add_image("real", vutils.make_grid(batch_v.data[:64]), iter_no)

INFO: Making new env: Breakout-v0
INFO: Making new env: AirRaid-v0
INFO: Making new env: Pong-v0
INFO: Iter 100: gen_loss=6.267e+00, dis_loss=2.519e-02
INFO: Iter 200: gen_loss=7.570e+00, dis_loss=9.488e-04
INFO: Iter 300: gen_loss=8.003e+00, dis_loss=5.950e-04
INFO: Iter 400: gen_loss=8.407e+00, dis_loss=3.536e-04
INFO: Iter 500: gen_loss=8.614e+00, dis_loss=2.673e-04
INFO: Iter 600: gen_loss=8.697e+00, dis_loss=2.410e-04
INFO: Iter 700: gen_loss=8.933e+00, dis_loss=1.882e-04
INFO: Iter 800: gen_loss=9.050e+00, dis_loss=1.602e-04
INFO: Iter 900: gen_loss=9.126e+00, dis_loss=1.423e-04
INFO: Iter 1000: gen_loss=9.573e+00, dis_loss=9.164e-05
INFO: Iter 1100: gen_loss=9.791e+00, dis_loss=7.444e-05
INFO: Iter 1200: gen_loss=9.970e+00, dis_loss=6.314e-05
INFO: Iter 1300: gen_loss=1.011e+01, dis_loss=5.432e-05
INFO: Iter 1400: gen_loss=1.014e+01, dis_loss=5.110e-05
INFO: Iter 1500: gen_loss=1.016e+01, dis_loss=4.933e-05
INFO: Iter 1600: gen_loss=1.041e+01, dis_loss=3.860e-05
INFO: Iter 1700:

INFO: Iter 14500: gen_loss=2.011e+01, dis_loss=2.980e-10
INFO: Iter 14600: gen_loss=2.017e+01, dis_loss=4.470e-10
INFO: Iter 14700: gen_loss=2.018e+01, dis_loss=2.235e-10
INFO: Iter 14800: gen_loss=2.020e+01, dis_loss=0.000e+00
INFO: Iter 14900: gen_loss=2.020e+01, dis_loss=2.235e-10
INFO: Iter 15000: gen_loss=2.014e+01, dis_loss=1.490e-10
INFO: Iter 15100: gen_loss=2.012e+01, dis_loss=2.980e-10
INFO: Iter 15200: gen_loss=2.008e+01, dis_loss=4.470e-10
INFO: Iter 15300: gen_loss=2.009e+01, dis_loss=1.490e-10
INFO: Iter 15400: gen_loss=2.012e+01, dis_loss=3.725e-10
INFO: Iter 15500: gen_loss=2.017e+01, dis_loss=7.451e-11
INFO: Iter 15600: gen_loss=2.017e+01, dis_loss=4.470e-10
INFO: Iter 15700: gen_loss=2.011e+01, dis_loss=3.278e-09
INFO: Iter 15800: gen_loss=2.017e+01, dis_loss=7.451e-11
INFO: Iter 15900: gen_loss=2.024e+01, dis_loss=0.000e+00
INFO: Iter 16000: gen_loss=2.024e+01, dis_loss=1.490e-10
INFO: Iter 16100: gen_loss=2.008e+01, dis_loss=0.000e+00
INFO: Iter 16200: gen_loss=1.99

INFO: Iter 28900: gen_loss=3.138e+00, dis_loss=4.522e-01
INFO: Iter 29000: gen_loss=3.263e+00, dis_loss=3.632e-01
INFO: Iter 29100: gen_loss=2.571e+00, dis_loss=5.284e-01
INFO: Iter 29200: gen_loss=2.117e+00, dis_loss=6.875e-01
INFO: Iter 29300: gen_loss=2.626e+00, dis_loss=4.661e-01
INFO: Iter 29400: gen_loss=3.175e+00, dis_loss=3.580e-01
INFO: Iter 29500: gen_loss=2.987e+00, dis_loss=5.307e-01
INFO: Iter 29600: gen_loss=3.116e+00, dis_loss=3.696e-01
INFO: Iter 29700: gen_loss=3.197e+00, dis_loss=5.278e-01
INFO: Iter 29800: gen_loss=3.212e+00, dis_loss=4.250e-01
INFO: Iter 29900: gen_loss=3.061e+00, dis_loss=4.012e-01
INFO: Iter 30000: gen_loss=3.526e+00, dis_loss=3.050e-01
INFO: Iter 30100: gen_loss=3.619e+00, dis_loss=3.266e-01
INFO: Iter 30200: gen_loss=3.485e+00, dis_loss=2.495e-01
INFO: Iter 30300: gen_loss=3.526e+00, dis_loss=3.412e-01
INFO: Iter 30400: gen_loss=3.609e+00, dis_loss=3.168e-01
INFO: Iter 30500: gen_loss=3.552e+00, dis_loss=2.828e-01
INFO: Iter 30600: gen_loss=3.46

KeyboardInterrupt: 