<img src="https://s3.amazonaws.com/pokemontcg/xy7/54.png" alt="Bulbasaur Pic" style="width: 256px;"/>

# Pokemon TCG card generator

- Downloads and saves card data from pokemontcg.io with python sdk, reformats it as YAML
    - Actually, just use api directly because python sdk is out of date and parts are not compatible with each other
- Uses keras lstm example to generate card data

## Load card data

In [14]:
# imports
import yaml, json, os, random, requests
from pprint import pprint

data_dir = '/home/ubuntu/fastai-data/pokemon'

## Preprocessing

 - Convert json data to a text representation easy for a character-embedding based model to parse

In [4]:
# load data
with open(os.path.join(data_dir,'cards.json')) as f:
     cards = json.load(f)
#pprint(cards[-1])

In [5]:
# augment data
for i in range(3):
    cards.extend(random.sample(cards, len(cards)))

In [6]:
# encode card categories as greek letters
alphabet = 'θωερτψυιοπασφγηςκλζχξωβνμ'
# encode type as a unicode character, following https://redd.it/4xvh2q
type_char = '✴☽☽⛩❤✊♨☘☘⚡⛓⚛☔'

types = json.loads(requests.get('https://api.pokemontcg.io/v1/types').content)['types']
types.insert(2, 'Dark')
types.insert(7, 'Green')
subtypes = json.loads(requests.get('https://api.pokemontcg.io/v1/subtypes').content)['subtypes']

In [7]:
# encode type as unicode character
def type_to_char(t_list):
    if t_list and t_list[0] != 'Free':
        return ''.join([type_char[types.index(t)] for t in t_list])
    else:
        return ''

# convert list of lines to single text, and replaces name with @
def singlify(text, name=None):
    if text:
        text = ''.join(text) if isinstance(text, list) else text
        if name:
            text = text.replace(name, '@')
        return text
    else:
        return ''

# write data as txt file
with open(os.path.join(data_dir,'cards.txt'), 'w+') as f:
    for card in cards:
        lines = ['\n']
        lines.append('|'.join([card['supertype'][0],
                alphabet[subtypes.index(card['subtype'])] if card['subtype'] else '',
                type_to_char(card['types']),
                type_char[types.index(card['weaknesses'][0]['type'])] \
                    + ('^'*int(card['weaknesses'][0]['value'][1]) if '0' in card['weaknesses'][0]['value'] else 'x')\
                    if card['weaknesses'] else '',     
                type_char[types.index(card['resistances'][0]['type'])] \
                    + ('^'*int(card['resistances'][0]['value'][1]) if '0' in card['resistances'][0]['value'] else 'x')\
                    if card['resistances'] else '',     
                '^'*(int(card['hp'])//10) if card['hp'] and card['hp'].isdigit() else '',
                type_to_char(card['retreat_cost']),
                singlify(card['name']), singlify(card['evolvesFrom']), singlify(card['text'],name=card['name'])]))
        if card['ability']:
            lines.append(
                '|'.join(['x', card['ability']['name'],
                          singlify(card['ability']['text'], name=card['name'])]))
        if card['ancient_trait']:
            lines.append(
                '|'.join(['y', card['ancient_trait']['name'],
                          singlify(card['ancient_trait']['text'], name=card['name'])]))
        if card['attacks'] and card['attacks']:
            for attack in card['attacks']:
                lines.append(
                    '|'.join(['z', type_to_char(attack['cost']) if 'cost' in attack else '',
                              str(attack['damage']), singlify(attack['name']),
                              singlify(attack['text'], name=card['name'])]))
        if 'マ' not in ''.join(lines): # no japanese cards
            for line in lines:
                f.write(line+'\n')
            

## Create Model

 - Turn text into embedded sequences that keras can use
 - Setup model architecture

In [8]:
# imports
from keras.models import Sequential
from keras.layers import *
from keras.optimizers import Adam
import numpy as np

Using Theano backend.
 https://github.com/Theano/Theano/wiki/Converting-to-the-new-gpu-back-end%28gpuarray%29

Using gpu device 0: Tesla K80 (CNMeM is disabled, cuDNN Mixed dnn version. The header is from one version, but we link with a different version (6021, 5103))


In [15]:
# load text
path = os.path.join(data_dir,'cards.txt')
text = open(path).read()[:]

print('corpus length:', len(text))
print(text[:128])

corpus length: 17472744


P|φ|☘|♨x||^^^^^^||Shroomish||
z|✴||Spore|Your opponent's Active Pokémon is now Asleep.


P|π|☘|♨x||^^^^^^^^^^||Breloom|Shroomi


In [16]:
# get characters used in text
chars = sorted(list(set(text)))
vocab_size = len(chars)

print('total chars:', vocab_size)
print(''.join(chars))

total chars: 131

 !"#&'()*+,-./0123456789:;?@ABCDEFGHIJKLMNOPQRSTUVWXYZ[]^_abcdefghijklmnopqrstuvwxyz{|}~ ×éαβγδεηθικοπρςστυφψωݎ—’•↓−☔☘☽♀♂♨⚛⚡⛓⛩✊✴❤＋


In [17]:
# create character indices
char_indices = dict((c, i) for i, c in enumerate(chars))
# turn text into char indices
idx = [char_indices[c] for c in text]

In [None]:
maxlen = 128
sentences = []
next_chars = []
for i in range(len(idx)-maxlen+1):
    sentences.append(idx[i: i + maxlen])
    next_chars.append(idx[i+1: i+maxlen+1])

In [None]:
print('# of sequences:', len(sentences))

sentences = np.concatenate([[np.array(o)] for o in sentences[:-2]])
next_chars = np.concatenate([[np.array(o)] for o in next_chars[:-2]])

# of sequences: 17472617


In [None]:
np.save(os.path.join(data_dir,'sentences'), sentences)
np.save(os.path.join(data_dir,'next_chars'), sentences)

In [None]:
# size of embedding
n_fac = 42

In [None]:
# model architecture
model=Sequential([
        Embedding(vocab_size, n_fac, input_length=maxlen),
        GRU(256, input_shape=(n_fac,),return_sequences=True, dropout=0.01, recurrent_dropout=0.01),
        Dropout(0.2),
        GRU(512, return_sequences=True, dropout=0.01, recurrent_dropout=0.01),
        Dropout(0.2),
        TimeDistributed(Dense(vocab_size)),
        Activation('softmax')
    ])
model.compile(loss='sparse_categorical_crossentropy', optimizer=Adam(lr=0.01), metrics=['acc'])
model.summary()

## Train Model

In [None]:
from numpy.random import choice
import random

# print example text, 
def print_example(length=800, temperature=0.7, mult=2):
    seed_len=40
    path = os.path.join(data_dir,'cards.txt')
    text = open(path).read()[:]
    ind = random.randint(0,len(text)-seed_len-1)
    seed_string = text[ind:ind+seed_len]
    
    for i in range(length):
        if (seed_string.split('\n')[-1].count('|') == 7 or
        seed_string.startswith(('x','y')) and seed_string.split('\n')[-1].count('|') == 1 or
        seed_string.startswith('z') and seed_string.split('\n')[-1].count('|') == 3):
            temp = temperature * mult # make names more creative
        else:
            temp = temperature
        
        x=np.array([char_indices[c] for c in seed_string[-40:]])[np.newaxis,:]
        preds = model.predict(x, verbose=0)[0][-1]
        preds = np.log(preds) / temp
        exp_preds = np.exp(preds)
        preds = exp_preds / np.sum(exp_preds)
        next_char = choice(chars, p=preds)
        print(next_char, end="")
        seed_string = seed_string + next_char
    
    #print(seed_string[seed_len:])

In [None]:
from keras.callbacks import ModelCheckpoint, ReduceLROnPlateau, LambdaCallback
import h5py

def print_callback(logs, epoch):
    print_example()

result_dir = os.path.join(data_dir, 'results')
weight_path = "weights-{epoch:02d}-{acc:.2f}.hdf5"
checkpoint = ModelCheckpoint(os.path.join(result_dir, weight_path),
                             monitor='acc', verbose=1, save_best_only=True, mode='max')
reduce_lr = ReduceLROnPlateau(monitor='loss', factor=0.2,
                              patience=2, min_lr=0.00000001)
printer = LambdaCallback(on_epoch_end=print_callback)

callbacks_list = [printer, checkpoint, reduce_lr]

In [None]:
num_epochs = 50
history = model.fit(sentences,
                    np.expand_dims(next_chars,-1),
                    batch_size=256,
                    epochs=num_epochs,
                    callbacks=callbacks_list)

In [None]:
%%capture generated_cards
print_example(length=300000, temperature=0.7, mult=2)

In [None]:
with open(os.path.join(data_dir,'cards_generated.txt'), 'w+') as f:
    f.write(generated_cards.stdout)

## Process Output

 - At this point I redid the prior stuff with a premade tensorflow char-rnn model, to see if it was any better. I didn't feel as if there were significant improvements but it did run considerably faster.
 - Decode generated text back into JSON format
 - Convert JSON format to card images

In [None]:
# encode card categories as greek letters
alphabet = 'θωερτψυιοπασφγηςκλζχξωβνμ'
# encode type as a unicode character, following https://redd.it/4xvh2q
type_char = '✴☽⛩❤✊♨☘⚡⛓⚛☔'

types = json.loads(requests.get('https://api.pokemontcg.io/v1/types').content)['types']
subtypes = json.loads(requests.get('https://api.pokemontcg.io/v1/subtypes').content)['subtypes']
supertypes = json.loads(requests.get('https://api.pokemontcg.io/v1/supertypes').content)['supertypes']
with open(os.path.join(data_dir,'cards.json')) as f:
     old_names = [card['name'] for card in json.load(f)]

In [None]:
# decode type from unicode character
def char_to_type(chars):
    if chars and len(chars) > 0:
        return [types[type_char.index(char)] for char in chars]
    else:
        return None

cards = []
card = None
with open(os.path.join(data_dir,'cards_generated.txt')) as f:
    for line in f:
        line = line.split('|')
        if line[0] in ('P','E','T'):
            if card and card['name'].rstrip() not in old_names:
                cards.append(card)
            try:
                card = {'supertype': supertypes[('P','E','T').index(line[0])],
                        'subtype': subtypes[alphabet.index(line[1])] if line[1] else None,
                        'types': char_to_type(line[2]),
                        'weaknesses':
                        {'type': types[type_char.index(line[3][0])],
                         'value': '×2' if line[4][1] == 'x' else '-'+str(len(line[4])-1)+'0'} if line[4] else None,
                        'resistances':
                        {'type': types[type_char.index(line[4][0])],
                         'value': '×2' if line[5][1] == 'x' else '-'+str(len(line[5])-1)+'0'} if line[5] else None,
                        'hp': len(line[5])*10 if line[6] else None,
                        'retreat_cost': char_to_type(line[6]),
                        'name': line[7].rstrip(),
                        'evolvesFrom': line[8].rstrip(),
                        'text': line[9].replace('@',line[8]).rstrip() if len(line) > 9 else None}
            except:
                card = None
                print('Skipped card')
        elif line[0] == 'x' and card and card['supertype'] == 'Pokémon':
            try:
                card['ability'] = {'name':line[1].rstrip(),
                                   'text':line[2].replace('@',card['name']).rstrip() if len(line) > 2 else None}
            except:
                print('Skipped ability')
        elif line[0] == 'y' and card and card['supertype'] == 'Pokémon':
            try:
                card['ancient_trait'] = {'name':line[1].rstrip(),
                                         'text':line[2].replace('@',card['name']).rstrip() if len(line) > 2 else None}
            except:
                print('Skipped trait')
        elif line[0] == 'z' and card and card['supertype'] == 'Pokémon':
            try:
                card.setdefault('attacks', []).append(
                    {'cost': char_to_type(line[1]),
                     'damage': line[2],
                     'name': line[3].rstrip(),
                     'text': line[4].replace('@',card['name']).rstrip() if len(line) > 4 else None})
            except:
                print('Skipped attack')
                                     

In [None]:
class ExplicitDumper(yaml.SafeDumper):
    def ignore_aliases(self, data):
        return True
    
with open('cards_generated.yml', 'w+') as f:
     yaml.dump(cards, f, allow_unicode=True, Dumper=ExplicitDumper, default_flow_style=False)

In [None]:
from IPython.display import FileLink
FileLink('cards_generated.yml')

## Create Card Mockups

- Using Paulsnoop's BWXY card templates and symbol sheet (from deviantart)

In [None]:
from PIL import Image, ImageDraw, ImageFont
import os, textwrap

data_dir = '/home/ubuntu/fastai-data/pokemon_img'
template_path = os.path.join(data_dir, 'templates')
save_path = os.path.join(data_dir, 'card_results')

In [None]:
import yaml, pprint, re, unidecode

card_data = []
with open('cards_generated.yml') as f:
     card_data = yaml.load(f)

In [None]:
# only supports basic pokemon for now
def get_energy_img(energy, category):
    energies = ['Grass', 'Fire', 'Water', 'Electric', 'Psychic', 'Fighting',
                'Dark', 'Metal', 'Fairy', 'Dragon', 'Colorless']
    full_img = Image.open('symbols.png')
    if category is 'attack':
        img = full_img.crop((46+energies.index(energy)*57, 85, 85+energies.index(energy)*57, 135))
    if category is 'weakness':
        img = full_img.crop((50+energies.index(energy)*57, 210, 80+energies.index(energy)*57, 250))
    return img

def gen_card_img(card):
    if card['supertype'] == 'Pokémon':
        img = Image.open(os.path.join(template_path,
                         card['supertype'], 'Basic',
                         card['types'][0]+'.png'))
        
        d = ImageDraw.Draw(img)
        
        f = ImageFont.truetype(font='fonts/gill-rb.ttf', size=48)
        d.text((180,36), card['name'], font=f, fill='black')

        f = ImageFont.truetype(font='fonts/gill-rb.ttf', size=18)
        d.text((556, 68), 'HP', font=f, fill='black')

        f = ImageFont.truetype(font='fonts/futura-cb.ttf', size=44)
        d.text((582, 42), str(card['hp']), font=f, fill='black')
        
        f = ImageFont.truetype(font='fonts/futura-cb.ttf', size=30)
        if card['weaknesses']:
            energy_img = get_energy_img(card['weaknesses']['type'], 'weakness')
            img.paste(energy_img, (65, 888), energy_img)
            d.text((100, 890), card['weaknesses']['value'], font=f, fill='black')
        if card['resistances']:
            energy_img = get_energy_img(card['resistances']['type'], 'weakness')
            img.paste(energy_img, (195, 888), energy_img)
            d.text((230, 890), card['resistances']['value'], font=f, fill='black')
        
        full_img = Image.open('symbols.png')
        retreat_img = full_img.crop((517, 433, 517+32*len(card['retreat_cost']),463))
        img.paste(retreat_img, (150, 938), retreat_img)

        start_height = 560
        if 'ability' in card:
            ability = card['ability']
            
            ability_img = full_img.crop((50, 433, 212, 475))
            img.paste(ability_img, (60, start_height+5), ability_img)
            
            f = ImageFont.truetype(font='fonts/gill-cb.ttf', size=44)
            d.text((240, start_height), ability['name'], font=f, fill='#c23600')
            
            f = ImageFont.truetype(font='fonts/gill-rp.ttf', size=30)
            d.multiline_text((60, start_height+54), textwrap.fill(ability['text'], width=48), font=f, fill='black')
            
            start_height += 80 + d.multiline_textsize(textwrap.fill(ability['text'], width=48), font=f)[1]
        if 'attacks' in card:
            for attack in card['attacks']:
                if start_height >= 760:
                    break
                
                for n in range(len(attack['cost'])):
                    energy_img = get_energy_img(attack['cost'][n],'attack')
                    img.paste(energy_img, (60+n*45, start_height), energy_img)
                
                f = ImageFont.truetype(font='fonts/gill-cb.ttf', size=44)
                d.text((115+n*45, start_height), attack['name'], font=f, fill='black')

                f = ImageFont.truetype(font='fonts/futura-cb.ttf', size=44)
                d.text((612, start_height), attack['damage'], font=f, fill='black')

                f = ImageFont.truetype(font='fonts/gill-rp.ttf', size=30)
                d.multiline_text((60, start_height+54), textwrap.fill(attack['text'], width=48), font=f, fill='black')

                start_height += 80 + d.multiline_textsize(textwrap.fill(attack['text'], width=48), font=f)[1]
        
    elif card['supertype'] == 'Trainer':
        img = Image.open(os.path.join(template_path,
            card['supertype'], (card['subtype'].replace(' ','_') if card['subtype'] else 'Supporter')+'.png'))
        d = ImageDraw.Draw(img)
        
        f = ImageFont.truetype(font='fonts/gill-rb.ttf', size=44)
        d.text((85,105), card['name'], font=f, fill='black')
        
        f = ImageFont.truetype(font='fonts/gill-rp.ttf', size=30)
        d.multiline_text((95, 570), textwrap.fill(card['text'] if card['text'] else '', width=42), font=f,fill='black')
    else:
        img = Image.open(os.path.join(template_path,
            card['supertype'], (card['subtype'].replace(' ','_') if card['subtype'] else 'Supporter')+'.png'))
        d = ImageDraw.Draw(img)
        
        f = ImageFont.truetype(font='fonts/gill-rb.ttf', size=30)
        d.text((80,100), 'Special Energy', font=f, fill='black')
        
        f = ImageFont.truetype(font='fonts/gill-rb.ttf', size=40)
        d.text((60,655), card['name'], font=f, fill='black')
        
        f = ImageFont.truetype(font='fonts/gill-rp.ttf', size=30)
        d.multiline_text((60, 720), textwrap.fill(card['text'], width=48), font=f, fill='black')
    
    background = Image.open('holosheet.jpg')
    background.paste(img, (0, 0), img)
    img = background

    img.thumbnail((512,512))
    return img

In [None]:
# turns string into filename
def slugify(value):
    value = unidecode.unidecode(value)
    value = str(re.sub('[^\w\s-]', '', value).strip())
    value = str(re.sub('[-\s]+', '-', value))
    return value

In [None]:
for card in card_data:
    try:
        print(card['name'])
        img = gen_card_img(card)
        img.save(os.path.join(save_path, slugify(card['name'])+'.jpg'))
    except:
        print('skipped a card')

<img src="https://cdn.bulbagarden.net/upload/2/21/001Bulbasaur.png" alt="Bulbasaur Pic" style="width: 256px;"/>

# Generate Card Images

1. Download Ken Sugimori art from [bulbapedia](https://archives.bulbagarden.net/wiki/Category:Ken_Sugimori_Pok%C3%A9mon_artwork), augment
2. Use [DRAGAN](https://github.com/kodalinaveen3/DRAGAN) to generate new pokemon art
3. Use PIL to combine random char-rnn generated card properties and DRAGAN generated art, using [templates](https://pokemoncardresources.deviantart.com/gallery/51274687/Resources-Classic)

## Download from Bulbapedia

In [None]:
cat_name = 'Ken_Sugimori_Pokémon_artwork'
data_dir = '/home/ubuntu/fastai-data/pokemon_img'
png_path = os.path.join(data_dir, 'pngs')
img_width = 256

In [None]:
import mwclient, requests, shutil

site = mwclient.Site('archives.bulbagarden.net')
category = site.Categories[cat_name]
filenames = (x.page_title for x in category.members(namespace=6))
for file in filenames:
    file_url = 'http://archives.bulbagarden.net/w/index.php?title=Special:FilePath&file={}&width={}'.format(
            file, img_width)
    r = requests.get(file_url, stream=True)
    if not r.status_code == 200:
        print('Requested width is bigger than source - downloading full size')
        file_url = 'http://archives.bulbagarden.net/w/index.php?title=Special:FilePath&file={}'.format(file)
        r = requests.get(file_url, stream=True)
    print('Thumbnail found')
    if r.status_code == 200:
        print('Saving file '+file)
        output_filepath = os.path.join(png_path, file.replace(' ','_'))
        with open(output_filepath, 'wb+') as f:
            r.raw.decode_content = True
            shutil.copyfileobj(r.raw, f)

## Create GAN

- Basically copied [makegirlsmoe](https://makegirlsmoe.github.io/assets/pdf/technical_report.pdf)'s architecture and [pytorch DRAGAN](https://github.com/jfsantos/dragan-pytorch/blob/master/dragan.py) / [keras WGAN-GP](https://github.com/farizrahman4u/keras-contrib/blob/master/examples/improved_wgan.py)'s code, because I don't know what I'm doing that well.
- GANs (Generative Adverserial Networks) use a competition between a generator and discriminator neural network to gradually make the generator output match existing samples
    - GANs are inefficient in pure keras - probably good to learn pytorch sometime
    - also don't combine learning a new generator architecture (resnet), a new concept (GANs), and an experimental GAN architecture (DRAGAN) at the same time next time. Especially do not do this all while using the entire data set for training rather than samples.
    - Every batch, first train the discriminator network to correctly classify real and generated images
        - discriminator is a convolutional image classifier network that classifies images as real or fake (binary).
        - discriminator network is actually stacked generator and discriminator, but only discriminator weights are being trained. Input a batch of real images and a batch of noise, feed the noise through the generator and then through the discriminator (real images go straight to discriminator), and optimize discriminator to classify real images as real and generated images as fake.
    - then, train the generator network to fool the discriminator network
        - generator takes noise vector and generates image. In this case we reshape a noise vector into an image, use a dense layer to increase a low resolution image's depth (# of channels), convolutional layers to create features, and pixel shuffles to increase resolution by decreasing depth, until we end up with an appropriate image.
        - generator network is also stacked generator and discriminator, but only the generator weights are trained now. Input a batch of noise, feed it through the generator and discriminator, and optimize the generator so the discriminator classifies the generated images as real (by feeding it the opposite labels).
    - Overall discriminator loss should be lower than generator loss (so generator always has a somewhat accurate goal);  about 0.1-1 discriminator loss and 2-7 generator loss may be good. Loss should not converge but reach equilibrium, as increasing discriminator accuracy leads to better generator which lowers discriminator accuracy and vice versa.
    - DRAGAN adds a gradient penalty to the discriminator to make it easier to train. Creates perturbed images (adding noise vectors to real images) and adds a penalty to the loss function proportional to the gradient of the discriminator at those perturbed images (should make discriminator more linear?)
    - GAN tricks is very helpful, and follow pytorch DRAGAN implementation as closely as possible
        - Remember example implementations and overview papers often don't show things like dropout
- Resnets use skip connections to make networks more linear and help with deep networks
    - resblocks have skip connections, where output from earlier layer is elementwise summed with output from later layer. See resnet paper. This helps make deep networks more linear and generalizable
    - pixel shuffle is used to increase resolution, by making an image tensor wider (higher res) but shallower (fewer channels). See [subpixel](https://github.com/tetrachrome/subpixel)

In [None]:
# imports
import theano

#theano.config.optimizer_including='alloc_empty_to_zeros'
#theano.config.nvcc.fastmath = False

from itertools import product

import theano.tensor as T
from keras import backend as K
from keras.models import Model
from keras.initializers import *
from keras.layers import *
from keras.engine.topology import Layer
from keras.optimizers import Adam
import itertools, os
import numpy as np

K.set_image_data_format('channels_first')

- Defines some layer functions for use in generator and discriminator networks
- pixel shuffle from https://github.com/titu1994/Super-Resolution-using-Generative-Adversarial-Networks/blob/master/models.py
- Appears that changing to upscale + convolution instead of deconvolution (convolution + pixel shuffle) improves diversity of results as well as eliminates checkerboard artifacts?

In [None]:
def pixel_shuffle(input, scale, channels):
    b, k, row, col = input.shape
    output_shape = (b, channels, row * scale, col * scale)
    out = T.zeros(output_shape)
    r = scale
    for y, x in itertools.product(range(scale), repeat=2):
        out = T.inc_subtensor(out[:, :, y::r, x::r], input[:, r * y + x :: r * r, :, :])
    return out

class PixelShuffle(Layer):
    def __init__(self, r, channels, **kwargs):
        super(PixelShuffle, self).__init__(**kwargs)
        self.r = r
        self.channels = channels

    def build(self, input_shape):
        super(PixelShuffle, self).build(input_shape)

    def call(self, x, mask=None):
        return pixel_shuffle(x, self.r, self.channels)

    def compute_output_shape(self, input_shape):
        b, k, r, c = input_shape
        return (b, self.channels, r * self.r, c * self.r)

In [None]:
# residual block for generative network
def resblock(x):
    skip = x
    
    x = Conv2D(64, 3, padding='same')(x)
    x = BatchNormalization(momentum=0.9)(x)
    x = Activation('relu')(x)
    x = Conv2D(64, 3, padding='same')(x)
    x = BatchNormalization(momentum=0.9)(x)
    x = Add()([x, skip])
    
    return x

# residual block for discriminator network
def resblock2(x, filters=32, strides=1):
    skip = x
    
    x = LeakyReLU(alpha=0.2)(x)
    x = Conv2D(filters, 3, padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Conv2D(filters, 3, padding='same', strides=strides)(x)
    x = Add()([x, skip])
    
    return x

# superresolution CNN
def sp_cnn(x):
    #x = Conv2D(256, 3, padding='same')(x)
    #x = PixelShuffle(2, 64)(x)
    x = UpSampling2D()(x)
    x = Conv2D(64, 3, padding='same')(x)
    x = BatchNormalization(momentum=0.9)(x)
    x = Activation('relu')(x)
    #x = PReLU()(x)
    
    return x

- Creates a super-resolution resnet like generator model, and a resnet discriminator model, approx. according to the [technical report](https://makegirlsmoe.github.io/assets/pdf/technical_report.pdf)
- generator model has less resblocks to reduce complexity and make it not converge as fast
- looks more random if activations are removed from sp_cnn blocks?
- Using tanh instead of relu and family seems to make things better as well?
- Adding gaussian noise to input of discriminator might stabilize training

In [None]:
# generator network architecture
def get_generator(x, dim=16, depth=256):
    x = Dense(depth*dim*dim)(x)
    x = BatchNormalization(momentum=0.9)(x)
    x = Activation('relu')(x)
    x = Reshape((depth, dim, dim))(x)
    x = Dropout(0.2)(x)
    skip = x = Conv2D(64, 1, padding='same')(x)

    for i in range(12):
        x = resblock(x)
        
    x = Conv2D(64, 3, padding='same')(x)
    x = BatchNormalization(momentum=0.9)(x)
    x = Activation('relu')(x)
    x = Add()([x, skip])
    
    for i in range(3):
        x = sp_cnn(x)
    
    x = Conv2D(3, 9, padding='same')(x)
    x = Activation('tanh')(x)
    
    return x

# discriminator network architecture
def get_discriminator(x):
    x = GaussianNoise(0.0)(x)
    x = Conv2D(32, 4, strides=2, padding='same')(x)
    x = LeakyReLU(alpha=0.2)(x)
    x = Dropout(0.2)(x)
    
    for i in range(5):
        x = resblock2(x, filters=2**(5+i))
        x = resblock2(x, filters=2**(5+i))
        if (i < 2):
            x = Conv2D(2**(6+i), 4, strides=2, padding='same')(x)
        else:
            x = Dropout(0.1)(x)
            x = Conv2D(2**(6+i), 3, strides=2, padding='same')(x)
        x = LeakyReLU(alpha=0.2)(x)

    x = Flatten()(x)
    x = Dense(1)(x)
    x = Activation('sigmoid')(x)
    
    return x

In [None]:
g_input = Input(shape=[128])
generator = Model(inputs=g_input, outputs=get_generator(g_input))

test = generator.predict(np.random.normal(size=(8, 128), scale=1))
print('max: {} min: {} mean: {}'.format(np.max(test), np.min(test), np.mean(test)))

In [None]:
d_input = Input(shape=[3, 128, 128])
discriminator = Model(inputs=d_input, outputs=get_discriminator(d_input))

test2 = discriminator.predict(generator.predict(np.random.normal(size=(8, 128), scale=1)))
print('max: {} min: {} mean: {}'.format(np.max(test2), np.min(test2), np.mean(test2)))

- Converts generator and discriminator networks into the form required for DRAGAN training
    - adds nontrainable discriminator network layers to generator, to allow us to maximize discriminator loss
    - adds other loss functions and inputs to discriminator network, so it trains on real examples as well as generated examples, and uses the DRAGAN gradient penalty
    - generator should have lower learning rate than discriminator

In [None]:
# discriminator_model trains discriminator with real and generated images
for layer in discriminator.layers:
    layer.trainable = True
for layer in generator.layers:
    layer.trainable = False
discriminator.trainable = True
generator.trainable = False

# blog article suggests perturbing in all directions (-1-1 rather than 0-1)
def perturb(input, c=0.5):
    b, row, col, k = input.shape
    alpha = K.repeat_elements(K.repeat_elements(K.repeat_elements(
        K.random_uniform((b, 1, 1, 1), 0, 1), row, 1), col, 2), k, 3)
    x_hat = alpha*input + (1-alpha)*(input + c * K.std(input) * K.random_uniform((b, row, col, k), 0, 1))
    return x_hat

imgs = Input(shape=[3, 128, 128]) # real mini-batch
noise = Input(shape=[128])
p_imgs = Lambda(perturb, output_shape=K.int_shape(imgs)[1:])(imgs) # perturbed mini-batch

# from keras WGAN-GP, though called with randomly perturbed inputs rather than averaged inputs
def gradient_penalty(y_true, y_pred, x_hat=p_imgs):
    gradients = K.gradients(K.sum(y_pred), x_hat)
    gradient_l2_norm = K.sqrt(K.sum(K.square(gradients), axis=[1,2,3]))
    gradient_penalty = K.mean(K.square(gradient_l2_norm - 1))
    return gradient_penalty

discriminator_model = Model(inputs=[imgs, noise],
                            outputs=[discriminator(imgs), discriminator(generator(noise)), discriminator(p_imgs)])
discriminator_model.compile(optimizer=Adam(0.0001),
                            loss=['binary_crossentropy', 'binary_crossentropy', gradient_penalty],
                            loss_weights=[0.4, 0.4, 0.2])


#perturb_test_model = Model(inputs=imgs, outputs=p_imgs)
#perturb_test_model.compile(optimizer='adam', loss='binary_crossentropy')
#test2 = discriminator_model.predict([generator.predict(np.random.normal(size=(4, 128))),
#                                     np.random.normal(size=(4, 128))])
#print('max: {} min: {} mean: {}'.format(np.max(test2[1]), np.min(test2[1]), np.mean(test2[1])))
discriminator_model.summary()

In [None]:
# generator_model trains generator to create image, optimizes to maximize discriminator loss
for layer in discriminator.layers:
    layer.trainable = False
for layer in generator.layers:
    layer.trainable = True
discriminator.trainable = False
generator.trainable = True

gm_input = Input(shape=[128])
generator_model = Model(inputs=gm_input, outputs=discriminator(generator(gm_input)))
generator_model.compile(optimizer=Adam(0.0001), loss='binary_crossentropy')

#test = generator_model.predict(np.random.normal(size=(4, 256)))
#print('max: {} min: {} mean: {}'.format(np.max(test), np.min(test), np.mean(test)))
generator_model.summary()

## Training Your DRAGAN

In [None]:
from keras.preprocessing import image
from matplotlib import pyplot as plt
import numpy as np
import random, os

# Instantiate plotting tool
%matplotlib inline

In [None]:
def plots(ims, figsize=(12,6), rows=2, interp=False, titles=None):
    if type(ims[0]) is np.ndarray:
        ims = np.array(ims).astype(np.uint8)
        if (ims.shape[-1] != 3):
            ims = ims.transpose((0,2,3,1))
    f = plt.figure(figsize=figsize)
    cols = len(ims)//rows if len(ims) % 2 == 0 else len(ims)//rows + 1
    for i in range(len(ims)):
        sp = f.add_subplot(rows, cols, i+1)
        sp.axis('Off')
        if titles is not None:
            sp.set_title(titles[i], fontsize=16)
        plt.imshow(ims[i], interpolation=None if interp else 'none')

In [None]:
data_dir = '/home/ubuntu/fastai-data/pokemon_img'
train_path = os.path.join(data_dir, 'train')
temp_path = os.path.join(data_dir, 'temp')
result_path = os.path.join(data_dir, 'results')
batch_size = 32

In [None]:
def get_batches(dirname, temp_dir=None, shuffle=True, batch_size=batch_size):
    gen = image.ImageDataGenerator(preprocessing_function=lambda x: (x - 127.5)/128,
                                  horizontal_flip=True,
                                  width_shift_range=0.09,
                                  height_shift_range=0.09,
                                  zoom_range=[1.1, 1.65],
                                  shear_range=0.18,
                                  fill_mode='constant',
                                  cval=255,
                                  channel_shift_range=8)
    return gen.flow_from_directory(dirname,
                                  target_size=(128,128),
                                  class_mode='binary',
                                  color_mode='rgb',
                                  shuffle=shuffle,
                                  save_to_dir=temp_dir,
                                  batch_size=batch_size)

batches = get_batches(train_path, temp_dir=temp_path)
#batches = get_batches(train_path)

for i in range(5000):
    batch, labels = next(batches)
print('max: {} min: {} mean: {}'.format(np.max(batch), np.min(batch), np.mean(batch)))
plots([image.load_img(os.path.join(temp_path, img)) for img in random.sample(os.listdir(temp_path), 4)])

In [None]:
from PIL import Image

def tile_images(image_stack):
    assert len(image_stack.shape) == 4
    image_list = [image_stack[i, :, :, :] for i in range(image_stack.shape[0])]
    #image_list = [image_stack[i, :, :] for i in range(image_stack.shape[0])]
    tiled_images = np.concatenate(image_list, axis=1)
    tiled_images = np.swapaxes(tiled_images, 0, 2)
    return tiled_images

def generate_images(generator, output_dir, epoch):
    test_image_stack = generator.predict(np.random.normal(size=(8, 128), scale=1)) 
    test_image_stack = (test_image_stack * 127.5) + 127.5
    test_image_stack = np.squeeze(np.round(test_image_stack).astype(np.uint8))
    tiled_output = tile_images(test_image_stack)
    tiled_output = Image.fromarray(tiled_output, mode='RGB')
    outfile = os.path.join(output_dir, 'epoch_{}.png'.format(epoch))
    tiled_output.save(outfile)

In [None]:
#import dcgan
#generator = dcgan.generator_model()
#generator.compile(optimizer=Adam(0.0001), loss='binary_crossentropy')

test_image_stack = generator.predict(np.random.normal(size=(8, 128), scale=1))
test_image_stack = (test_image_stack * 127.5) + 127.5
test_image_stack = np.squeeze(np.round(test_image_stack).astype(np.uint8))
tiled_output = tile_images(test_image_stack)
tiled_output = Image.fromarray(tiled_output, mode='RGB')
outfile = os.path.join(temp_path, 'perturb_test.png')
tiled_output.save(outfile)

plots([image.load_img(outfile)])

- Concensus from [Animeface-GAN](https://github.com/forcecore/Keras-GAN-Animeface-Character) and others seems to be generator loss of 2-4 and discriminator loss around 0.3 is good
- One sided label noising and both sided label smoothing for discriminator, no smoothing or noising for generator seems to work well for me

In [None]:
batches_per_epoch = int(1426 / batch_size)

true_positive_y = np.ones((batch_size, 1))
true_negative_y = np.zeros((batch_size, 1))
dummy_y = np.zeros((batch_size, 1)) + 0.5

epoch = 0
if epoch > 0:
    generator_model.load_weights(os.path.join(result_path, 'gan_g_weights{}.h5'.format(epoch)))
    discriminator_model.load_weights(os.path.join(result_path, 'gan_d_weights{}.h5'.format(epoch)))
    discriminator_loss = np.loadtxt(os.path.join(result_path, 'gan_d_loss_history.csv'))
    discriminator_loss = list(np.broadcast_to(
        np.expand_dims(discriminator_loss, axis=1),(discriminator_loss.shape[0], 4)))
    generator_loss =  list(np.loadtxt(os.path.join(result_path, 'gan_g_loss_history.csv')))
    batch_num = len(discriminator_loss) + 1
else:
    discriminator_loss = []
    generator_loss = []
    batch_num = 0

epoch += 1
print('Epoch ' + str(epoch) + '\n')
for batch, labels in batches: # folder labels not actually used
    
    if batch_num % batches_per_epoch == 0 and batch_num > 0:
        generate_images(generator, result_path, epoch)
        if generator_loss[-1] > 0.5:
            try:
                generator_model.save_weights(os.path.join(result_path, 'gan_g_weights{}.h5'.format(epoch)))
                discriminator_model.save_weights(os.path.join(result_path, 'gan_d_weights{}.h5'.format(epoch)))
            except:
                print('Weights could not be saved')
        epoch += 1
        print('\nEpoch ' + str(epoch))
    
    if len(batch) == batch_size:
        # smooth positive labels only
        positive_y = np.random.uniform(0.7, 1.2, size=(batch_size, 1))
        negative_y = np.zeros((batch_size, 1))+0.05
    
        # train discriminator with real, generated, and perturbed images
        noise = np.random.normal(size=(batch_size, 128), scale=1)
        discriminator_loss.append(
            discriminator_model.train_on_batch([batch, noise],[positive_y, negative_y, dummy_y]))
        
        print('D. Loss | Total: ' + str(discriminator_loss[-1][0])
              + ' | Real: ' + str(discriminator_loss[-1][1])
              + ' | Fake: ' + str(discriminator_loss[-1][2])
              + ' | Penalty: ' + str(discriminator_loss[-1][3]))
        np.savetxt(os.path.join(result_path, 'gan_d_loss_history.csv'), np.asarray(discriminator_loss)[:,0])

        # train generator to maximize discriminator loss
        noise2 = np.random.normal(size=(batch_size, 128), scale=1)
        positive_y = np.random.uniform(0.85, 1.05, size=(batch_size, 1))
        generator_loss.append(
            generator_model.train_on_batch(noise2, positive_y))
            
        print('G. Loss: '+  str(generator_loss[-1]))
        np.savetxt(os.path.join(result_path, 'gan_g_loss_history.csv'), np.asarray(generator_loss))
    batch_num += 1

## Compare with hypergan card images because keras is soooo slow

- Seriously, raw tensorflow is like at least 10 times faster than keras for this. Also keras keeps making me tweak the hyperparameters every few epochs to avoid nans. I really should learn pytorch soon. Maybe my generator and discriminator architectures are overkill, but the double evaluation for each model and the gradient penalty are very inefficient in keras.

- Looks like hypergan isn't very effective for this task either - discriminator or generator loss goes to 0 after just a few epochs, no matter what hyperparameters I try. May need more data augmentation, and it's easier to do that in keras.