In [1]:
import tensorflow as tf 
from tensorflow import keras
import tensorflow_datasets as tfds

from matplotlib import pyplot as plt
from matplotlib import ticker
from pathlib import Path

from datetime import datetime

import sys
sys.path.insert(0, "..")

print(tf.__version__)
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
    # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized
        print(e)

2.2.0
2 Physical GPUs, 2 Logical GPUs


## Load Data

In [2]:
## Download file
data_path = Path("..") / "datasets" / "data"
if not data_path.is_dir():
    data_path.mkdir(parents=True)

In [3]:
datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True, data_dir=data_path)
mnist_train, mnist_test = datasets['train'], datasets['test']

In [4]:
# define model
num_train_examples = info.splits['train'].num_examples
num_test_examples = info.splits['test'].num_examples
BUFFER_SIZE = 60000
BATCH_SIZE = 32
print(f"Training data samples: {num_train_examples}, Testing data samples: {num_test_examples}")

Training data samples: 60000, Testing data samples: 10000


In [5]:
# Network Parameters
num_input = 1 # MNIST data input (img shape: 28*28)
timesteps = 28 * 28 # timesteps
num_hidden = 128 # hidden layer num of features
num_classes = 10 # MNIST total classes (0-9 digits)

In [6]:
def scale(image, label):
    image = tf.cast(image, tf.float32)
    image /= 255
    image = tf.reshape(image, (-1, 1))
    return image, label

In [7]:
train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)
eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

In [8]:
logdir = "../logs/scalars/" + datetime.now().strftime("%Y%m%d-%H%M%S")
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)

## Create Model Using Neuromodulated Bistable RNNs

In [9]:
from bistablernn import NBR

In [10]:
model = tf.keras.Sequential([
  NBR(units=num_hidden, input_shape=(28*28, num_input), use_bias=True, 
                   recurrent_dropout=0, unroll=False, activation = "tanh", 
                   recurrent_activation = "sigmoid"),
  tf.keras.layers.Dense(num_classes)
])

model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
            optimizer=tf.keras.optimizers.Adam(learning_rate=0.002, beta_1=0.1),
            metrics=['accuracy'])

In [11]:
model.fit(train_dataset, epochs=35, validation_data=eval_dataset, callbacks=[tensorboard_callback])

Epoch 1/35
Epoch 2/35
Epoch 3/35
Epoch 4/35
Epoch 5/35
Epoch 6/35
Epoch 7/35
Epoch 8/35
Epoch 9/35
Epoch 10/35
Epoch 11/35
Epoch 12/35
Epoch 13/35
Epoch 14/35
Epoch 15/35
Epoch 16/35
Epoch 17/35
Epoch 18/35

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



Epoch 25/35
Epoch 26/35
Epoch 27/35
Epoch 28/35
Epoch 29/35
Epoch 30/35
Epoch 31/35
Epoch 32/35
Epoch 33/35
Epoch 34/35
Epoch 35/35


<tensorflow.python.keras.callbacks.History at 0x7fa580135790>

In [12]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
nbr (NBR)                    (None, 128)               50304     
_________________________________________________________________
dense (Dense)                (None, 10)                1290      
Total params: 51,594
Trainable params: 51,594
Non-trainable params: 0
_________________________________________________________________


In [13]:
model.evaluate(eval_dataset)



[0.06916925311088562, 0.978600025177002]