### Test Differentiable Neural Computer

Create synthetic input data `X` of dimension *NxM* where the first *N/2* rows consist of ones and zeros and the last *N/2* rows are zeros. The order of the rows are flipped for target `y` (first *N/2* rows are zeros now). The *DNC* needs to keep this in memory and predict `y` correctly.

In [1]:
import logging
import numpy as np
import tensorflow as tf
from model import DNC
from trainer import trainer

logger = tf.get_logger()
logger.setLevel(logging.ERROR)

### Generate training data

In [2]:
rows, cols = 6, 4
ones = np.random.randint(0, cols, size=rows)
seq = np.zeros((rows, cols))
seq[np.arange(rows), ones] = 1
zer = np.zeros((rows, cols))

X = np.concatenate((seq, zer), axis=0).astype(np.float32)
y = np.concatenate((zer, seq), axis=0).astype(np.float32)

for i in range(rows):
    assert (X[i, :] == y[rows+i,:]).all()

X_train = np.expand_dims(X, axis=0)
y_train = np.expand_dims(y, axis=0)

### Initialize and train DNC model

Initialize:

In [3]:
dnc = DNC(
    output_dim=cols,
    memory_shape=(10,4),  # shape of memory matrix
    n_read=1              # nb of read heads
)

Train:

In [4]:
trainer(
    model=dnc,
    loss_fn=tf.keras.losses.mse,
    X_train=X_train,
    y_train=y_train,
    epochs=2000,
    batch_size=1,
    verbose=False
)

Predict on `X`:

In [5]:
y_pred = dnc(X).numpy()

Check if the predictions are almost the same as the ground truth `y`:

In [6]:
np.testing.assert_almost_equal(y_pred, y, decimal=2)

In [7]:
np.set_printoptions(precision=3)
print('Prediction: ')
print(y_pred)
print('\nGround truth: ')
print(y)

Prediction: 
[[-1.922e-03  1.810e-03 -1.225e-04  1.335e-03]
 [-2.168e-03  2.258e-04 -1.364e-05  2.740e-03]
 [ 4.904e-04 -6.639e-04 -1.084e-03 -1.633e-03]
 [-2.993e-03  1.132e-03 -1.938e-04  4.551e-03]
 [-7.027e-04 -2.482e-03 -7.492e-04 -1.286e-05]
 [-9.379e-04 -6.188e-04  2.214e-03 -1.392e-03]
 [-1.105e-03 -1.609e-03  8.935e-04  9.934e-01]
 [-1.471e-03  9.989e-01 -2.230e-05  4.435e-03]
 [-2.691e-03  3.490e-03  9.998e-01  2.988e-03]
 [ 7.629e-05  9.961e-01 -2.395e-04  5.796e-04]
 [-4.813e-03  1.002e+00  1.611e-04  1.424e-03]
 [-5.659e-04  9.966e-01  3.111e-04  1.470e-03]]

Ground truth: 
[[0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 0.]
 [0. 0. 0. 1.]
 [0. 1. 0. 0.]
 [0. 0. 1. 0.]
 [0. 1. 0. 0.]
 [0. 1. 0. 0.]
 [0. 1. 0. 0.]]
