In [None]:
# Append location to path to allow custom modules to be used.
import sys, os
sys.path.append(os.path.abspath(os.path.join("..")))

In [None]:
import tensorflow as tf
import tensorflow.keras as keras

import cgael

In [None]:
WORD_COUNT = 4
WORD_LENGTH = 10

COLOR_COUNT = 2
COLOR_CHANNELS = 3

In [None]:
TOKEN_SET = cgael.LanguageTokenSet(alphabet_tokens=[*'CHAT'], pad_token='-')

In [None]:
def build_simple_modular_color_speaker(channel_dense):
    # Take in a list of colors.
    x = y = keras.layers.Input((COLOR_COUNT, COLOR_CHANNELS), name="input")
    # Dense colors individually but with the same weights.
    y = keras.layers.TimeDistributed(keras.layers.Dense(channel_dense, activation="sigmoid"), name="distributed_dense")(y)
    # Dense all colors to appropriate word shape.
    y = keras.layers.Flatten(name="flatten")(y)
    y = keras.layers.Dense(WORD_COUNT * WORD_LENGTH * TOKEN_SET.token_count, activation="sigmoid", name="shape_dense")(y)
    y = keras.layers.Reshape((WORD_COUNT, WORD_LENGTH, TOKEN_SET.token_count), name="final_shape")(y)
    return keras.Model(x, y, name="speaker")

In [None]:
spk = build_simple_modular_color_speaker(channel_dense=5)
spk.summary()

In [None]:
def build_simple_modular_color_listener(embedding_size):
    x = y = keras.layers.Input((WORD_COUNT, WORD_LENGTH), name="input")
    y = keras.layers.Embedding(TOKEN_SET.token_count, embedding_size, embeddings_initializer="random_normal", name="embeddings")(y)
    y = keras.layers.Flatten(name="flatten")(y)
    y = keras.layers.Dense(COLOR_COUNT * COLOR_CHANNELS, activation="sigmoid", name="shape_dense")(y)
    y = keras.layers.Reshape((COLOR_COUNT, COLOR_CHANNELS), name="final_shape")(y)
    return keras.Model(x, y, name="listener")

In [None]:
lsn = build_simple_modular_color_listener(embedding_size=6)
lsn.summary()

In [None]:
class CgaelToolkit():
    def __init__(self, speaker, listener):
        self.speaker = speaker
        self.listener = listener

In [None]:
keras.layers.Input(spk.layers[0].input_shape[0])

In [None]:
def build_train_model(input_shape, speaker, listener):
    x = y = keras.layers.Input(input_shape, name="input")
    y = speaker(y)
    y = cgael.layers.ArgmaxLayer(name="argmax")(y)
    y = z = cgael.layers.LanguageDenoiseLayer(do_columns=True, name="denoise")(y)
    y = listener(y)
    return keras.Model(x, [z, y])

In [None]:
trn = build_train_model((COLOR_COUNT, COLOR_CHANNELS), spk, lsn)
trn.summary()