Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 12 additions & 14 deletions tensorflow/python/data/experimental/ops/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,14 @@ def __init__(self, name, dtype=None, ragged_rank=None, shape=None):
self._ragged_rank = ragged_rank
if shape:
shape = tensor_shape.TensorShape(shape)
shape_rank = 0
for _ in shape:
shape_rank += 1
if ragged_rank is not None and ragged_rank != shape_rank:
for d in shape:
if d.value is None:
raise ValueError(
f'Field {name} has incomplete shape: {shape}')
if ragged_rank is not None and ragged_rank > 1:
raise ValueError(
f'Field {name} is a nested list ({ragged_rank}) '
f'with shape {shape}')
self._ragged_rank = shape_rank
elif ragged_rank is not None:
shape = tensor_shape.TensorShape([None for _ in xrange(ragged_rank)])

self._shape = shape

@property
Expand Down Expand Up @@ -134,16 +131,17 @@ def output_classes(self):
def output_types(self):
return self.map(lambda i: self._dtype if i == 0 else dtypes.int32)

def output_shapes(self, batch_size=None):
@property
def output_shapes(self):
if self._shape is None:
return self.map(lambda i: tensor_shape.vector(batch_size) if i == 0
else tensor_shape.vector(None))
return self.map(lambda _: tensor_shape.vector(None))
return self.map(
lambda i: tensor_shape.vector(batch_size).concatenate(self._shape) if i == 0
lambda i: tensor_shape.vector(None).concatenate(self._shape) if i == 0
else tensor_shape.vector(None))

def output_specs(self, batch_size=None):
shape = tensor_shape.vector(batch_size)
@property
def output_specs(self):
shape = tensor_shape.vector(None)
if self._shape is not None:
shape = shape.concatenate(self._shape)
specs = [tensor_spec.TensorSpec(shape, dtype=self._dtype)]
Expand Down
28 changes: 16 additions & 12 deletions tensorflow/python/data/experimental/ops/parquet_dataset_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from tensorflow.python.data.ops import readers
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_spec
from tensorflow.python.framework import type_spec
from tensorflow.python.util import nest
Expand All @@ -38,25 +39,23 @@ class DataFrameValueSpec(type_spec.BatchableTypeSpec):
def value_type(self):
return DataFrame.Value if self._ragged_rank > 0 else ops.Tensor

def __init__(self, field, batch_size=None):
def __init__(self, field):
"""Constructs a type specification for a `tf.RaggedTensor`.

Args:
field: The field definition.
batch_size: The batch_size of DataFrame.
"""
if field.incomplete:
raise ValueError(
f'Field {field} is incomplete, please specify dtype and ragged_rank')
self._field = field
self._batch_size = batch_size

def _serialize(self):
return (self._field.dtype, self._field.ragged_rank)

@property
def _component_specs(self):
return self._field.output_specs(self._batch_size)
return self._field.output_specs

def _to_components(self, value):
if isinstance(value, DataFrame.Value):
Expand All @@ -80,7 +79,7 @@ def _to_legacy_output_types(self):
return self._field.output_types

def _to_legacy_output_shapes(self):
return self._field.output_shapes(self._batch_size)
return self._field.output_shapes

def _to_legacy_output_classes(self):
return self._field.output_classes
Expand Down Expand Up @@ -110,13 +109,18 @@ def __init__(
self._batch_size = ops.convert_to_tensor(
batch_size, dtype=dtypes.int64, name='batch_size')
self._fields = fields
self._output_specs = {
f.name: (
DataFrameValueSpec(f, batch_size if drop_remainder else None)
if f.ragged_rank > 0
else tensor_spec.TensorSpec(
shape=[batch_size if drop_remainder else None], dtype=f.dtype))
for f in self._fields}
self._output_specs = {}
for f in self._fields:
item = None
if f.ragged_rank > 0:
item = DataFrameValueSpec(f)
else:
shape = tensor_shape.vector(batch_size if drop_remainder else None)
if f.shape:
shape = shape.concatenate(f.shape)
item = tensor_spec.TensorSpec(shape=shape, dtype=f.dtype)
self._output_specs[f.name] = item

self._field_names = nest.flatten({f.name: f.name for f in self._fields})
self._field_dtypes = nest.flatten({f.name: f.dtype for f in self._fields})
self._field_ragged_ranks = nest.flatten(
Expand Down