In [1]:
import tensorflow as tf

In [2]:
@tf.function
def iterate_over(data, nexts):
    data_len = tf.shape(data)[0]
    P = nexts
    p = tf.one_hot([0], data_len)
    
    x = tf.expand_dims(data, -1)
    y_ = tf.zeros((data_len))
    eye = tf.eye(data_len)
    
    for i in tf.range(data_len):
        x_scalar = tf.squeeze(p @ x)
        y_ += eye[i] * x_scalar
        
        p = p @ P

    return y_

data  = tf.Variable([1, 3, 2], dtype=tf.float32)
target = tf.Variable([1, 2, 3], dtype=tf.float32)
data_len = tf.shape(data)[0]
nexts = tf.Variable(tf.one_hot([2, 0, 1], data_len), dtype=tf.float32)

with tf.GradientTape(persistent=True) as tape:
    result = iterate_over(data, nexts)
    loss = tf.nn.l2_loss(result - target)
    
print(result)
print(tape.gradient(result, data))
print(tape.gradient(result, nexts))

print(loss)
print(tape.gradient(loss, nexts))

tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32)
tf.Tensor([1. 1. 1.], shape=(3,), dtype=float32)
tf.Tensor(
[[3. 4. 5.]
 [0. 0. 0.]
 [1. 3. 2.]], shape=(3, 3), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]], shape=(3, 3), dtype=float32)


In [3]:
@tf.function
def bistable_loss(x):
    a = (x ** 2)
    b = (x - 1) ** 2
    
    return a * b

@tf.function
def permute_matrix_loss(P, cycle_length=1, cycle_weight=0):
    loss = 0
    
    P_square = tf.math.square(P)
    axis_1_sum = tf.reduce_sum(P_square, axis=1)
    axis_0_sum = tf.reduce_sum(P_square, axis=0)
    
    # Penalize axes not adding up to one
    loss += tf.nn.l2_loss(axis_1_sum - 1)
    loss += tf.nn.l2_loss(axis_0_sum - 1)
    
    # Penalize numbers outside [0, 1]
    loss += tf.math.reduce_sum(bistable_loss(P))
    
    # Cycle loss
    Q = P
    for _ in tf.range(cycle_length - 1):
        Q = P @ Q
    cycle_loss = tf.nn.l2_loss(Q - tf.eye(tf.shape(Q)[0]))
    loss += cycle_loss * cycle_weight
    
    return loss

test1 = tf.constant([
    [1,0,0],
    [0,1,0],
    [0,0,1]
],dtype=tf.float32)

test2 = tf.constant([
    [0,1,0],
    [1,0,0],
    [0,0,1]
],dtype=tf.float32)

test3 = tf.constant([
    [-1, 0, 0],
    [0, 1, 0],
    [0, 0, 1],
],dtype=tf.float32)

test4 = tf.constant([
    [2, 0, 0],
    [0, 1, 0],
    [0, 0, 1],
],dtype=tf.float32)

test5 = tf.constant([
    [0.1, 0, 0],
    [0, 0.1, 0],
    [0, 0, 0.1],
],dtype=tf.float32)

test6 = tf.constant([
    [0.5, 0.5, 0],
    [0.5, 0.5, 0],
    [0, 0, 1],
],dtype=tf.float32)

test7 = tf.constant([
    [0, 1, 0],
    [1, 0, 0],
    [0, 0, 1],
],dtype=tf.float32)

print(permute_matrix_loss(test1))
print(permute_matrix_loss(test2))
print(permute_matrix_loss(test3))
print(permute_matrix_loss(test4))
print(permute_matrix_loss(test5))
print(permute_matrix_loss(test6))
print(permute_matrix_loss(test7, 1, 1))

tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(0.0, shape=(), dtype=float32)
tf.Tensor(4.0, shape=(), dtype=float32)
tf.Tensor(13.0, shape=(), dtype=float32)
tf.Tensor(2.9646, shape=(), dtype=float32)
tf.Tensor(0.75, shape=(), dtype=float32)
tf.Tensor(2.0, shape=(), dtype=float32)


In [4]:
opt = tf.keras.optimizers.Adam()

@tf.function
def train_step(data, nexts, target_data):
    data_length = tf.shape(data)[0]
    
    with tf.GradientTape() as tape:
        nextss = tf.nn.softmax(nexts, axis=1)
        actual_data = iterate_over(data, nextss)
        loss = tf.nn.l2_loss(actual_data - target_data)
        loss += permute_matrix_loss(nextss, data_length, 1)
#         loss += permute_matrix_loss(nextss)
    
    grads = tape.gradient(loss, nexts)
#     grads = tf.clip_by_norm(grads, 1)
    opt.apply_gradients(zip([grads], [nexts]))
    
    return loss, actual_data

data  = tf.constant([1, 3, 2, 5, 4], dtype=tf.float32)
target_data  = tf.constant([1, 2, 3, 4, 5], dtype=tf.float32)
data_len = tf.shape(data)[0]
# nexts = tf.Variable(tf.one_hot([2, 4, 1, 0, 3], data_len), dtype=tf.float32)
nexts = tf.Variable(tf.one_hot([1, 1, 1, 1, 1], data_len), dtype=tf.float32)
# nexts = tf.Variable(tf.one_hot([0,0,0,0,0], data_len), dtype=tf.float32)

for i in range(10000):
    loss, actual_data = train_step(data, nexts, target_data)
    if i % 1000 == 0:
        argmax_next = tf.argmax(nexts, 1)
        defuzzified = tf.one_hot(argmax_next, data_len)
        defuzzified_data = iterate_over(data, defuzzified)
        tf.print(loss, tf.round(actual_data), defuzzified_data, argmax_next, [2, 4, 1, 0, 3])
        
tf.print(nexts)

8.7374239 [1 3 3 3 3] [1 3 3 3 3] [1 1 1 1 1] [2, 4, 1, 0, 3]
5.66760302 [1 2 3 4 4] [1 2 3 5 5] [2 3 1 3 3] [2, 4, 1, 0, 3]
3.87311959 [1 2 3 4 4] [1 2 3 4 5] [2 4 1 3 3] [2, 4, 1, 0, 3]
1.50285113 [1 2 3 4 5] [1 2 3 4 5] [2 4 1 0 3] [2, 4, 1, 0, 3]
0.205469772 [1 2 3 4 5] [1 2 3 4 5] [2 4 1 0 3] [2, 4, 1, 0, 3]
0.0740788057 [1 2 3 4 5] [1 2 3 4 5] [2 4 1 0 3] [2, 4, 1, 0, 3]
0.0350763313 [1 2 3 4 5] [1 2 3 4 5] [2 4 1 0 3] [2, 4, 1, 0, 3]
0.0185657572 [1 2 3 4 5] [1 2 3 4 5] [2 4 1 0 3] [2, 4, 1, 0, 3]
0.0103952968 [1 2 3 4 5] [1 2 3 4 5] [2 4 1 0 3] [2, 4, 1, 0, 3]
0.00600781338 [1 2 3 4 5] [1 2 3 4 5] [2 4 1 0 3] [2, 4, 1, 0, 3]
[[-2.18890929 -2.56413817 4.51624966 -2.87268877 -3.21818304]
 [-2.38402414 -2.36196756 -2.74159765 -1.29020119 4.7228055]
 [-2.63246131 5.06587076 -2.92895341 -2.58874464 -2.98058629]
 [3.2075038 -3.05285668 -3.33832383 -1.79198897 -2.805336]
 [-2.95142794 -2.45104289 -3.21477652 4.53860283 -2.06163]]


In [11]:
P = tf.nn.softmax(nexts, axis=1)
data  = tf.constant([1, 3, 2, 5, 4], dtype=tf.float32)
iterate_over(data, P)

<tf.Tensor: shape=(5,), dtype=float32, numpy=
array([1.       , 2.0023386, 2.9993038, 3.992791 , 4.96837  ],
      dtype=float32)>

In [12]:
tf.print(data)
tf.print(tf.argmax(P,1))
for i in range(5):
    x = tf.squeeze(tf.round(tf.expand_dims(P[i],0) @ tf.expand_dims(data, -1)))
#     p = tf.squeeze(tf.round(tf.expand_dims(P[i],0) @ tf.expand_dims(data, -1)))
    tf.print(x, P[i])

[1 3 2 5 4]
[2 4 1 0 3]
2 [0.00122076692 0.000838828098 0.996888101 0.000616128382 0.00043613906]
4 [0.000815673964 0.000833864906 0.000570458942 0.00243533053 0.995344698]
3 [0.000452865468 0.998417616 0.000336669764 0.000473102235 0.000319727667]
1 [0.98762 0.00188690366 0.00141831581 0.0066579082 0.0024168333]
5 [0.000556805404 0.000918370206 0.000427890482 0.996741354 0.00135561905]


In [13]:
tf.print(tf.argmax(P,1))
tf.round(P @ P @ P @ P @ P)

[2 4 1 0 3]


<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]], dtype=float32)>

In [14]:
permute_matrix_loss(P, 5, 1)

<tf.Tensor: shape=(), dtype=float32, numpy=0.0030093766>

In [15]:
argmax_next = tf.argmax(P, 1)
DQ = tf.one_hot(argmax_next, data_len)
DQ @ DQ @ DQ @ DQ @ DQ

<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]], dtype=float32)>

In [16]:
Q = tf.one_hot([2, 4, 1, 0, 3], 5)
Q @ Q @ Q @ Q @ Q

<tf.Tensor: shape=(5, 5), dtype=float32, numpy=
array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]], dtype=float32)>