In [None]:
# Imports
import tensorflow as tf
from tensorflow.keras.layers import (Dense, 
                                     BatchNormalization, 
                                     LeakyReLU, 
                                     Reshape, 
                                     Conv2DTranspose,
                                     Conv2D,
                                     Dropout,
                                     Flatten)
import numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets
import time
from IPython import display as display_console # A command shell for interactive computing in Python.
import asyncio
import logging

In [None]:
# Set constants
IMAGE_SIZE = 8
IMAGE_SIZE_O4 = int(IMAGE_SIZE/4)

MAX_EPOCHS = 200000
BATCH_SIZE = 450

NOISE_DIM = 100

GENERATOR_LEARNING_RATE = 1e-4
JUDGE_LEARNING_RATE = 1e-3

RATING_BUTTONS = 5
RATING_MEMORY = 15

In [None]:
# Define the generator model
def generator_model():
    # They happen in a linear order
    model = tf.keras.Sequential()

    # Add a NN layer with 7*7*256 nodes and no bias
    model.add(Dense(IMAGE_SIZE_O4*IMAGE_SIZE_O4*256, use_bias=False, input_shape=(100,)))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    # Reshape back into a 2D image with 256 layers (and confirm it was reshaped correctly)
    model.add(Reshape((IMAGE_SIZE_O4, IMAGE_SIZE_O4, 256)))

    # Convolutional layer; ensures that the output will be the same size
    model.add(Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())

    # Another convolutional layer
    model.add(Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    model.add(BatchNormalization())
    model.add(LeakyReLU())
    
    # Convolution
    model.add(Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    
    # Print summary?
    print(model.summary())

    # Return
    return model

# Get the model
generator = generator_model()

In [None]:
# Generates one image from the generator to use as an example
def generate_one_image():
  noise = tf.random.normal([1, NOISE_DIM])
  generated_images = generator(noise, training=False)
  return generated_images[0]

In [None]:
def judge_model():
    # Define model
    model = tf.keras.Sequential()
    
    # Convolutional layer
    model.add(Conv2D(64, (5, 5), strides=(2, 2), padding='same', input_shape=[IMAGE_SIZE, IMAGE_SIZE, 1]))
    model.add(LeakyReLU()) # ReLU should always come after a convolutional layer
    model.add(Dropout(0.3)) # Randomly sets 30% of nodes to 0 during training. Prevents overfitting.

    # Another convolution layer, same as above
    model.add(Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(LeakyReLU())
    model.add(Dropout(0.3))

    # Flatten to a 1D vector
    model.add(Flatten())

    # Final step is to convert into a single scalar representing rating
    model.add(Dense(1))

    print(model.summary())

    return model

# Get the model
judge = judge_model()

In [None]:
def judge_loss(predicted_output, human_output):
    cross_entropy = tf.keras.losses.MeanSquaredError()
    return cross_entropy(predicted_output, human_output)

def generator_loss(fake_output):
    # Generator always wants a value of one (which indicates a high rating)
    cross_entropy = tf.keras.losses.MeanSquaredError()
    return cross_entropy(tf.ones_like(fake_output), fake_output)

# Declare optimizers
# Judge needs a much faster learning rate to make the most of human feedback
generator_optimizer = tf.keras.optimizers.Adam(GENERATOR_LEARNING_RATE)
judge_optimizer = tf.keras.optimizers.Adam(JUDGE_LEARNING_RATE)

In [None]:
# Generator training step
@tf.function
def generator_train_step():
  
    # Give the generator random noise
    noise = tf.random.normal([BATCH_SIZE, NOISE_DIM])
    
    with tf.GradientTape() as tape:
      # Generate the fake images for the mini-batch
      generated_images = generator(noise, training=True)

      # Judge the output
      judge_output = judge(generated_images, training=True)

      # Loss
      loss = generator_loss(judge_output)

      # Get the gradients
      gradients = tape.gradient(loss, generator.trainable_variables)

      # Optimize
      generator_optimizer.apply_gradients(zip(gradients, generator.trainable_variables))

In [None]:
# Judge training step
@tf.function
def judge_train_step(images, labels):
    with tf.GradientTape() as tape:
      # Train the judge based on the training data we've built
      judge_output = judge(images, training=True)

      # Loss
      loss = judge_loss(judge_output, labels)

      # Get the gradients
      gradients = tape.gradient(loss, judge.trainable_variables)

      # Optimize
      judge_optimizer.apply_gradients(zip(gradients, judge.trainable_variables))

In [None]:
# ====================
# Define training loop
# ====================

# Training loop
async def train(critical_data):
  # In each epoch...
  for epoch in range(MAX_EPOCHS):
    start = time.time()
    
    # Do one generator training batch
    generator_train_step()

    # DEBUG
    # print(
    #   "Epoch:", epoch,
    #   "Image exists:", critical_data['current_image']!=None,
    #   "user_data:", len(critical_data['user_data_images']),
    #   "user_data_queue:", len(critical_data['user_data_images_queue']),
    # )

    # Total hack, but this sleep is necessary to allow Jupyter widget button presses to work
    await asyncio.sleep(0.05)

    # Check to see if new examples have been added by the user
    if not critical_data['user_data_lock'].locked():
      # Acquire lock
      await critical_data['user_data_lock'].acquire()

      # Move all elements from queue to main dataset
      # Images
      for entry in critical_data['user_data_images_queue']:
        critical_data['user_data_images'].append(entry)
      critical_data['user_data_images_queue'] = []

      # Labels
      for entry in critical_data['user_data_labels_queue']:
        critical_data['user_data_labels'].append(entry)
      critical_data['user_data_labels_queue'] = []

      # Check if a new current_image should be set
      if critical_data['current_image'] == None:
        critical_data['current_image'] = generate_one_image()

      # Release lock
      critical_data['user_data_lock'].release()
    
    # Do judge training step if we have at least one training example
    if len(critical_data['user_data_images']) >= 2 and len(critical_data['user_data_labels']) >= 2:
      image_tensor = tf.Variable(critical_data['user_data_images'][-RATING_MEMORY:])
      image_tensor = tf.reshape(image_tensor, [-1, 8, 8, 1])
      label_tensor = tf.Variable(critical_data['user_data_labels'][-RATING_MEMORY:])
      label_tensor = tf.reshape(label_tensor, [-1, 1])
      judge_train_step(image_tensor, label_tensor)

    # Print out epoch data
    # print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

# Handle task errors
def check_for_errors(task: asyncio.Task) -> None:
    try:
        task.result()
    except asyncio.CancelledError:
        pass  # Task cancellation should not be logged as an error
    except Exception:
        logging.exception('Exception raised by task = %r', task)

# ==========
# Start task
# ==========

current_image = None
user_data_images_queue = []
user_data_labels_queue = []
user_data_images = []
user_data_labels = []
user_data_lock = asyncio.Lock()

critical_data = {
  'current_image': current_image,
  'user_data_images_queue': user_data_images_queue,
  'user_data_labels_queue': user_data_labels_queue,
  'user_data_images': user_data_images,
  'user_data_labels': user_data_labels,
  'user_data_lock': user_data_lock,
}
training_task = asyncio.create_task(train(critical_data))
training_task.add_done_callback(check_for_errors)


# ==========
# User input
# ==========

# Create all the clickable buttons for the user
async def create_buttons():
    for i in range(RATING_BUTTONS):
        # Calculate the actual number from 0.0 to 1.0 that will be used in the loss function
        rating_value = i/(RATING_BUTTONS-1)

        # Create and display the button
        button = widgets.Button(description=(str(i+1) + " Star"))
        button.on_click(lambda x, rv=rating_value, r=(i+1): press_rating_button(r, rv))
        display(button)

# Show image
async def show_image(image):
    plt.subplot(4, 4, 1)
    plt.imshow(image * 127.5 + 127.5, cmap='gray')
    plt.axis('off')
    # plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
    plt.show()

# TODO: Save image button

def press_rating_button(rating, rating_value):
    button_task = asyncio.create_task(press_rating_button_results(rating, rating_value))
    button_task.add_done_callback(check_for_errors)

async def press_rating_button_results(rating, rating_value):
    # Clear console
    display_console.clear_output(wait=True)

    # Print a holdover text
    print("You rated that image", rating, "star.")

    # Save this result to the user data queue
    await critical_data['user_data_lock'].acquire()
    critical_data['user_data_images_queue'].append(critical_data['current_image'])
    critical_data['user_data_labels_queue'].append(rating_value)
    critical_data['current_image'] = None
    critical_data['user_data_lock'].release()

    # Move on to next image
    await next_image()

# Wait for an image for the human to rate and then set up the image and buttons
async def next_image():
    while (True):
      if critical_data['current_image'] == None:
          # If no image has been posted yet, sleep to wait for a new one
          print("No image; sleeping...") # DEBUG
          await asyncio.sleep(1)
          continue
      else:
          # Clear current display
          display_console.clear_output(wait=True)

          # Show the image and give the user the rating buttons
          await show_image(critical_data['current_image'])
          await create_buttons()

          # End the while loop so we can receive user input
          break

# Start user input
button_task = asyncio.create_task(next_image())
button_task.add_done_callback(check_for_errors)