In [3]:
# Imports
import tensorflow as tf
import tensorflow.keras as keras
import numpy as np
import matplotlib.pyplot as plt
import random
import time
import json
import training

In [4]:
# Constants

# Network Structure
CONTEXT_SIZE = 8        # How many other voxels are considered for a training example
EMBEDDING_SIZE = 64     # Dimensionality of the voxel embedding vector
STACKED_LAYERS = 2      # How many times the network structure repeats itself
ATTENTION_HEADS = 2     # Number of heads in each multi-headed attention mechanism

# Training Hyperparameters
CHECK_RADIUS = 7        # How far away voxels can be to be part of a training example
CENTER_FOCUS = 0.3      # How much to focus on picking voxels close to the center of the cube. Must be between 0 and 1.
LEARNING_RATE = 1e-4
TRAINING_EXAMPLES = 100

In [6]:
# Load voxel palette
# The output is a 255-dimensional vector of probabilities for different colors
# Which 255 colors can be generated is decided by the palette file

# Index 0 is reserved as 'undecided' voxel
# Index 1 is reserved as 'air' voxel
# Index 2-255 are colors. So there are 254 possible colors.
with open('data/palette.json', 'r') as json_file:
    palette = json.load(json_file)['colors']
    palette_size = len(palette)

print(f"Palette has {palette_size} colors")


Palette has 256 colors


In [7]:
# Create model
def main_model():
    input = keras.Input(shape=(CONTEXT_SIZE, EMBEDDING_SIZE,), name='input_layer')

    x = input

    # Stacked layers
    for i in range(STACKED_LAYERS):
        # Multi-headed attention
        fx = keras.layers.MultiHeadAttention(
            num_heads=ATTENTION_HEADS,
            key_dim=EMBEDDING_SIZE,
            name=f'multi_head_attention_{i}',
        )(x, x)

        # Normalization
        fx = keras.layers.LayerNormalization(name=f'normalization_{i}a')(x)

        # Residual connection
        x = keras.layers.Add(name=f'residual_connection_{i}a')([x,fx])

        # Feedforward
        fx = keras.layers.Dense(EMBEDDING_SIZE, name=f'feedforward_{i}')(x)
        fx = keras.layers.LeakyReLU(name=f'relu_{i}')(fx)

        # Normalization
        fx = keras.layers.LayerNormalization(name=f'normalization_{i}b')(fx)

        # Residual connection
        x = keras.layers.Add(name=f'residual_connection_{i}b')([x,fx])
    
    # Final feedforward layer
    # Output size should be palette_size-1, since we don't want it to be able to choose "undecided"
    x = keras.layers.Dense(palette_size-1, name='feedforward_final')(x)

    # Softmax
    x = keras.layers.Softmax(name='softmax')(x)
    
    # Build and return model
    return keras.Model(inputs=input, outputs=x)

model = main_model()
model.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
model.summary()


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_layer (InputLayer)       [(None, 8, 64)]      0           []                               
                                                                                                  
 normalization_0a (LayerNormali  (None, 8, 64)       128         ['input_layer[0][0]']            
 zation)                                                                                          
                                                                                                  
 residual_connection_0a (Add)   (None, 8, 64)        0           ['input_layer[0][0]',            
                                                                  'normalization_0a[0][0]']       
                                                                                              

In [8]:
# Set up loss and optimizer
loss_function = tf.keras.losses.CategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam(LEARNING_RATE)

In [10]:
# Design training examples
training_examples = training.generate_training_examples(TRAINING_EXAMPLES, CONTEXT_SIZE)

print(f"{len(training_examples)} training examples created.")

100 training examples loaded.
