<a href="https://colab.research.google.com/github/Muzhi1920/awesome-models/blob/main/03-Loss%E4%B8%8E%E4%BC%98%E5%8C%96/NerualSort.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import tensorflow as tf

In [2]:
# inputs: [batch_size, seq_len]
# return: M=[batch_size, seq_len, seq_len]

SEQUENCE_LEN = 6

def neural_sort(inputs, realshow_mask=None):
  # preds 预先处理非realshow被mask
  inputs = inputs[:,:SEQUENCE_LEN] * realshow_mask

  A1 = tf.tile(tf.expand_dims(inputs, 2), [1, 1, SEQUENCE_LEN])
  A2 = tf.tile(tf.expand_dims(inputs, 1), [1, SEQUENCE_LEN, 1])
  A = tf.abs(A1-A2)
  A = tf.transpose(tf.reduce_sum(A, axis=2, keep_dims=True), [0, 2, 1])

  ck_list = []
  tempreture = 1.0
  for i in range(SEQUENCE_LEN):
    ck_list.append( (SEQUENCE_LEN + 1 - 2*i) * tf.expand_dims(inputs, 1) - A)
  ck = tf.concat(ck_list, axis=1) / tempreture

  # realshow mask to control loss on realshow rewards
  ck_masked = ck
  if realshow_mask != None:
    realshow_mask = tf.tile(tf.expand_dims(realshow_mask, 1), [1, SEQUENCE_LEN, 1])
    infinity = tf.fill(tf.shape(realshow_mask), -1e9)
    ck_masked = tf.where(realshow_mask > 0.000001, ck, infinity)
  ck = tf.nn.softmax(ck, axis=2)
  ck_masked = tf.nn.softmax(ck_masked, axis=2)
  return ck, ck_masked

In [3]:
input_emb = tf.random.normal([4, 6])
realshow_mask = tf.concat([tf.ones([4, 5]), tf.zeros([4,1])], axis=-1)
input_emb, realshow_mask

(<tf.Tensor: shape=(4, 6), dtype=float32, numpy=
 array([[-0.43738857,  0.94122016,  0.8991876 ,  0.926741  , -0.7687766 ,
         -0.03910059],
        [ 1.1879306 ,  1.1556386 ,  0.09810921, -0.19238792, -1.1775075 ,
          0.6740587 ],
        [-0.08911345, -1.0718447 , -0.40152082,  1.7250129 , -1.1861783 ,
          1.1754292 ],
        [ 1.4859723 ,  1.0844629 ,  0.21600237, -0.12503721, -0.44688466,
         -1.3961648 ]], dtype=float32)>,
 <tf.Tensor: shape=(4, 6), dtype=float32, numpy=
 array([[1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 0.]], dtype=float32)>)

In [4]:
inputs = input_emb
A1 = tf.tile(tf.expand_dims(inputs, 2), [1, 1, SEQUENCE_LEN])
A1

<tf.Tensor: shape=(4, 6, 6), dtype=float32, numpy=
array([[[-0.43738857, -0.43738857, -0.43738857, -0.43738857,
         -0.43738857, -0.43738857],
        [ 0.94122016,  0.94122016,  0.94122016,  0.94122016,
          0.94122016,  0.94122016],
        [ 0.8991876 ,  0.8991876 ,  0.8991876 ,  0.8991876 ,
          0.8991876 ,  0.8991876 ],
        [ 0.926741  ,  0.926741  ,  0.926741  ,  0.926741  ,
          0.926741  ,  0.926741  ],
        [-0.7687766 , -0.7687766 , -0.7687766 , -0.7687766 ,
         -0.7687766 , -0.7687766 ],
        [-0.03910059, -0.03910059, -0.03910059, -0.03910059,
         -0.03910059, -0.03910059]],

       [[ 1.1879306 ,  1.1879306 ,  1.1879306 ,  1.1879306 ,
          1.1879306 ,  1.1879306 ],
        [ 1.1556386 ,  1.1556386 ,  1.1556386 ,  1.1556386 ,
          1.1556386 ,  1.1556386 ],
        [ 0.09810921,  0.09810921,  0.09810921,  0.09810921,
          0.09810921,  0.09810921],
        [-0.19238792, -0.19238792, -0.19238792, -0.19238792,
         -0.1

In [5]:
A2 = tf.tile(tf.expand_dims(inputs, 1), [1, SEQUENCE_LEN, 1])
A2

<tf.Tensor: shape=(4, 6, 6), dtype=float32, numpy=
array([[[-0.43738857,  0.94122016,  0.8991876 ,  0.926741  ,
         -0.7687766 , -0.03910059],
        [-0.43738857,  0.94122016,  0.8991876 ,  0.926741  ,
         -0.7687766 , -0.03910059],
        [-0.43738857,  0.94122016,  0.8991876 ,  0.926741  ,
         -0.7687766 , -0.03910059],
        [-0.43738857,  0.94122016,  0.8991876 ,  0.926741  ,
         -0.7687766 , -0.03910059],
        [-0.43738857,  0.94122016,  0.8991876 ,  0.926741  ,
         -0.7687766 , -0.03910059],
        [-0.43738857,  0.94122016,  0.8991876 ,  0.926741  ,
         -0.7687766 , -0.03910059]],

       [[ 1.1879306 ,  1.1556386 ,  0.09810921, -0.19238792,
         -1.1775075 ,  0.6740587 ],
        [ 1.1879306 ,  1.1556386 ,  0.09810921, -0.19238792,
         -1.1775075 ,  0.6740587 ],
        [ 1.1879306 ,  1.1556386 ,  0.09810921, -0.19238792,
         -1.1775075 ,  0.6740587 ],
        [ 1.1879306 ,  1.1556386 ,  0.09810921, -0.19238792,
         -1.1

In [6]:
A = tf.abs(A1-A2)
A

<tf.Tensor: shape=(4, 6, 6), dtype=float32, numpy=
array([[[0.        , 1.3786087 , 1.3365762 , 1.3641295 , 0.33138803,
         0.39828798],
        [1.3786087 , 0.        , 0.04203254, 0.01447916, 1.7099967 ,
         0.98032075],
        [1.3365762 , 0.04203254, 0.        , 0.02755338, 1.6679642 ,
         0.9382882 ],
        [1.3641295 , 0.01447916, 0.02755338, 0.        , 1.6955175 ,
         0.9658416 ],
        [0.33138803, 1.7099967 , 1.6679642 , 1.6955175 , 0.        ,
         0.729676  ],
        [0.39828798, 0.98032075, 0.9382882 , 0.9658416 , 0.729676  ,
         0.        ]],

       [[0.        , 0.03229201, 1.0898213 , 1.3803185 , 2.365438  ,
         0.5138719 ],
        [0.03229201, 0.        , 1.0575293 , 1.3480265 , 2.333146  ,
         0.4815799 ],
        [1.0898213 , 1.0575293 , 0.        , 0.29049712, 1.2756168 ,
         0.5759495 ],
        [1.3803185 , 1.3480265 , 0.29049712, 0.        , 0.9851196 ,
         0.8664466 ],
        [2.365438  , 2.333146  , 1.27

In [7]:
A = tf.transpose(tf.compat.v1.reduce_sum(A, axis=2, keep_dims=True), [0, 2, 1])
A

<tf.Tensor: shape=(4, 1, 6), dtype=float32, numpy=
array([[[ 4.80899  ,  4.1254377,  4.0124145,  4.067521 ,  6.1345425,
          4.0124145]],

       [[ 5.3817415,  5.252574 ,  4.2894144,  4.870408 ,  8.810886 ,
          4.2894144]],

       [[ 5.4708724,  6.81152  ,  5.4708724, 10.198292 ,  7.2688546,
          7.999958 ]],

       [[ 8.097483 ,  6.491445 ,  4.754524 ,  4.754524 ,  5.398219 ,
          9.195339 ]]], dtype=float32)>

In [8]:

ck_list = []
tempreture = 1.0
for i in range(SEQUENCE_LEN):
  ck_list.append((SEQUENCE_LEN + 1 - 2*i) * tf.expand_dims(inputs, 1) - A)
ck = tf.concat(ck_list, axis=1) / tempreture
ck

<tf.Tensor: shape=(4, 6, 6), dtype=float32, numpy=
array([[[ -7.87071   ,   2.4631033 ,   2.281899  ,   2.4196658 ,
         -11.515979  ,  -4.2861185 ],
        [ -6.9959326 ,   0.5806632 ,   0.48352385,   0.56618404,
          -9.978426  ,  -4.207917  ],
        [ -6.1211557 ,  -1.3017774 ,  -1.3148515 ,  -1.2872982 ,
          -8.440872  ,  -4.1297164 ],
        [ -5.2463784 ,  -3.1842175 ,  -3.113227  ,  -3.14078   ,
          -6.903319  ,  -4.051515  ],
        [ -4.3716016 ,  -5.066658  ,  -4.911602  ,  -4.994262  ,
          -5.365766  ,  -3.9733138 ],
        [ -3.4968243 ,  -6.949098  ,  -6.709977  ,  -6.847744  ,
          -3.8282127 ,  -3.8951128 ]],

       [[  2.933772  ,   2.836896  ,  -3.60265   ,  -6.2171235 ,
         -17.05344   ,   0.4289961 ],
        [  0.5579114 ,   0.52561903,  -3.7988684 ,  -5.832348  ,
         -14.698423  ,  -0.919121  ],
        [ -1.8179498 ,  -1.7856584 ,  -3.9950867 ,  -5.4475718 ,
         -12.343409  ,  -2.2672384 ],
        [ -4.193811 

In [9]:
# realshow mask to control loss on realshow rewards
ck_masked = ck
if realshow_mask != None:
  mask = tf.tile(tf.expand_dims(realshow_mask, 1), [1, SEQUENCE_LEN, 1])
  infinity = tf.fill(tf.shape(mask), -1e9)
  ck_masked = tf.where(mask > 0.000001, ck, infinity)
ck_masked

<tf.Tensor: shape=(4, 6, 6), dtype=float32, numpy=
array([[[-7.87070990e+00,  2.46310329e+00,  2.28189898e+00,
          2.41966581e+00, -1.15159788e+01, -1.00000000e+09],
        [-6.99593258e+00,  5.80663204e-01,  4.83523846e-01,
          5.66184044e-01, -9.97842598e+00, -1.00000000e+09],
        [-6.12115574e+00, -1.30177736e+00, -1.31485152e+00,
         -1.28729820e+00, -8.44087219e+00, -1.00000000e+09],
        [-5.24637842e+00, -3.18421745e+00, -3.11322689e+00,
         -3.14077997e+00, -6.90331888e+00, -1.00000000e+09],
        [-4.37160158e+00, -5.06665802e+00, -4.91160202e+00,
         -4.99426222e+00, -5.36576605e+00, -1.00000000e+09],
        [-3.49682426e+00, -6.94909811e+00, -6.70997715e+00,
         -6.84774399e+00, -3.82821274e+00, -1.00000000e+09]],

       [[ 2.93377209e+00,  2.83689594e+00, -3.60264993e+00,
         -6.21712351e+00, -1.70534401e+01, -1.00000000e+09],
        [ 5.57911396e-01,  5.25619030e-01, -3.79886842e+00,
         -5.83234787e+00, -1.46984234e+0

In [10]:
ck = tf.nn.softmax(ck, axis=2)
ck

<tf.Tensor: shape=(4, 6, 6), dtype=float32, numpy=
array([[[1.16417186e-05, 3.58042777e-01, 2.98702508e-01, 3.42823237e-01,
         3.04016169e-07, 4.19551332e-04],
        [1.76539223e-04, 3.44599724e-01, 3.12699974e-01, 3.39646161e-01,
         8.94459299e-06, 2.86853989e-03],
        [2.62959884e-03, 3.25775862e-01, 3.21544319e-01, 3.30527127e-01,
         2.58493354e-04, 1.92646254e-02],
        [3.44706178e-02, 2.71040499e-01, 2.90981233e-01, 2.83073246e-01,
         6.57429174e-03, 1.13860108e-01],
        [2.23333225e-01, 1.11453615e-01, 1.30146980e-01, 1.19821630e-01,
         8.26405436e-02, 3.32604021e-01],
        [4.00584906e-01, 1.26879402e-02, 1.61153600e-02, 1.40413428e-02,
         2.87590414e-01, 2.68980086e-01]],

       [[5.02280831e-01, 4.55904454e-01, 7.28139654e-04, 5.33043858e-05,
         1.04860098e-09, 4.10332792e-02],
        [4.52276736e-01, 4.37904954e-01, 5.79800690e-03, 7.58839480e-04,
         1.07068672e-07, 1.03261210e-01],
        [3.55768204e-01, 3.

In [11]:
ck_masked = tf.nn.softmax(ck_masked, axis=2)
ck_masked

<tf.Tensor: shape=(4, 6, 6), dtype=float32, numpy=
array([[[1.16466044e-05, 3.58193040e-01, 2.98827857e-01, 3.42967123e-01,
         3.04143754e-07, 0.00000000e+00],
        [1.77047084e-04, 3.45591068e-01, 3.13599527e-01, 3.40623260e-01,
         8.97032442e-06, 0.00000000e+00],
        [2.68125231e-03, 3.32175106e-01, 3.27860445e-01, 3.37019712e-01,
         2.63570982e-04, 0.00000000e+00],
        [3.88997458e-02, 3.05866480e-01, 3.28369409e-01, 3.19445312e-01,
         7.41902227e-03, 0.00000000e+00],
        [3.34633738e-01, 1.66997731e-01, 1.95007145e-01, 1.79536045e-01,
         1.23825356e-01, 0.00000000e+00],
        [5.47980845e-01, 1.73564907e-02, 2.20450368e-02, 1.92078799e-02,
         3.93409818e-01, 0.00000000e+00]],

       [[5.23772895e-01, 4.75412101e-01, 7.59296003e-04, 5.55852239e-05,
         1.09346954e-09, 0.00000000e+00],
        [5.04357338e-01, 4.88330603e-01, 6.46565948e-03, 8.46221403e-04,
         1.19397853e-07, 0.00000000e+00],
        [4.60248977e-01, 4.

In [12]:
mock_reward = [[6.0,4,5,3,2,1],
               [7,6,4,5,3,2],
               [9,6,4,8,3,1],
               [7,8,5,6,1,2]]
rewards = tf.constant(mock_reward)
rewards = tf.expand_dims(rewards, 2)
rewards

<tf.Tensor: shape=(4, 6, 1), dtype=float32, numpy=
array([[[6.],
        [4.],
        [5.],
        [3.],
        [2.],
        [1.]],

       [[7.],
        [6.],
        [4.],
        [5.],
        [3.],
        [2.]],

       [[9.],
        [6.],
        [4.],
        [8.],
        [3.],
        [1.]],

       [[7.],
        [8.],
        [5.],
        [6.],
        [1.],
        [2.]]], dtype=float32)>

In [13]:
result = tf.matmul(ck_masked, rewards)
result

<tf.Tensor: shape=(4, 6, 1), dtype=float32, numpy=
array([[[3.9558835],
        [3.9733117],
        [3.9956768],
        [4.0718856],
        [4.437088 ],
        [4.31198  ]],

       [[6.5221977],
        [6.490579 ],
        [6.3436513],
        [5.593516 ],
        [4.645846 ],
        [4.316944 ]],

       [[8.000191 ],
        [8.001968 ],
        [7.8191066],
        [6.762853 ],
        [5.6579323],
        [4.8103127]],

       [[7.2235665],
        [7.321497 ],
        [6.9246383],
        [5.450488 ],
        [4.307716 ],
        [3.3586226]]], dtype=float32)>

In [14]:
allposition_reward = tf.squeeze(result)
allposition_reward

<tf.Tensor: shape=(4, 6), dtype=float32, numpy=
array([[3.9558835, 3.9733117, 3.9956768, 4.0718856, 4.437088 , 4.31198  ],
       [6.5221977, 6.490579 , 6.3436513, 5.593516 , 4.645846 , 4.316944 ],
       [8.000191 , 8.001968 , 7.8191066, 6.762853 , 5.6579323, 4.8103127],
       [7.2235665, 7.321497 , 6.9246383, 5.450488 , 4.307716 , 3.3586226]],
      dtype=float32)>

In [15]:
target_loss = - tf.reduce_sum(allposition_reward * realshow_mask, 1)
target_loss

<tf.Tensor: shape=(4,), dtype=float32, numpy=array([-20.433846, -29.595789, -36.24205 , -31.227905], dtype=float32)>