<h1> Toy Task to predict maximum number in a set

In [1]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''

<h1> Data Generator

In [2]:
def getData(size, n, maxN):
    """
    Params:
      size: number of sets to generate
      n: number of elements in each set
      maxN: range of elements in set
    Returns:
      x: numpy array of shape (size, n, dimension of each element in set)
      y: numpy array of shape (size, dimension of set-function output label)
    """
    x = []
    for i in range(size):
        xp = np.random.randint(1, np.random.randint(2, maxN), (n)).tolist()
        x.append(xp)
    x = np.array(x)
    x = np.reshape(x, [-1, n, 1])
    y = np.max(x, axis=1)
    y = np.reshape(y, [-1, 1])
    print(x.shape, y.shape)
    
    return x, y

<h1> Model Parameters

In [3]:
numHidden = 128
n = 10
temp = 0.1
alpha = 1e-5
numLSTMUnits = 128
numLSTMHidden = 128
numLSTMOut = 1
maxN = 30

numEpochs = 100
numSubEpochs = 150 # number of epochs to individually optimize adversary and learner
batchSize = 32

<h1> SPAN Architecture

In [4]:
def sinkhorn(X, num_iters=100):
    """
    Sinkhorn Normalization works better in practice in the log-space
    Reference: https://github.com/google/gumbel_sinkhorn/blob/master/sinkhorn_ops.py
    """
    n = tf.shape(X)[1]
    X = tf.reshape(X, [-1, n, n])
    for _ in range(num_iters):
        X -= tf.reshape(tf.reduce_logsumexp(X, axis=2), [-1, n, 1])
        X -= tf.reshape(tf.reduce_logsumexp(X, axis=1), [-1, 1, n])
    return tf.exp(X)

In [5]:
tf.reset_default_graph()

X = tf.placeholder(tf.float32, [None, n, 1])
Y = tf.placeholder(tf.float32, [None, 1])
keep_prob = tf.placeholder(tf.float32)

X_flattened = tf.reshape(X, [-1, 1])

initializer = tf.contrib.layers.xavier_initializer()

W1 = tf.Variable(initializer([1, numHidden]))
B1 = tf.Variable(initializer([numHidden]))
hidden = tf.nn.dropout(tf.nn.relu(tf.add(tf.matmul(X_flattened, W1), B1)), keep_prob=keep_prob)

W2 = tf.Variable(initializer([numHidden, n]))
B2 = tf.Variable(initializer([n]))
preSinkhorn = tf.add(tf.matmul(hidden, W2), B2)
preSinkhorn = tf.reshape(preSinkhorn, [-1, n, n])
preSinkhorn /= temp

# Apply the Sinkhorn operator on the matrix to convert it into a Doubly-Stochastic Matrix
postSinkhorn = sinkhorn(preSinkhorn)

# Since the inverse of a permutation matrix is its transpose, approximate the inverse
# of the double stochastic matrix with its transpose
postSinkhornInv = tf.transpose(postSinkhorn, [0, 2, 1])

# Apply the adversarial permutations on the input sets
permutedX = tf.matmul(postSinkhornInv, X)

In [6]:
# Define LSTM and FC layerds of the learner function

LSTMW1 = tf.Variable(initializer([numLSTMUnits, numLSTMHidden]))
LSTMB1 = tf.Variable(initializer([numLSTMHidden]))

LSTMW2 = tf.Variable(initializer([numLSTMHidden, numLSTMOut]))
LSTMB2 = tf.Variable(initializer([numLSTMOut]))

permutedXUnstacked = tf.unstack(permutedX, n, axis=1)
cell = tf.nn.rnn_cell.BasicLSTMCell(numLSTMUnits, forget_bias=1)
outputs, states = tf.nn.static_rnn(cell, permutedXUnstacked, dtype=tf.float32)
LSTMHidden = tf.nn.dropout(tf.nn.relu(tf.add(tf.matmul(outputs[-1], LSTMW1), LSTMB1)), keep_prob=keep_prob)
LSTMOut = tf.add(tf.matmul(LSTMHidden, LSTMW2), LSTMB2)

loss = tf.nn.l2_loss(LSTMOut - Y)

# Parameters of the learner function
varListFunction = [LSTMW1, LSTMB1, LSTMW2, LSTMB2]
varListFunction.extend(cell.variables)

# Parameters of the adversary
varListPerm = [W1, B1, W2, B2]

# Optimizer which updates parameters of the learner to minimize loss, keeping 
# parameters of the adversary fixed
optimizerFunction = tf.train.AdamOptimizer(learning_rate=alpha).minimize(loss, var_list=varListFunction)

# Optimizer which updates parameters of the adversary to maximize loss, keeping
# parameters of the learner fixed
optimizerPerm = tf.train.AdamOptimizer(learning_rate=alpha).minimize(-1 * loss, var_list=varListPerm)

Instructions for updating:
This class is deprecated, please use tf.nn.rnn_cell.LSTMCell, which supports all the feature this cell currently has. Please replace the existing code with tf.nn.rnn_cell.LSTMCell(name='basic_lstm_cell').


<h1> Create Train, Val, Test Splits

In [7]:
x_train, y_train = getData(1000, n, maxN)
x_val, y_val = getData(200, n, maxN)
x_test, y_test = getData(200, n, maxN)

(1000, 10, 1) (1000, 1)
(200, 10, 1) (200, 1)
(200, 10, 1) (200, 1)


<h1> Initialize session

In [8]:
sess = tf.Session()
sess.run(tf.global_variables_initializer())
lossList = []
minValLoss = 1000000000
weights = {'W1': sess.run(W1),
           'B1': sess.run(B1),
           'W2': sess.run(W2),
           'B2': sess.run(B2),
           'LSTMW1': sess.run(LSTMW1),
           'LSTMB1': sess.run(LSTMB1),
           'LSTMW2': sess.run(LSTMW2),
           'LSTMB2': sess.run(LSTMB2),
           'kernel': sess.run(cell.variables[0]),
           'bias': sess.run(cell.variables[1])}

<h1> SPAN training with Alternating Optimization

In [9]:
for i in range(numEpochs):
    for j in range(numSubEpochs):
        totalLoss = 0
        for k in range(0, len(x_train), batchSize):
            batch_x_train = x_train[k:min(k+batchSize, len(x_train)), :, :]
            batch_y_train = y_train[k:min(k+batchSize, len(y_train)), :]
            _, l = sess.run([optimizerPerm, loss], feed_dict={X:batch_x_train, Y:batch_y_train, keep_prob:1.0})
            totalLoss += l
        lossList.append(totalLoss)
        print("Optimizing Permutation: Epoch %d Sub-epoch %d Loss %f" % (i, j, totalLoss), end='\r')
    print()
    for j in range(numSubEpochs):
        totalLoss = 0
        for k in range(0, len(x_train), batchSize):
            batch_x_train = x_train[k:min(k+batchSize, len(x_train)), :, :]
            batch_y_train = y_train[k:min(k+batchSize, len(y_train)), :]
            _, l = sess.run([optimizerFunction, loss], feed_dict={X:batch_x_train, Y:batch_y_train, keep_prob:1.0})
            totalLoss += l
        lossList.append(totalLoss)
        print("Optimizing Function: Epoch %d Sub-epoch %d Loss %f" % (i, j, l), end='\r')
    print()
    valLoss = sess.run(loss, feed_dict={X:x_val, Y:y_val, keep_prob:1.0})
    print("Validation loss is %f" % valLoss)
    if valLoss < minValLoss:
        minValLoss = valLoss
        weights = {'W1': sess.run(W1),
                   'B1': sess.run(B1),
                   'W2': sess.run(W2),
                   'B2': sess.run(B2),
                   'LSTMW1': sess.run(LSTMW1),
                   'LSTMB1': sess.run(LSTMB1),
                   'LSTMW2': sess.run(LSTMW2),
                   'LSTMB2': sess.run(LSTMB2),
                   'kernel': sess.run(cell.variables[0]),
                   'bias': sess.run(cell.variables[1])}   

Optimizing Permutation: Epoch 0 Sub-epoch 149 Loss 93503.203491
Optimizing Function: Epoch 0 Sub-epoch 149 Loss 8.9673541
Validation loss is 194.698502
Optimizing Permutation: Epoch 1 Sub-epoch 149 Loss 20028.749390
Optimizing Function: Epoch 1 Sub-epoch 149 Loss 1.211377
Validation loss is 47.591011
Optimizing Permutation: Epoch 2 Sub-epoch 149 Loss 1678.978518
Optimizing Function: Epoch 2 Sub-epoch 149 Loss 0.100525
Validation loss is 2.682525
Optimizing Permutation: Epoch 3 Sub-epoch 149 Loss 59317.536438
Optimizing Function: Epoch 3 Sub-epoch 149 Loss 3.087206
Validation loss is 136.017822
Optimizing Permutation: Epoch 4 Sub-epoch 149 Loss 11677.958481
Optimizing Function: Epoch 4 Sub-epoch 149 Loss 0.562636
Validation loss is 13.547885
Optimizing Permutation: Epoch 5 Sub-epoch 149 Loss 15261.074387
Optimizing Function: Epoch 5 Sub-epoch 149 Loss 4.323665
Validation loss is 129.654755
Optimizing Permutation: Epoch 6 Sub-epoch 149 Loss 25525.040283
Optimizing Function: Epoch 6 Sub-e

KeyboardInterrupt: 

<h1> Restore weights that gave least loss on validation set

In [10]:
print("Minimum validation loss is", minValLoss)
print("Restoring weights with minimum validation loss")
sess.run(tf.assign(W1, weights['W1']))
sess.run(tf.assign(B1, weights['B1']))
sess.run(tf.assign(W2, weights['W2']))
sess.run(tf.assign(B2, weights['B2']))
sess.run(tf.assign(LSTMW1, weights['LSTMW1']))
sess.run(tf.assign(LSTMB1, weights['LSTMB1']))
sess.run(tf.assign(LSTMW2, weights['LSTMW2']))
sess.run(tf.assign(LSTMB2, weights['LSTMB2']))
sess.run(tf.assign(cell.variables[0], weights['kernel']))
sess.run(tf.assign(cell.variables[1], weights['bias']))
print("Validation loss after restoring is", sess.run(loss, feed_dict={X:x_val, Y:y_val, keep_prob:1.0}))

Minimum validation loss is 2.6825252
Restoring weights with minimum validation loss
Validation loss after restoring is 2.6825252


<h1> Accuracy on test set after rounding

In [14]:
preds = sess.run(LSTMOut, feed_dict={X:x_test, keep_prob:1.0})
preds = np.round(preds)
cnt = 0
for i in range(len(y_test)):
    if preds[i, 0] == y_test[i, 0]:
        cnt += 1
print(cnt / len(y_test))

1.0


<h1> Average Relative Error on Test Set

In [13]:
preds = sess.run(LSTMOut, feed_dict={X:x_test, keep_prob:1.0})
total_relative_error = 0
cnt = 0
for i in range(len(preds)):
    if y_test[i, 0] == 0:
        continue
    relative_error = abs(preds[i, 0] - y_test[i, 0]) / y_test[i, 0]
    cnt += 1
    total_relative_error += relative_error
    
print("Average relative error on test set is", (total_relative_error / cnt))

Average relative error on test set is 0.018251811778837182


<h1> Studying Permutation Invariance of SPAN

In [19]:
print("Studying output on different permutations")
for i in range(20):
    tt = []
    for num in x_test[10]:
        tt.append(num[0])
    np.random.shuffle(tt)
    print("x_test is", tt)
    x_test_perm = np.reshape(tt, [-1, n, 1])
    preds = sess.run(LSTMOut, feed_dict={X:x_test_perm, keep_prob:1.0})
    print("Prediction is", preds[0, 0], "\n")

Studying output on different permutations
x_test is [20, 19, 2, 11, 16, 20, 6, 16, 8, 1]
Prediction is 20.173746 

x_test is [20, 20, 11, 16, 8, 2, 1, 19, 16, 6]
Prediction is 20.173746 

x_test is [11, 16, 20, 2, 8, 19, 16, 1, 6, 20]
Prediction is 20.17375 

x_test is [1, 20, 19, 6, 16, 16, 11, 8, 20, 2]
Prediction is 20.173748 

x_test is [16, 8, 20, 19, 11, 6, 20, 16, 1, 2]
Prediction is 20.17375 

x_test is [16, 19, 2, 20, 20, 6, 8, 11, 16, 1]
Prediction is 20.173748 

x_test is [20, 16, 2, 20, 6, 11, 19, 1, 8, 16]
Prediction is 20.173748 

x_test is [2, 16, 20, 19, 1, 16, 8, 20, 11, 6]
Prediction is 20.173748 

x_test is [11, 1, 2, 20, 20, 19, 8, 6, 16, 16]
Prediction is 20.173746 

x_test is [11, 2, 16, 8, 6, 1, 20, 20, 16, 19]
Prediction is 20.173746 

x_test is [1, 16, 8, 16, 2, 19, 11, 6, 20, 20]
Prediction is 20.173746 

x_test is [11, 19, 20, 2, 16, 1, 8, 16, 20, 6]
Prediction is 20.17375 

x_test is [2, 8, 20, 20, 1, 19, 16, 6, 16, 11]
Prediction is 20.173748 

x_test is [2