In [1]:
import tensorflow as tf
from tensorflow import keras

import numpy as np
from scipy.signal import convolve2d

In [2]:
class CyclicPadding2D(keras.layers.Layer):

    def __init__(self,):
        super(CyclicPadding2D, self).__init__()

    def build(self, input_shape):
        self.grid = tf.Variable(tf.zeros(shape=(input_shape[0], input_shape[1]+2, input_shape[2]+2), dtype=tf.float32), 
                        trainable=False, validate_shape=True)
        super(CyclicPadding2D, self).build(input_shape)  

    def call(self, inputs):
        
        self.grid[:,1:-1, 1:-1].assign(inputs)
        self.grid[:,0,0].assign(inputs[:,-1,-1])
        self.grid[:,0,-1].assign( inputs[:,-1,0])
        self.grid[:,-1,0].assign(inputs[:,0,-1])
        self.grid[:,-1,-1].assign(inputs[:,0,0])
        self.grid[:, 1:-1, 0].assign(inputs[:,:,-1])
        self.grid[:,1 : -1 , -1].assign(inputs[:,:, 0])

        self.grid[:,0, 1:-1].assign(inputs[:,-1,:])
        self.grid[:,-1, 1:-1].assign(inputs[:,0,:])
        return self.grid

In [3]:
class MakeBinary(keras.layers.Layer):
    def __init__(self,):
        super(MakeBinary, self).__init__()
        
    def build(self, input_shape):
        self.B = self.add_weight(name="B", shape=(), initializer="zeros", trainable=True)
        
    def call(self, input_grid):
        input_grid = input_grid - self.B
        input_grid = input_grid > 0.5
        return tf.cast(input_grid, dtype="float32")

In [4]:
class Conv2D(keras.layers.Layer):
    
    def __init__(self,kernel):
        super(Conv2D, self).__init__()
        self.kernel = kernel 
        
    def call(self, x):
        print(x.shape)
        x = tf.nn.conv2d(x, self.kernel, strides=1, padding='VALID')
        return x

In [5]:
class MultiplyLayer(keras.layers.Layer):
    def __init__(self, ):
        super(MultiplyLayer, self).__init__()
    def build(self, input_shape):
        input_shape = input_shape[0]
        self.w1 = self.add_weight(name="w1", shape=(input_shape[-2], input_shape[-1]), initializer="ones", trainable=True)
        self.w2 = self.add_weight(name="w2", shape=(input_shape[-2], input_shape[-1]), initializer="ones", trainable=True)
        self.b = self.add_weight(name="b", shape=(input_shape[-2], input_shape[-1]), initializer='zeros', trainable=True)
    def call(self, inputs):
        input_1 = inputs[0]
        input_2 = inputs[1]
        return tf.matmul(input_1, self.w1) + tf.matmul(input_2, self.w2) + self.b

In [6]:
class LocallyDense(keras.layers.Layer):
    def __init__(self, ):
        super(LocallyDense, self).__init__()
        

    def build(self, input_shape):
        m = input_shape[-2] - 2
        n = input_shape[-1] - 2
        self.w00 = self.add_weight(name="w00", shape=(m,n), initializer="ones", trainable=True)
        self.w01 = self.add_weight(name="w01", shape=(m,n), initializer="ones", trainable=True)
        self.w02 = self.add_weight(name="w02", shape=(m,n), initializer="ones", trainable=True)
        self.w10 = self.add_weight(name="w10", shape=(m,n), initializer="ones", trainable=True)
        self.w11 = self.add_weight(name="w11", shape=(m,n), initializer="ones", trainable=True)
        self.w12 = self.add_weight(name="w12", shape=(m,n), initializer="ones", trainable=True)
        self.w20 = self.add_weight(name="w20", shape=(m,n), initializer="ones", trainable=True)
        self.w21 = self.add_weight(name="w21", shape=(m,n), initializer="ones", trainable=True)
        self.w22 = self.add_weight(name="w22", shape=(m,n), initializer="ones", trainable=True)
        self.b = self.add_weight(name="b", shape=(m,n), initializer='zeros', trainable=True)

    def call(self, padded_input):
        p00 = padded_input[:,:-2,:-2]
        p01 = padded_input[:,:-2,1:-1]
        p02 = padded_input[:,:-2,2:]
        p10 = padded_input[:,1:-1,:-2]
        p11 = padded_input[:,1:-1,1:-1]
        p12 = padded_input[:,1:-1,2:]
        p20 = padded_input[:,2:,:-2]
        p21 = padded_input[:,2:,1:-1]
        p22 = padded_input[:,2:,2:]
        
        return tf.matmul(p00, self.w00) + tf.matmul(p01, self.w01) + tf.matmul(p02, self.w02) + tf.matmul(p10, self.w10) + tf.matmul(p11, self.w11) + tf.matmul(p12, self.w12) + tf.matmul(p20, self.w20) + tf.matmul(p21, self.w21) + tf.matmul(p22, self.w22) + self.b

In [12]:
class MyShittyModel(tf.keras.Model):
    def __init__(self, grid_size):
        super(MyShittyModel, self).__init__()
        self.d1 = grid_size[0]
        self.d2 = grid_size[1]

        # self.kernel = tf.ones(shape=(3, 3, 1, 1), dtype="float32")
        # self.kernel2 = tf.Variable(tf.ones(shape=(3, 3, 1, 1), dtype="float32"),  trainable=False, validate_shape=True)
        # self.kernel2[1, 1, : ,:].assign(tf.constant(0., shape=(1,1)))
        # self.conv2d = Conv2D(kernel=self.kernel)
        # self.conv2d_no_centre = Conv2D(kernel = self.kernel2)

        self.padding = CyclicPadding2D()
        # self.add_dim = tf.keras.layers.Reshape(target_shape=(self.d1 + 2, self.d2+2,1))
        # self.reduce_dim = tf.keras.layers.Reshape(target_shape=(self.d1, self.d2))
        self.locally_dense1 = LocallyDense()
        self.locally_dense2 = LocallyDense()
        self.locally_dense3 = LocallyDense()
        # self.weighted_sum_2D = MultiplyLayer()
        
    def call(self, x):
        x = self.padding(x)
        x = self.locally_dense1(x)
        x = tf.keras.activations.relu(x)
        x = self.padding(x)
        x = self.locally_dense2(x)
        x = tf.keras.activations.relu(x)
        x = self.padding(x)
        x = self.locally_dense3(x)
        x = tf.keras.activations.sigmoid(x)
        return x

In [13]:
def life_step(X):
    nbrs_count = convolve2d(X, np.ones((3, 3)), mode='same', boundary='wrap') - X
    return (nbrs_count == 3) | (X & (nbrs_count == 2))

sample_input = []
sample_output = []
probs = np.random.uniform(0.2, 0.99, 1000)
for prob in probs:
    grid = np.random.binomial(n=1,p=prob, size=(6,6))
    for _ in range(2):
        grid = life_step(grid)
    
    grid_1_step_more = life_step(grid)
    if grid.sum() > 0:
        sample_input.append(grid)
        sample_output.append(grid_1_step_more)
sample_x = tf.constant(np.array(sample_input).astype(float), shape= [len(sample_input),6,6], dtype="float32")
sample_y = tf.constant(np.array(sample_output).astype(float), shape= [len(sample_input),6,6], dtype="float32")

In [14]:
game = MyShittyModel(grid_size = (6,6))
game.compile(loss=tf.keras.losses.BinaryCrossentropy(), optimizer=keras.optimizers.Adam(),
    metrics=["accuracy"])
input_layer = keras.layers.Input(shape=(6, 6), name='input_layer', batch_size=32)

In [15]:
game(input_layer)

<tf.Tensor 'my_shitty_model_1/Identity:0' shape=(32, 6, 6) dtype=float32>

In [None]:
tf.keras.layers.LocallyConnected2D(
    filters, kernel_size, strides=(1, 1), padding='valid', data_format=None,
    activation=None, use_bias=True, kernel_initializer='glorot_uniform',
    bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None,
    activity_regularizer=None, kernel_constraint=None, bias_constraint=None,
    implementation=1, **kwargs
)

In [16]:
for i in range(0,len(sample_x),32):
    x = sample_x[i:i+32]
    y = sample_y[i:i+32]
    if len(y) == 32:
        game.fit(x, y, batch_size=32)

Train on 32 samples
Train on 32 samples
Train on 32 samples
Train on 32 samples
Train on 32 samples
Train on 32 samples
Train on 32 samples
Train on 32 samples
Train on 32 samples
Train on 32 samples
Train on 32 samples
Train on 32 samples
Train on 32 samples
Train on 32 samples
Train on 32 samples
Train on 32 samples
Train on 32 samples


In [None]:
game.fit(inputs,outputs)

In [None]:
outputs[30:32]

In [18]:
game.call(sample_y[30:62])

<tf.Tensor: id=5404, shape=(32, 6, 6), dtype=float32, numpy=
array([[[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]],

       [[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]],

       [[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]],

       ...,

       [[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]],

       [[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
