In [1]:
import tensorflow as tf

In [2]:
print(tf.__version__)

2.17.0


## Creating Datasets from slices of tensor or ragged tensor

Here each element of a dataset will be a slice of a tensor or ragged tensor. The slices are created by slicing the tensor or ragged tensor along the first dimension. The slices are then converted to tensors and added to the dataset.

In [17]:
l = []
l.append(tf.constant([1, 2], dtype=tf.float32))
l.append(tf.constant([1, 2, 3], dtype=tf.float32))
l.append(tf.constant([1, 2, 3, 4], dtype=tf.float32))
l.append(tf.constant([1, 2, 3, 4, 5], dtype=tf.float32))
l.append(tf.constant([1, 2, 3, 4, 5, 6], dtype=tf.float32))

len_list = [len(elem) for elem in l]

In [19]:
d = tf.data.Dataset.from_tensor_slices(tf.RaggedTensor.from_row_lengths(tf.concat(l,axis=0),len_list))

In [35]:
d

<_TensorSliceDataset element_spec=RaggedTensorSpec(TensorShape([None]), tf.float32, 0, tf.int64)>

In [20]:
for elem in d:
    print(elem)

tf.Tensor([1. 2.], shape=(2,), dtype=float32)
tf.Tensor([1. 2. 3.], shape=(3,), dtype=float32)
tf.Tensor([1. 2. 3. 4.], shape=(4,), dtype=float32)
tf.Tensor([1. 2. 3. 4. 5.], shape=(5,), dtype=float32)
tf.Tensor([1. 2. 3. 4. 5. 6.], shape=(6,), dtype=float32)


In [43]:
d

<_TensorSliceDataset element_spec=RaggedTensorSpec(TensorShape([None]), tf.float32, 0, tf.int64)>

In [23]:
dbatch = d.ragged_batch(2)
for elem in dbatch:
    print(elem)

<tf.RaggedTensor [[1.0, 2.0], [1.0, 2.0, 3.0]]>
<tf.RaggedTensor [[1.0, 2.0, 3.0, 4.0], [1.0, 2.0, 3.0, 4.0, 5.0]]>
<tf.RaggedTensor [[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]]>


In [32]:
l2 = []
l2.append(tf.constant([1,2]))
l2.append(tf.constant([3,4]))
l2.append(tf.constant([5,6]))
l2.append(tf.constant([7,8]))
l2.append(tf.constant([9,10]))
l2.append(tf.constant([11,12,13]))
len_list = [len(elem) for elem in l2]
d2 = tf.data.Dataset.from_tensor_slices(tf.RaggedTensor.from_row_lengths(tf.concat(l2,axis=0),len_list))

In [42]:
d2batch = d2.batch(2)
d2batch

<_BatchDataset element_spec=RaggedTensorSpec(TensorShape([None, None]), tf.int32, 1, tf.int64)>

In [46]:
for elem in d2batch:
    print(tf.keras.utils.pad_sequences(elem.numpy(), padding='pre'))

[[1 2]
 [3 4]]
[[5 6]
 [7 8]]
[[ 0  9 10]
 [11 12 13]]


In [52]:
itr = iter(d2batch)
ret = [next(itr), next(itr), next(itr)]

In [53]:
for i in ret[2]:
    print(i)

tf.Tensor([ 9 10], shape=(2,), dtype=int32)
tf.Tensor([11 12 13], shape=(3,), dtype=int32)


## Constructing padded batch datset from the ragged batch dataset
Using dataset.from_generator() method, we can create a dataset from a generator. The generator yields padded batches

In [73]:
def padded_batch_generator(dataset):
    def func():
        for elem in dataset:
            yield tf.keras.utils.pad_sequences(elem.numpy(), padding='pre')
    return func

d2_padded_batch = tf.data.Dataset.from_generator(padded_batch_generator(d2batch), output_signature=tf.TensorSpec(shape=(None,None,), dtype=tf.float32))

In [78]:
for i,elem in enumerate(d2_padded_batch):
    print(f'Tensor {i} is:')
    print(elem)
    print('\n')

Tensor 0 is:
tf.Tensor(
[[1. 2.]
 [3. 4.]], shape=(2, 2), dtype=float32)


Tensor 1 is:
tf.Tensor(
[[5. 6.]
 [7. 8.]], shape=(2, 2), dtype=float32)


Tensor 2 is:
tf.Tensor(
[[ 0.  9. 10.]
 [11. 12. 13.]], shape=(2, 3), dtype=float32)




In [80]:
d2_padded_batch_2 = d2batch.map(lambda elem: tf.keras.utils.pad_sequences(elem.numpy(), padding='pre'))

ValueError: in user code:

    File "C:\Users\anike\AppData\Local\Temp\ipykernel_8100\535097259.py", line 1, in None  *
        lambda elem: tf.keras.utils.pad_sequences(elem.numpy(), padding='pre')

    ValueError: RaggedTensor.numpy() is only supported in eager mode.


In [81]:
l3 = []
l3.append(tf.constant([1,2]))
l3.append(tf.constant([3,4]))
l3.append(tf.constant([5,6]))
l3.append(tf.constant([7,8]))
l3.append(tf.constant([9,10]))
l3.append(tf.constant([11,12]))
d3 = tf.data.Dataset.from_tensor_slices(l3)

In [86]:
ret = d3.batch(d3.cardinality()).get_single_element()
print(ret)

tf.Tensor(
[[ 1  2]
 [ 3  4]
 [ 5  6]
 [ 7  8]
 [ 9 10]
 [11 12]], shape=(6, 2), dtype=int32)


In [None]:
tf.keras.layers.StringLookup()