In [93]:
import random
from typing import Tuple
import tensorflow as tf

# test_tensor = tf.constant([1,2,3,0,1,1,0,3,3,0,0,0], shape=(2,2,3), dtype="float32")
# print(test_tensor)
# indices = tf.constant([[0,0]])
# updates = tf.constant([[9,9,9]], dtype="float32")
# test_tensor = tf.tensor_scatter_nd_update(test_tensor,indices,updates)
# print(test_tensor)


def parse_example(raw):
    example = tf.io.parse_single_sequence_example(
          raw, sequence_features={
              "categories": tf.io.FixedLenSequenceFeature([], tf.int64),
              "features": tf.io.FixedLenSequenceFeature(2048, tf.float32)
          })
    return example[1]["features"], example[1]["categories"]


def get_dataset(filenames):
    raw_dataset = tf.data.TFRecordDataset(filenames)
    return raw_dataset.map(parse_example)


def append_targets(features, categories, token_pos):
    return (features, categories, token_pos), features


def add_special_token_positions(features, categories):
    seq_length = tf.shape(categories)[0]
    random_position = tf.random.uniform((1,), minval=0, maxval=seq_length, dtype="int32")
    token_positions = tf.expand_dims(random_position, 0)
    return features, categories, token_positions
    
    
def get_training_dataset(filenames, batch_size):
    outfits = get_dataset(filenames)
    outfits = outfits.map(add_special_token_positions)
    outfits = outfits.padded_batch(batch_size, ([None, 2048], [None], [None,1]))
    return outfits.map(append_targets)


def replace_slices(batch, batched_positions, updates):
    """

    Args:
        batch: Tensor of shape [batch_size, seq_length, feature_dim]
        batched_positions: Tensor of shape [batch_size, None, 1] corresponding to indices of sequence
        updates: Tensor of shape [total_number_of_indices, feature_dim]

    Returns:
        Tensor with the same shape as batch. Defined slices are replaced
    """
    batch_size = tf.shape(batched_positions)[0]
    r = tf.range(0, limit=batch_size, dtype="int32")
    r = tf.reshape(r, shape=[tf.shape(r)[0], -1, 1])
    indices = tf.squeeze(tf.concat([r, batched_positions], axis=-1))
    print(indices)
    return tf.tensor_scatter_nd_update(batch, indices, updates)


batch = tf.range(0, 24)
batch = tf.reshape(batch, [2, 3, 4])
b_pos = tf.constant([[[0]], [[2]]])
up = tf.constant([[8, 8, 8, 8], [9, 9, 9, 9]])
print(batch)
res = replace_slices(batch, b_pos, up)
print(res)


tf.Tensor(
[[[ 0  1  2  3]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [20 21 22 23]]], shape=(2, 3, 4), dtype=int32)
tf.Tensor(
[[0 0]
 [1 2]], shape=(2, 2), dtype=int32)
tf.Tensor(
[[[ 8  8  8  8]
  [ 4  5  6  7]
  [ 8  9 10 11]]

 [[12 13 14 15]
  [16 17 18 19]
  [ 9  9  9  9]]], shape=(2, 3, 4), dtype=int32)


In [None]:
dataset = get_training_dataset(["output-000-5.tfrecord"], 2)

# for x, y in dataset:
#     zeroes = tf.meshgrid(tf.range(tf.shape(x[0])[0]), tf.range(tf.shape(x[0])[0]))
#     mask_positions = x[2]
#     zero_slices = tf.gather_nd(zeroes, x[2], batch_dims=1)
#     updates = tf.reshape(zero_slices, shape=(-1, 2048))
#     print(x[2])
#     r = tf.range(0, limit=tf.shape(x[2])[0], dtype="int32")
#     r = tf.reshape(r, shape=[r.shape[0], -1, 1])
#     print(r)
#     indices = tf.squeeze(tf.concat([r, x[2]], axis=-1))
#     print(indices)
#     # indices = tf.constant([[0,1], [1, 6]]) 
#     replaced = tf.tensor_scatter_nd_update(x[0], indices, updates)
#     print(replaced)
# 
#     r = tf.range(0, limit=tf.shape(mask_positions)[0], dtype="int64")
#     r = tf.reshape(r, shape=[tf.shape(r)[0], -1, 1])
#     mask_positions = tf.cast(mask_positions, dtype="int64")
#     indices = tf.squeeze(tf.concat([r, mask_positions], axis=-1))
#     weights = tf.sparse.to_dense(tf.SparseTensor(indices, 1, dense_shape=tf.cast(tf.shape(y), dtype="int64")))
#     weights = tf.reshape(weights, [-1])
#     print(weights)
    
    
    # r = tf.range(0, limit=tf.shape(mask_positions)[0])
    # r = tf.reshape(r, shape=[tf.shape(r)[0], -1, 1])
    # indices = tf.squeeze(tf.concat([r, mask_positions], axis=-1))
    # print(indices)
    # updates = tf.ones(shape=(tf.shape(mask_positions)[0]))
    # print(updates)
    # weights = tf.scatter_nd(indices, updates, tf.shape(x[1]))
    # weights = tf.cast(weights, dtype="float32")
    # weights = tf.reshape(weights, [-1])
    # length = tf.shape(weights)[0]
    # weights = tf.tile(weights, [length])
    # weights = tf.reshape(weights, [-1, length])
    # weights = tf.transpose(weights)
    # print(weights)
    