Parameters: 
- width, height of NN
- image resolution. 
- grayscale. 
- drop out. (may increase interpretability of mid layers?)
- L1 vs L2 loss.

In [None]:
import os
from time import time
from functools import lru_cache
import matplotlib.pylab as plt
import numpy as np
from scipy.ndimage import zoom
from IPython import display
from imageio import imread

In [None]:
import tensorflow as tf 
from tensorflow.keras.layers import Dense, InputLayer
from tensorflow.keras import Model

print("TF version:", tf.__version__)
print("GPU is", "available" if tf.config.list_physical_devices('GPU') else "NOT AVAILABLE")

In [None]:
try:
    from google.colab import drive
except ModuleNotFoundError:
    ROOT = 'images'
else:
    drive.mount('/content/drive')
    ROOT = '/content/drive/My Drive/Colab Notebooks/images'

In [None]:
class MyModel(Model):
  def __init__(self, width, depth, n_channels = 3):
    super(MyModel, self).__init__()
    self.width = width
    self.depth = depth
    self.n_channels = n_channels
    self.myLayers = []
    for i in range(depth):
      layer = Dense(
        width, activation='relu', name = f'relu_layer_{i}', 
        kernel_initializer=tf.initializers.RandomNormal(
          stddev = (width** -.5), 
        ), 
        bias_initializer  =tf.initializers.RandomNormal(stddev=0.01),
      )
      self.myLayers.append(layer)
    self.last = Dense(n_channels, activation='sigmoid', name = f'sigmoid_layer')

  def call(self, x):
    for layer in self.myLayers:
      x = layer(x)
    return self.last(x)


In [None]:
@lru_cache()
def getRaster(width, height):
  buffer = np.zeros((width*height, 2))
  x_lin_space = np.linspace(-1, 1, width)
  y_lin_space = np.linspace(-1, 1, height)
  for x in range(width):
    for y in range(height):
      buffer[x * height + y, :] = (x_lin_space[x], y_lin_space[y])
  return buffer

In [None]:
def view(model, width, height, view_h = 5):
    output = model.predict(getRaster(width, height))
    plt.imshow(
        np.reshape(output, (width, height, model.n_channels)), 
        vmin=0, vmax=1, 
    )
    plt.axis('off')
    plt.gcf().set_size_inches(view_h / height * width, view_h)

In [None]:
def viewInitField():
    model = MyModel(4, 4, 3)
    model.build((None, 2))
    model.summary()
    view(model)
# viewInitField()

In [None]:
os.chdir(ROOT)
img_names = os.listdir()

In [None]:
def loadData(img_name, resolution = 50, to_gray = False):
  img = imread(img_name) / 255
  width, height = img.shape[:2]
  try:
    if img.shape[2] == 4:
      img = img[:, :, :3]
    elif img.shape[2] == 1:
      raise IndexError
    else:
      assert img.shape[2] == 3
  except IndexError:
    t = np.zeros((width, height, 3))
    t[:, :, 0] = img
    t[:, :, 1] = img
    t[:, :, 2] = img
    img = t
  zoom_k = resolution / (width * height) ** .5
  tt = zoom(img[:, :, 0], zoom_k, order=1)
  width, height = tt.shape
  t = np.zeros((width, height, 3))
  t[:, :, 0] = tt
  t[:, :, 1] = zoom(img[:, :, 1], zoom_k, order=1)
  t[:, :, 2] = zoom(img[:, :, 2], zoom_k, order=1)
  img = t
  if to_gray:
    img[:, :, 0] = np.mean(img, axis=2)
    n_channels = 1
  else:
    n_channels = 3
  x = getRaster(width, height)
  y = np.zeros((width * height, n_channels))
  for i in range(n_channels):
    y[:, i] = np.reshape(img[:, :, i], (width * height, ))
  return x, y, width, height, img

In [None]:
previewIter = iter(img_names)

In [None]:
# Run this cell multiple times to preview all data. 
try:
  name = next(previewIter)
except StopIteration:
  print("No more.")
else:
  x, y, w, h, img = loadData(name, 200)
  print(name)
  plt.imshow(img)

In [None]:
# class MyCallback(tf.keras.callbacks.Callback):
#     def __init__(self, width, height):
#         super().__init__()
#         self.width = width
#         self.height = height
#     def on_epoch_begin(self, epoch, logs=None):
#         print(epoch)
#         sleep(1)
#         view(model, self.width, self.height, 5)
#         display.clear_output(wait=True)
#         display.display(plt.gcf())
# #         sleep(.01)

In [None]:
def train(
    img_name, is_gray = False, SPF = 1.5, steps_per_epoch = 32, 
    canvas_size = (12, 5), 
    resolution = [150], 
    nn_width = [64], nn_depth = [3], 
    loss = [tf.keras.losses.mean_squared_error], 
):
    shape = []
    titles = []
    if len(resolution) > 1:
        shape.append(len(resolution))
        titles.append(('resolution', resolution))
    if len(nn_width) > 1:
        shape.append(len(nn_width))
        titles.append(('NN width', nn_width))
    if len(nn_depth) > 1:
        shape.append(len(nn_depth))
        titles.append(('NN depth', nn_depth))
    if len(loss) > 1:
        shape.append(len(loss))
        titles.append(('loss', ['L2' if x is tf.keras.losses.mean_squared_error else 'L1' for x in loss]))
    if len(shape) != 2:
        raise Exception('comparison limited to 2D, sorry')
    fig, axes = plt.subplots(*shape)
#     try:
#         iter(axes)
#     except TypeError:
#         axes = [axes]
    for i, ax in enumerate(axes[0, :]):
        ax.set_title(titles[1][0] + ' = ' + str(titles[1][1][i]))
    for i, ax in enumerate(axes[:, 0]):
        ax.set_ylabel(titles[0][0] + ' = ' + str(titles[0][1][i]))
    flat_axes = [x for t in axes for x in t]
    for ax in flat_axes:
        ax.tick_params(
            axis='both', which='both', 
            bottom=False, top=False, 
            labelbottom=False, 
            right=False, left=False, 
            labelleft=False, 
        )
    iterAxes = iter(flat_axes)
    max_w = 0
    max_h = 0
    for r in resolution:
        x, y, w, h, _ = loadData(img_name, r, is_gray)
        if w > max_w:
            max_w = w
            max_h = h
        for nw in nn_width:
            for nd in nn_depth:
                for l in loss:
                    model = MyModel(nw, nd, 1 if is_gray else 3)
                    model.compile(
                        optimizer='adam',
                        loss=l,
                    )
                    setups.append((model, x, y, w, h, next(iterAxes)))
    age = np.zeros((len(setups), ))
    epoch = np.zeros((len(setups), ), dtype=np.int32)
    next_render = 0
    render_i = 0
    while True:
        if next_render < np.sum(age):
            start = time()
            for model, x, y, w, h, ax in setups:
                output = model.predict(getRaster(max_w, max_h))
                ax.imshow(
                    np.reshape(output, (max_w, max_h, model.n_channels)), 
                    vmin=0, vmax=1, 
                )
            fig.set_size_inches(*canvas_size)
            fig.tight_layout()
            display.clear_output(wait=True)
            display.display(fig)
            next_render += SPF
            plt.savefig(f'../frames/{render_i}.jpg')
            render_i += 1
            print('Render overhead:', format((time() - start) / SPF, '.1%'))
            print('Epochs:')
            print(np.reshape(epoch, shape))
        elected = np.argmin(age)
        model, x, y, w, h, ax = setups[elected]
        start = time()
        model.fit(
            x, y, 
            steps_per_epoch = max_w * max_h, 
            epochs = 1, 
            verbose = 0, 
            batch_size = w * h, 
        )
        age[elected] += time() - start
        epoch[elected] += 1
    display.clear_output(wait=True)
    print("ok")

In [None]:
setups = []

train(
    'polyak et al.png', 
    is_gray = False, 
    SPF = 40, 
    canvas_size = (7, 5), 
    steps_per_epoch = 128, 
    resolution = [300, 500], 
    nn_width = [128, 256], 
    nn_depth = [6], 
#     nn_depth = [4, 8], 
#     loss = [tf.keras.losses.mean_squared_error, tf.keras.losses.mean_absolute_error], 
)

In [None]:
def inspect(model, resolution = 150):
    activations = []
    activations.append(getRaster(resolution, resolution))
    for reluLayer in model.myLayers:
        activations.append(reluLayer(activations[-1]))
    activations.append(model.last(activations[-1]))
    fig, axes = plt.subplots(model.depth + 2, model.width)
    def draw(i, j, field, absolute = False):
        axes[i, j].imshow(
            np.reshape(field, (resolution, resolution)), 
            **({"vmin": 0, "vmax": 1} if absolute else {})
        )
        axes[i, j].axis('off')
    mid_col = model.width // 2
    draw(0, mid_col,     activations[0][:, 0])
    draw(0, mid_col + 1, activations[0][:, 1])
    for i in range(model.depth + 2):
        for j in range(model.width):
            if i in (0, model.depth + 1):
                axes[i, j].axis('off')
                continue
            draw(i, j, activations[i][:, j])
    for c in range(model.n_channels):
        draw(
            model.depth + 1, mid_col + c, 
            activations[model.depth + 1][:, c], absolute = True, 
        )
    fig.set_size_inches(model.width * 2, (model.depth + 2) * 2)

inspect(model)