In [None]:
import numpy as np
import onnx 

# Load data from 8000 train and 2000 valid samples stored in .npy files from physically based model outputs
X_train = np.load("X_train.npy")
X_val = np.load("X_val.npy")

def reshape_input(X):
    return X.reshape(X.shape[0], X.shape[1], X.shape[2], 1)

X_train = reshape_input(X_train)
X_val = reshape_input(X_val)

def life_step(X):
    live_neighbors = sum(np.roll(np.roll(X, i, 0), j, 1)
                     for i in (-1, 0, 1) for j in (-1, 0, 1)
                     if (i != 0 or j != 0))
    return (live_neighbors == 3) | (X & (live_neighbors == 2)).astype(int)


import matplotlib.pyplot as plt

board_shape = (20, 20)
board_size = board_shape[0] * board_shape[1]
probability_alive = 0.15

def generate_frames(num_frames, board_shape=(100,100), prob_alive=0.15):
    return np.array([
        np.random.choice([False, True], size=board_shape, p=[1-prob_alive, prob_alive])
        for _ in range(num_frames)
    ]).astype(int)


frames = generate_frames(10, board_shape=board_shape, prob_alive=probability_alive)

print(frames.shape) # (num_frames, board_w, board_h)
y_train = np.array([life_step(frame) for frame in X_train])
y_val = np.array([life_step(frame) for frame in X_val])

# CNN Properties
filters = 50
kernel_size = (3, 3) # look at all 8 neighboring cells, plus itself
strides = 1
hidden_dims = 100

model = Sequential()
model.add(Conv2D(filters, kernel_size, padding='same', activation='relu',strides=strides, input_shape=(board_shape[0], board_shape[1], 1)))
model.add(Dense(hidden_dims))
model.add(Dense(1))
model.add(Activation('sigmoid'))
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

def train(model, X_train, y_train, X_val, y_val, batch_size=50, epochs=5, filename_suffix=''):
    model.fit(
        X_train, y_train,
        batch_size=batch_size,
        epochs=epochs,
        validation_data=(X_val, y_val) )

train(model, X_train, y_train, X_val, y_val, filename_suffix='_basic')



model_save_path = './keras_model.h5'
model.save(model_save_path)

input_spec = [tf.TensorSpec([None, board_shape[0], board_shape[1], 1], tf.float32, name='input')]
onnx_model, _ = tf2onnx.convert.from_keras(model, input_signature=input_spec, opset=13)  

# Save the ONNX model to disk
onnx_model_path = 'game_of_life.onnx'
with open(onnx_model_path, "wb") as f:
    f.write(onnx_model.SerializeToString())
