In [60]:
import random
from typing import Tuple
import tensorflow as tf

# test_tensor = tf.constant([1,2,3,0,1,1,0,3,3,0,0,0], shape=(2,2,3), dtype="float32")
# print(test_tensor)
# indices = tf.constant([[0,0]])
# updates = tf.constant([[9,9,9]], dtype="float32")
# test_tensor = tf.tensor_scatter_nd_update(test_tensor,indices,updates)
# print(test_tensor)


def parse_example(raw):
    example = tf.io.parse_single_sequence_example(
          raw, sequence_features={
              "categories": tf.io.FixedLenSequenceFeature([], tf.int64),
              "features": tf.io.FixedLenSequenceFeature(2048, tf.float32)
          })
    return example[1]["features"], example[1]["categories"]


def get_dataset(filenames):
    raw_dataset = tf.data.TFRecordDataset(filenames)
    return raw_dataset.map(parse_example)


def append_targets(features, categories, token_pos):
    return (features, categories, token_pos), features


def add_special_token_positions(features, categories):
    seq_length = tf.shape(categories)[0]
    random_position = tf.random.uniform((1,), minval=0, maxval=seq_length, dtype="int32")
    token_positions = tf.expand_dims(random_position, 0)
    return features, categories, token_positions
    
    
def get_training_dataset(filenames, batch_size):
    outfits = get_dataset(filenames)
    outfits = outfits.map(add_special_token_positions)
    outfits = outfits.padded_batch(batch_size, ([None, 2048], [None], [None,1]))
    return outfits.map(append_targets)



dataset = get_training_dataset(["output-000-5.tfrecord"], 2)
for x, y in dataset:
    zeroes = tf.zeros_like(x[0])
    zero_slices = tf.gather_nd(zeroes, x[2], batch_dims=1)
    updates = tf.reshape(zero_slices, shape=(-1, 2048))
    print(x[2])
    r = tf.range(0, limit=tf.shape(x[2])[0], dtype="int32")
    r = tf.reshape(r, shape=[r.shape[0], -1, 1])
    print(r)
    indices = tf.squeeze(tf.concat([r, x[2]], axis=-1))
    print(indices)
    # indices = tf.constant([[0,1], [1, 6]]) 
    replaced = tf.tensor_scatter_nd_update(x[0], indices, updates)
    print(replaced)
    
    

tf.Tensor(
[[[4]]

 [[6]]], shape=(2, 1, 1), dtype=int32)
tf.Tensor(
[[[0]]

 [[1]]], shape=(2, 1, 1), dtype=int32)
tf.Tensor(
[[0 4]
 [1 6]], shape=(2, 2), dtype=int32)
tf.Tensor(
[[[0.00544265 0.2782234  0.01296203 ... 0.03073396 0.02102465 0.60515   ]
  [0.10327241 0.59317654 0.23712789 ... 0.5107244  0.32906044 0.2647133 ]
  [0.04759569 0.24221316 0.03452357 ... 0.4702158  0.04748438 0.21493322]
  ...
  [0.80003095 0.6951431  0.6862365  ... 0.28535375 0.02023683 0.32046562]
  [1.0514071  0.         0.         ... 0.04131878 0.7136081  2.0880868 ]
  [0.         0.         0.         ... 0.         0.         0.        ]]

 [[0.10257128 0.05374552 0.2554618  ... 0.4756894  0.22387642 1.6893127 ]
  [0.05468541 0.03893444 0.03038425 ... 0.1826662  0.13914062 0.36530462]
  [0.02581285 0.28361136 0.4748394  ... 0.407583   0.03706785 0.02978181]
  ...
  [0.42446607 0.17985591 0.07060237 ... 0.19590387 0.10340778 1.5159997 ]
  [0.         0.         0.         ... 0.         0.         0. 