In [1]:
import numpy as np
import tensorflow as tf

class PolicyNetwork(tf.keras.Model):
    def __init__(self, n_hidden, n_output_1, n_output_2):
        super().__init__()
        self.n_hidden = n_hidden
        self.n_output_1 = n_output_1
        self.n_output_2 = n_output_2
        
        self.flatten = tf.keras.layers.Flatten()
        self.dense_1 = tf.keras.layers.Dense(self.n_hidden)
        self.dense_2 = tf.keras.layers.Dense(self.n_output_1)
        self.dense_3 = tf.keras.layers.Dense(self.n_output_2)
    
    def call(self, data_input):
        hidden = self.flatten(data_input)
        hidden = self.dense_1(hidden)
        hidden = tf.keras.activations.tanh(hidden)
        output_1 = self.dense_2(hidden)
        output_2 = self.dense_3(hidden)
        return output_1, output_2

class Agent:
    NEG_INF = -10000000000.0
    def __init__(self, env_grid_length):
        self.policy = PolicyNetwork(40, 5, 6)
        self.env_grid_length = env_grid_length
        
    def backward_action(self, current_position):
        encoded_position = tf.one_hot(
            current_position, depth=self.env_grid_length, axis=-1
        )
        action_logits = self.policy.predict(encoded_position)[0]
        
        back_coord_mask = tf.math.equal(current_position, 0)
        masked_logits = self._mask_action_logits(
            action_logits, back_coord_mask
        )
        
        action_indices = tf.random.categorical(masked_logits, 1)
        encoded_actions = tf.one_hot(
            tf.reshape(action_indices, shape=(-1,)), depth=5, dtype=tf.int32
        )
        
        is_at_origin = tf.math.reduce_all(
            back_coord_mask, axis=1, keepdims=True
        )
        action_mask = tf.cast(
            tf.math.logical_not(is_at_origin), 
            dtype=tf.int32
        )
        back_actions = encoded_actions*action_mask
        return back_actions
    
    def _mask_action_logits(self, action_logits, mask):
        avoid_inds = tf.where(mask)
        # Need validation that masked are not sampled
        masked_logits = tf.tensor_scatter_nd_add(
            action_logits, avoid_inds, 
            tf.constant([self.NEG_INF]*avoid_inds.shape[0])
        )
        return masked_logits
        
    def calculate_log_proba_ratios(self, trajectories, backward_actions, forward_actions):
        reshaped_positions = tf.reshape(trajectories, shape=(-1, 5))
        encoded_positions = tf.one_hot(
            reshaped_positions, 
            depth=self.env_grid_length, axis=-1
        )
        backward_logits, forward_logits = self.policy.predict(encoded_positions)
        
        back_coord_mask = tf.math.equal(reshaped_positions, 0)
        backward_log_probas = self._action_log_probas(
            backward_logits, back_coord_mask, backward_actions
        )
        
        forward_coord_mask = tf.math.equal(reshaped_positions, self.env_grid_length-1)
        forward_log_probas = self._action_log_probas(
            forward_logits, forward_coord_mask, forward_actions
        )
        log_proba_ratios = (
            tf.reduce_sum(forward_log_probas, axis=0)
            - tf.reduce_sum(backward_log_probas, axis=0)
        )
        return log_proba_ratios
    
    def _action_log_probas(self, logits, mask, actions):
        masked_logits = self._mask_action_logits(logits, mask)
        log_softmax = tf.nn.log_softmax(masked_logits)
        action_log_probas = tf.reduce_sum(
            (
                tf.reshape(log_softmax, shape=actions.shape)
                * tf.cast(actions, dtype=tf.float32)
            ), 
            axis=2, keepdims=True
        )
        return action_log_probas
        
    
def backward_step(current_position, back_action):
    new_position = current_position - back_action
    return new_position

In [2]:
batch_size = 2
grid_length = 3

current_position = tf.random.uniform(
    shape=(batch_size, 5), minval=0, maxval=grid_length, dtype=tf.int32
)

agent = Agent(grid_length)

trajectory = [current_position]
actions = []

sampling = tf.constant(True)

while sampling:
    action = agent.backward_action(current_position)
    current_position = backward_step(current_position, action)
    trajectory.append(current_position)
    actions.append(action)
    
    sampling = tf.math.reduce_any(tf.math.not_equal(current_position, 0))
    
trajectory = tf.stack(trajectory)
back_actions = tf.stack(
    actions + [tf.zeros(shape=action.shape, dtype=tf.int32)]
)
forward_actions = tf.stack(
    [tf.zeros(shape=action.shape, dtype=tf.int32)] + actions
)
terminal_action = tf.concat([
    tf.ones(shape=(1, forward_actions.shape[1], 1), dtype=tf.int32),
    tf.zeros(shape=(forward_actions.shape[0]-1, forward_actions.shape[1], 1), dtype=tf.int32)
], axis=0)
forward_actions = tf.concat([forward_actions, terminal_action], axis=2)

2022-11-25 16:16:30.706918: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  SSE4.1 SSE4.2 AVX AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.




In [3]:
log_proba_ratios = agent.calculate_log_proba_ratios(trajectory, back_actions, forward_actions)



In [4]:
log_proba_ratios

<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[-6.537716],
       [-8.074623]], dtype=float32)>

In [8]:
trajectory

<tf.Tensor: shape=(8, 2, 5), dtype=int32, numpy=
array([[[0, 1, 2, 2, 1],
        [2, 0, 2, 1, 2]],

       [[0, 1, 2, 1, 1],
        [2, 0, 1, 1, 2]],

       [[0, 1, 2, 1, 0],
        [2, 0, 1, 1, 1]],

       [[0, 1, 1, 1, 0],
        [2, 0, 1, 1, 0]],

       [[0, 0, 1, 1, 0],
        [2, 0, 1, 0, 0]],

       [[0, 0, 0, 1, 0],
        [1, 0, 1, 0, 0]],

       [[0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0]],

       [[0, 0, 0, 0, 0],
        [0, 0, 0, 0, 0]]], dtype=int32)>

In [4]:
reshaped_positions = tf.reshape(trajectory, shape=(-1, 5))
encoded_positions = tf.one_hot(
    reshaped_positions, 
    depth=agent.env_grid_length, axis=-1
)
logits = agent.policy.predict(encoded_positions)[1]
logits



array([[ 0.36606342, -0.50462985,  0.89018464,  0.21219297, -0.3064791 ,
         0.51369846],
       [-0.06173432, -0.24918589,  0.33766556,  0.19875363,  0.05723034,
         0.36787444],
       [ 0.1182244 , -0.35777742,  0.4566329 ,  0.17104968, -0.29075134,
         0.6330337 ],
       [ 0.24353316,  0.15930244,  0.382346  ,  0.19979723,  0.59276474,
         0.16319683],
       [ 0.21388142, -0.36071864,  0.47404975,  0.21301773, -0.46088445,
         0.79795426],
       [ 0.03015062,  0.50287676,  0.16799748,  0.2246193 ,  0.7768584 ,
         0.5853298 ],
       [ 0.5263218 , -0.01107349,  0.50734246,  0.19541554,  0.11375482,
         0.6073983 ],
       [ 0.10943656,  0.36290202,  0.2654329 ,  0.29338276,  0.71696436,
         0.6251638 ],
       [ 0.31811303,  0.13062839,  0.4559734 ,  0.17071219,  0.1198009 ,
         0.46686274],
       [-0.02628805,  0.3948001 ,  0.32909453,  0.05131832,  0.6474758 ,
         0.7828636 ],
       [ 0.22249931,  0.49489203,  0.2441214 ,  0.

In [5]:
forward_coord_mask = tf.math.equal(reshaped_positions, agent.env_grid_length-1)
masked_logits = agent._mask_action_logits(logits, forward_coord_mask)
log_softmax = tf.nn.log_softmax(masked_logits)

In [10]:
action_log_probas = tf.reduce_sum(
    (
        tf.reshape(log_softmax, shape=forward_actions.shape)
        * tf.cast(forward_actions, dtype=tf.float32)
    ), 
    axis=2, keepdims=True
)
action_log_probas

<tf.Tensor: shape=(8, 2, 1), dtype=float32, numpy=
array([[[-0.9799405],
        [-0.8687525]],

       [[-1.5592407],
        [-1.234432 ]],

       [[-2.2567647],
        [-1.3098645]],

       [[-1.6344278],
        [-1.3623778]],

       [[-1.9488367],
        [-2.0312302]],

       [[-1.7943294],
        [-1.9450111]],

       [[-2.1282582],
        [-1.4863207]],

       [[ 0.       ],
        [-1.9877853]]], dtype=float32)>

In [9]:
forward_actions

<tf.Tensor: shape=(8, 2, 6), dtype=int32, numpy=
array([[[0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 1]],

       [[0, 0, 0, 1, 0, 0],
        [0, 0, 1, 0, 0, 0]],

       [[0, 0, 0, 0, 1, 0],
        [0, 0, 0, 0, 1, 0]],

       [[0, 0, 1, 0, 0, 0],
        [0, 0, 0, 0, 1, 0]],

       [[0, 1, 0, 0, 0, 0],
        [0, 0, 0, 1, 0, 0]],

       [[0, 0, 1, 0, 0, 0],
        [1, 0, 0, 0, 0, 0]],

       [[0, 0, 0, 1, 0, 0],
        [0, 0, 1, 0, 0, 0]],

       [[0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0]]], dtype=int32)>

In [10]:
chosen_log_softmax = tf.reduce_sum(log_softmax*tf.reshape(tf.cast(back_actions, dtype=tf.float32), shape=(-1, 5)), axis=1, keepdims=True)
chosen_log_softmax

<tf.Tensor: shape=(14, 1), dtype=float32, numpy=
array([[-0.91850126],
       [-1.6424545 ],
       [-1.4310194 ],
       [-1.9285947 ],
       [-0.9301102 ],
       [-1.8424559 ],
       [-1.0008314 ],
       [-0.5848421 ],
       [-0.44553822],
       [-0.7281079 ],
       [ 0.        ],
       [-0.23103859],
       [ 0.        ],
       [ 0.        ]], dtype=float32)>

In [11]:
tf.stack(tf.split(chosen_log_softmax, back_actions.shape[0]))

<tf.Tensor: shape=(7, 2, 1), dtype=float32, numpy=
array([[[-0.91850126],
        [-1.6424545 ]],

       [[-1.4310194 ],
        [-1.9285947 ]],

       [[-0.9301102 ],
        [-1.8424559 ]],

       [[-1.0008314 ],
        [-0.5848421 ]],

       [[-0.44553822],
        [-0.7281079 ]],

       [[ 0.        ],
        [-0.23103859]],

       [[ 0.        ],
        [ 0.        ]]], dtype=float32)>

In [28]:
tf.split(log_softmax, back_actions.shape[0])

[<tf.Tensor: shape=(2, 5), dtype=float32, numpy=
 array([[-1.5279682e+00, -2.0923474e+00, -6.4403170e-01, -1.0000000e+10,
         -2.0065184e+00],
        [-1.0000000e+10, -9.5508361e-01, -1.0000000e+10, -1.0000000e+10,
         -4.8577532e-01]], dtype=float32)>,
 <tf.Tensor: shape=(2, 5), dtype=float32, numpy=
 array([[-1.0000000e+10, -1.4237180e+00, -5.3001684e-01, -1.0000000e+10,
         -1.7685046e+00],
        [-1.0000000e+10, -9.4627082e-01, -1.0000000e+10, -1.0000000e+10,
         -4.9132681e-01]], dtype=float32)>,
 <tf.Tensor: shape=(2, 5), dtype=float32, numpy=
 array([[-1.000000e+10, -1.232618e+00, -7.472019e-01, -1.000000e+10,
         -1.449100e+00],
        [-1.000000e+10, -1.000000e+10, -1.000000e+10, -1.000000e+10,
          0.000000e+00]], dtype=float32)>,
 <tf.Tensor: shape=(2, 5), dtype=float32, numpy=
 array([[-1.0000000e+10, -4.7735494e-01, -1.0000000e+10, -1.0000000e+10,
         -9.6869588e-01],
        [-1.6094379e+00, -1.6094379e+00, -1.6094379e+00, -1.6094379

In [4]:
# Remove if already at origin for back sampling

current_position = tf.concat([current_position, 
                              tf.constant([
                                  [2, 0, 0, 0, 0],
#                                   [0, 0, 0, 0, 0]
                              ])
                             ], axis=0)
current_position

<tf.Tensor: shape=(3, 5), dtype=int32, numpy=
array([[1, 3, 2, 6, 6],
       [5, 5, 3, 5, 3],
       [2, 0, 0, 0, 0]], dtype=int32)>

In [6]:
agent = Agent(grid_length)

trajectory = [current_position]
back_actions = []

for _ in range(4):
    back_action = agent.backward_action(current_position)
    current_position = backward_step(current_position, back_action)
    trajectory.append(current_position)
    back_actions.append(back_action)



In [7]:
trajectory

[<tf.Tensor: shape=(3, 5), dtype=int32, numpy=
 array([[1, 3, 2, 6, 6],
        [5, 5, 3, 5, 3],
        [2, 0, 0, 0, 0]], dtype=int32)>,
 <tf.Tensor: shape=(3, 5), dtype=int32, numpy=
 array([[0, 3, 2, 6, 6],
        [5, 4, 3, 5, 3],
        [1, 0, 0, 0, 0]], dtype=int32)>,
 <tf.Tensor: shape=(3, 5), dtype=int32, numpy=
 array([[0, 2, 2, 6, 6],
        [5, 4, 3, 5, 2],
        [0, 0, 0, 0, 0]], dtype=int32)>,
 <tf.Tensor: shape=(3, 5), dtype=int32, numpy=
 array([[0, 1, 2, 6, 6],
        [5, 4, 3, 4, 2],
        [0, 0, 0, 0, 0]], dtype=int32)>,
 <tf.Tensor: shape=(3, 5), dtype=int32, numpy=
 array([[0, 1, 2, 6, 5],
        [5, 4, 2, 4, 2],
        [0, 0, 0, 0, 0]], dtype=int32)>]

In [8]:
a = tf.stack(back_actions)
a

<tf.Tensor: shape=(4, 3, 5), dtype=int32, numpy=
array([[[1, 0, 0, 0, 0],
        [0, 1, 0, 0, 0],
        [1, 0, 0, 0, 0]],

       [[0, 1, 0, 0, 0],
        [0, 0, 0, 0, 1],
        [1, 0, 0, 0, 0]],

       [[0, 1, 0, 0, 0],
        [0, 0, 0, 1, 0],
        [0, 0, 0, 0, 0]],

       [[0, 0, 0, 0, 1],
        [0, 0, 1, 0, 0],
        [0, 0, 0, 0, 0]]], dtype=int32)>

In [14]:
mask = tf.cast(tf.ones(shape=a.shape, dtype=tf.int32)*tf.reduce_sum(a, axis=2, keepdims=True), dtype=tf.bool)

In [16]:
b = tf.ragged.boolean_mask(a, mask)

In [17]:
b

<tf.RaggedTensor [[[1, 0, 0, 0, 0], [0, 1, 0, 0, 0], [1, 0, 0, 0, 0]],
 [[0, 1, 0, 0, 0], [0, 0, 0, 0, 1], [1, 0, 0, 0, 0]],
 [[0, 1, 0, 0, 0], [0, 0, 0, 1, 0], []],
 [[0, 0, 0, 0, 1], [0, 0, 1, 0, 0], []]]>

In [9]:
new_position = backward_step(current_position, back_actions)
new_position

<tf.Tensor: shape=(11, 5), dtype=int32, numpy=
array([[5, 3, 2, 1, 1],
       [4, 5, 5, 4, 4],
       [1, 2, 2, 0, 5],
       [3, 2, 0, 6, 0],
       [3, 4, 5, 3, 3],
       [0, 1, 5, 2, 1],
       [2, 0, 4, 4, 0],
       [6, 6, 0, 0, 4],
       [1, 6, 1, 4, 6],
       [0, 3, 0, 5, 6],
       [1, 0, 0, 0, 0]], dtype=int32)>

In [10]:
back_actions = agent.backward_action(new_position)
back_actions

tf.Tensor(
[[0 0 0 0 1]
 [0 0 0 1 0]
 [0 1 0 0 0]
 [1 0 0 0 0]
 [0 0 0 1 0]
 [0 0 0 1 0]
 [0 0 1 0 0]
 [0 1 0 0 0]
 [0 1 0 0 0]
 [0 0 0 0 1]
 [1 0 0 0 0]], shape=(11, 5), dtype=int32)
tf.Tensor(
[[1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]], shape=(11, 1), dtype=int32)


<tf.Tensor: shape=(11, 5), dtype=int32, numpy=
array([[0, 0, 0, 0, 1],
       [0, 0, 0, 1, 0],
       [0, 1, 0, 0, 0],
       [1, 0, 0, 0, 0],
       [0, 0, 0, 1, 0],
       [0, 0, 0, 1, 0],
       [0, 0, 1, 0, 0],
       [0, 1, 0, 0, 0],
       [0, 1, 0, 0, 0],
       [0, 0, 0, 0, 1],
       [1, 0, 0, 0, 0]], dtype=int32)>

In [11]:
new_position = backward_step(new_position, back_actions)
new_position

<tf.Tensor: shape=(11, 5), dtype=int32, numpy=
array([[5, 3, 2, 1, 0],
       [4, 5, 5, 3, 4],
       [1, 1, 2, 0, 5],
       [2, 2, 0, 6, 0],
       [3, 4, 5, 2, 3],
       [0, 1, 5, 1, 1],
       [2, 0, 3, 4, 0],
       [6, 5, 0, 0, 4],
       [1, 5, 1, 4, 6],
       [0, 3, 0, 5, 5],
       [0, 0, 0, 0, 0]], dtype=int32)>

In [12]:
back_actions = agent.backward_action(new_position)
back_actions

tf.Tensor(
[[0 0 0 1 0]
 [0 0 0 1 0]
 [0 0 1 0 0]
 [1 0 0 0 0]
 [0 0 1 0 0]
 [0 0 0 1 0]
 [0 0 1 0 0]
 [0 1 0 0 0]
 [1 0 0 0 0]
 [0 1 0 0 0]
 [0 0 1 0 0]], shape=(11, 5), dtype=int32)
tf.Tensor(
[[1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]], shape=(11, 1), dtype=int32)


<tf.Tensor: shape=(11, 5), dtype=int32, numpy=
array([[0, 0, 0, 1, 0],
       [0, 0, 0, 1, 0],
       [0, 0, 1, 0, 0],
       [1, 0, 0, 0, 0],
       [0, 0, 1, 0, 0],
       [0, 0, 0, 1, 0],
       [0, 0, 1, 0, 0],
       [0, 1, 0, 0, 0],
       [1, 0, 0, 0, 0],
       [0, 1, 0, 0, 0],
       [0, 0, 0, 0, 0]], dtype=int32)>

In [13]:
new_position = backward_step(new_position, back_actions)
new_position

<tf.Tensor: shape=(11, 5), dtype=int32, numpy=
array([[5, 3, 2, 0, 0],
       [4, 5, 5, 2, 4],
       [1, 1, 1, 0, 5],
       [1, 2, 0, 6, 0],
       [3, 4, 4, 2, 3],
       [0, 1, 5, 0, 1],
       [2, 0, 2, 4, 0],
       [6, 4, 0, 0, 4],
       [0, 5, 1, 4, 6],
       [0, 2, 0, 5, 5],
       [0, 0, 0, 0, 0]], dtype=int32)>

In [14]:
back_actions = agent.backward_action(new_position)
back_actions

tf.Tensor(
[[0 0 1 0 0]
 [0 0 0 0 1]
 [0 0 1 0 0]
 [1 0 0 0 0]
 [0 0 0 1 0]
 [0 0 0 0 1]
 [0 0 0 1 0]
 [0 0 0 0 1]
 [0 1 0 0 0]
 [0 0 0 0 1]
 [0 1 0 0 0]], shape=(11, 5), dtype=int32)
tf.Tensor(
[[1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [1]
 [0]], shape=(11, 1), dtype=int32)


<tf.Tensor: shape=(11, 5), dtype=int32, numpy=
array([[0, 0, 1, 0, 0],
       [0, 0, 0, 0, 1],
       [0, 0, 1, 0, 0],
       [1, 0, 0, 0, 0],
       [0, 0, 0, 1, 0],
       [0, 0, 0, 0, 1],
       [0, 0, 0, 1, 0],
       [0, 0, 0, 0, 1],
       [0, 1, 0, 0, 0],
       [0, 0, 0, 0, 1],
       [0, 0, 0, 0, 0]], dtype=int32)>

In [15]:
new_position = backward_step(new_position, back_actions)
new_position

<tf.Tensor: shape=(11, 5), dtype=int32, numpy=
array([[5, 3, 1, 0, 0],
       [4, 5, 5, 2, 3],
       [1, 1, 0, 0, 5],
       [0, 2, 0, 6, 0],
       [3, 4, 4, 1, 3],
       [0, 1, 5, 0, 0],
       [2, 0, 2, 3, 0],
       [6, 4, 0, 0, 3],
       [0, 4, 1, 4, 6],
       [0, 2, 0, 5, 4],
       [0, 0, 0, 0, 0]], dtype=int32)>

In [10]:
encoded = tf.one_hot(current_position, depth=grid_length, axis=-1)
action_logits = policy.predict(encoded)
action_logits



array([[-6.4809972e-01,  5.3084195e-01,  2.4392086e-01, -5.3548342e-01,
        -4.3538299e-01],
       [-1.1343048e-01, -1.7665544e-01, -1.6544846e-01,  3.6007795e-01,
         3.9295644e-01],
       [ 3.7668648e-01,  3.0121180e-01,  1.9785663e-01,  3.6049068e-02,
        -1.0613856e+00],
       [-4.0208122e-01, -3.5049197e-01,  2.6646286e-01,  3.1395680e-01,
         1.4197114e-01],
       [ 3.0549276e-01, -3.1088769e-02,  2.2941329e-01, -8.6708754e-02,
        -6.2300704e-02],
       [-1.2447351e+00,  4.9847800e-01,  6.4200854e-01,  3.6725062e-01,
         4.3760002e-01],
       [-2.8469381e-01,  8.6022317e-02,  8.6110282e-01,  4.6994925e-02,
        -2.1885723e-02],
       [-1.2475082e-01, -4.0141165e-01,  2.9444158e-02, -1.9728884e-01,
         2.9799452e-01],
       [-5.1271170e-04,  2.3457837e-01, -5.1820165e-01, -4.6456134e-01,
        -1.1830505e+00],
       [ 2.8238037e-01,  1.9005036e-01, -2.9617256e-01,  2.3690698e-01,
        -4.1451451e-01]], dtype=float32)

# Mask logits to remove out of bounds actions

In [11]:
current_position = tf.concat([current_position, tf.constant([[0, 0, 0, 0, 0]])], axis=0)
current_position

<tf.Tensor: shape=(12, 5), dtype=int32, numpy=
array([[5, 5, 2, 2, 4],
       [1, 0, 1, 5, 5],
       [2, 6, 6, 2, 4],
       [6, 4, 2, 6, 6],
       [0, 2, 6, 1, 3],
       [3, 6, 5, 1, 4],
       [1, 3, 0, 1, 2],
       [6, 6, 0, 1, 1],
       [3, 5, 5, 2, 2],
       [2, 6, 6, 6, 5],
       [2, 0, 0, 0, 0],
       [0, 0, 0, 0, 0]], dtype=int32)>

In [12]:
# Back Sampling

back_coord_mask = tf.math.equal(current_position, 0)
avoid_inds = tf.where(back_coord_mask)

In [14]:
back_coord_mask

<tf.Tensor: shape=(12, 5), dtype=bool, numpy=
array([[False, False, False, False, False],
       [False,  True, False, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [ True, False, False, False, False],
       [False, False, False, False, False],
       [False, False,  True, False, False],
       [False, False,  True, False, False],
       [False, False, False, False, False],
       [False, False, False, False, False],
       [False,  True,  True,  True,  True],
       [ True,  True,  True,  True,  True]])>

In [20]:
tf.cast(tf.math.reduce_all(back_coord_mask, axis=1, keepdims=True), dtype=tf.int32)

<tf.Tensor: shape=(12, 1), dtype=int32, numpy=
array([[0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [0],
       [1]], dtype=int32)>

In [13]:
neg_inf = -10000000000.0
masked_logits = tf.tensor_scatter_nd_add(
    action_logits, avoid_inds, tf.constant([neg_inf]*avoid_inds.shape[0])
)
masked_logits

NameError: name 'action_logits' is not defined

In [27]:
actions = tf.random.categorical(masked_logits, 1)
actions

<tf.Tensor: shape=(10, 1), dtype=int64, numpy=
array([[2],
       [1],
       [3],
       [0],
       [0],
       [2],
       [1],
       [4],
       [0],
       [1]])>

In [34]:
encoded_actions = tf.one_hot(tf.reshape(actions, shape=(-1,)), depth=5)
encoded_actions

<tf.Tensor: shape=(10, 5), dtype=float32, numpy=
array([[0., 0., 1., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 0., 1., 0.],
       [1., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 0., 0., 1.],
       [1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.]], dtype=float32)>

In [26]:
# gradient test

x = tf.Variable(1.0)
a = 1
with tf.GradientTape() as tape:
    y = a*tf.math.log(x)
    
dy_dx = tape.gradient(y, x)
dy_dx

<tf.Tensor: shape=(), dtype=float32, numpy=1.0>