In [4]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

In [5]:
from datetime import datetime
import numpy as np

from tensorflow.keras.optimizers import Adam
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model

from data_loader import DataLoader
from model.gan_discriminators import patchgan70
from model.gan_generator import ternausNet16

In [6]:
# Setup data loader
img_sz=(256, 256, 3)
dataset_name = 'facades'
samples_pth = f'images/{dataset_name}'
data_loader = DataLoader(dataset_name=dataset_name, img_res=img_sz[:2])

In [7]:
# Determine patch size for PatchGan
patch = int(img_sz[0] / 2**4)
disc_patch = (patch, patch, 1)

In [8]:
# Setup optimiser
optimizer = Adam(0.0002, 0.5)

In [9]:
# Build discriminator
discriminator = patchgan70(input_size=img_sz)
discriminator.compile(loss='mse', optimizer=optimizer, metrics=['accuracy'])

In [10]:
# Build generator
generator = ternausNet16(input_size=img_sz, output_channels=3, 
                         dropout=True, batch_norm=True)

In [11]:
# Build GAN model by combining generator and 
# non-trainable discriminator

x = Input(shape=img_sz) # condition
z = generator(x)        # generated

# Discriminate without training
discriminator.trainable = False
valid = discriminator([z, x])

combined = Model(inputs=[x], outputs=[valid, z])
combined.compile(loss=['mse', 'mae'], loss_weights=[1, 100], optimizer=optimizer)



In [14]:
# Input images and their conditioning images
img_A = Input(shape=img_sz)
img_B = Input(shape=img_sz)

# By conditioning on B generate a fake version of A
fake_A = generator(img_B)

# For the combined model we will only train the generator
discriminator.trainable = False

# Discriminators determines validity of translated images / condition pairs
valid = discriminator([fake_A, img_B])

combined = Model(inputs=[img_A, img_B], outputs=[valid, fake_A])
combined.compile(loss=['mse', 'mae'], loss_weights=[1, 100], optimizer=optimizer)

In [15]:
import os

def sample_from_model(sample_dir, epoch, batch_i):
    os.makedirs(sample_dir, exist_ok=True)
    r, c = 3, 3
    
    imgs_A, imgs_B = data_loader.load_batch(batch_size=3, is_testing=True)
    fake_A = generator.predict(imgs_B)
    
    gen_imgs = np.concatenate([imgs_B, fake_A, imgs_A])
    
    # Rescale 0-1
    gen_imgs = 0.5 * gen_imgs + 0.5
    
    titles = ['Condition', 'Generated', 'Original']
    fig, axs = plt.subplots(r, c)
    cnt = 0
    
    for i in range(r):
        for j in range (c):
            axs[i, j].imshow(gen_imgs[cnt])
            axs[i, j].set_title(titles[i])
            axs[i, j].axis['off']
            cnt += 1
    fig.save(f'{sample_dir}/{epoch}_{batch_i}.png')
    plt.close()

In [16]:
def train(epochs, batch_size=1, sample_interval=50):
    # Train
    start_time = datetime.now()
    
    # Adversarial loss ground truths
    valid = np.ones((batch_size,) + disc_patch)  # TODO: this patch size will be wrong
    fake  = np.ones((batch_size,) + disc_patch)
    
    for epoch in range(epochs):
        for batch_i, (imgs_A, imgs_B) in enumerate(data_loader.load_batch(batch_size=batch_size)):
            
            # Train Discriminator
            # --------------------------
            
            # generate some fake samples
            fake_A = generator.predict(imgs_B)
            
            # train the discriminator
            d_loss_real = discriminator.train_on_batch([imgs_A, imgs_B], valid)
            d_loss_fake = discriminator.train_on_batch([fake_A, imgs_B], valid)
            d_loss = 0.5 * np.add(d_loss_real, d_loss_fake) # TODO: check math
            
            # Train Generator
            # --------------------------
            
            g_loss = combined.train_on_batch([imgs_A, imgs_B], [valid, imgs_A])
            
            elapsed_time = datetime.now() - start_time
            #print(f'Epoch {epoch}/{epochs} Batch {batch_i} D loss: {d_loss[0]} acc: {100*d_loss[1]} ' )
            print ("[Epoch %d/%d] [Batch %d/%d] [D loss: %f, acc: %3d%%] [G loss: %f] time: %s" % (epoch, epochs,
                                                                        batch_i, self.data_loader.n_batches,
                                                                        d_loss[0], 100*d_loss[1],
                                                                        g_loss[0],
                                                                        elapsed_time))   
            
            if batch_i % sample_interval == 0:
                sample_images(epoch, batch_i)





train(epochs=200, batch_size=1, sample_interval=200)

ValueError: Error when checking target: expected activation to have shape (32, 32, 1) but got array with shape (16, 16, 1)