In [7]:
import tensorflow as tf

In [8]:
shuffle = True #called when running TFRecordBatchDiagram on split==C.DATA_TRAIN
seed = 1234 #Dataset class
num_parallel_calls = 4
batch_size = 128

## filtering options
min_length_threshold = 4 #default value defined in TFRecordStroke that is inherited by TFRecordDiagram
max_length_threshold = 201 #defined and default value in config.data
num_strokes_threshold = 4 #minimum number of strokes, default value defined in TFRecordStroke that is inherited by TFRecordDiagram

#RUN DOES NOT INCLUDE RDP, NOR RDP_DIDI_PP
normalize = True

In [9]:
#tf_data_transformations

# Transformation

In [10]:
tf_data_ = tf.data.TFRecordDataset.list_files(['/data/jcabrera/didi_wo_text/training/diagrams_wo_text_20200131-?????-of-?????'], seed = 1234, shuffle = True)

In [11]:
tf_data_.element_spec

TensorSpec(shape=(), dtype=tf.string, name=None)

In [12]:
tf_data_ = tf_data_.interleave(tf.data.TFRecordDataset, cycle_length = num_parallel_calls, block_length = 1)

In [13]:
tf_data_.element_spec

TensorSpec(shape=(), dtype=tf.string, name=None)

In [14]:
import functools

In [15]:
  def parse_tfexample_fn(proto, rdp = True):
    """Parses a single tfrecord proto storing diagram sequence as strokes.
    Args:
      proto:
    Returns:
    """

    if rdp:
      feature_to_type = dict()
      feature_to_type["rdp_ink"] = tf.io.VarLenFeature(dtype=tf.float32)
      feature_to_type["rdp_stroke_length"] = tf.io.VarLenFeature(dtype=tf.int64)
      feature_to_type["rdp_num_strokes"] = tf.io.FixedLenFeature([], dtype=tf.int64)
      
      parsed_features = tf.io.parse_single_example(serialized=proto, features=feature_to_type)
      parsed_features["ink"] = tf.reshape(tf.sparse.to_dense(parsed_features["rdp_ink"]), (parsed_features["rdp_num_strokes"], -1, 4))
      parsed_features["stroke_length"] = tf.sparse.to_dense(parsed_features["rdp_stroke_length"])
      parsed_features["num_strokes"] = tf.tile(tf.expand_dims(parsed_features["rdp_num_strokes"], axis=0), [parsed_features["rdp_num_strokes"]])
    else:
      feature_to_type = {
          "ink": tf.io.VarLenFeature(dtype=tf.float32),
          "stroke_length": tf.io.VarLenFeature(dtype=tf.int64),
          "num_strokes": tf.io.FixedLenFeature([], dtype=tf.int64),
          # "shape": tf.FixedLenFeature([3], dtype=tf.int64),
          # "ink_hash": tf.FixedLenFeature([], dtype=tf.string),
      }
      parsed_features = tf.io.parse_single_example(serialized=proto, features=feature_to_type)
      parsed_features["ink"] = tf.reshape(tf.sparse.to_dense(parsed_features["ink"]), (parsed_features["num_strokes"], -1, 4))
      parsed_features["stroke_length"] = tf.sparse.to_dense(parsed_features["stroke_length"])
      parsed_features["num_strokes"] = tf.tile(tf.expand_dims(parsed_features["num_strokes"], axis=0), [parsed_features["num_strokes"]])
    
    return parsed_features

In [16]:
tf_data_ = tf_data_.map(
    functools.partial(parse_tfexample_fn),
    num_parallel_calls=num_parallel_calls)

In [17]:
tf_data_.element_spec

{'rdp_ink': SparseTensorSpec(TensorShape([None]), tf.float32),
 'rdp_stroke_length': SparseTensorSpec(TensorShape([None]), tf.int64),
 'rdp_num_strokes': TensorSpec(shape=(), dtype=tf.int64, name=None),
 'ink': TensorSpec(shape=(None, None, 4), dtype=tf.float32, name=None),
 'stroke_length': TensorSpec(shape=(None,), dtype=tf.int64, name=None),
 'num_strokes': TensorSpec(shape=(None,), dtype=tf.int64, name=None)}

In [18]:
tf_data_ = tf_data_.prefetch(batch_size*2) #boots performance by prestoring 2 batches

In [19]:
  def __pp_filter(sample):
    """Filters diagram samples.
    Works in batch mode. In other words, if an individual stroke of a diagram
    violates the conditions, then the entire diagram is discarded.
    Hence, the conditions should be relaxed.
    Args:
      sample:
    Returns:
    """
    has_strokes, is_long_enough = True, True
    if min_length_threshold > 0:
      is_long_enough = tf.math.greater(
          tf.reduce_min(input_tensor=sample["stroke_length"]), min_length_threshold)
    if max_length_threshold > 0:
      is_long_enough = tf.math.logical_and(
          is_long_enough,
          tf.math.less(
              tf.reduce_max(input_tensor=sample["stroke_length"]),
              max_length_threshold))
    if num_strokes_threshold > 0:
      has_strokes = (
          tf.shape(input=sample["num_strokes"])[0] > num_strokes_threshold)
    return tf.math.logical_and(has_strokes, is_long_enough)

In [20]:
tf_data_ = tf_data_.filter(functools.partial(__pp_filter)) #filtering for number of strokes, and max length of points in each stroke

In [21]:
import sys
i = 0
for sample_dict in tf_data_:
    sample = sample_dict['ink']
    if isinstance(sample, tf.Tensor):
        sample = sample.numpy()
    print (sample.shape)
    i+=1
    if i>5:
        sys.exit(0)

(13, 20, 4)
(7, 30, 4)
(9, 29, 4)
(6, 41, 4)
(5, 24, 4)
(7, 39, 4)


SystemExit: 0

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)


# Preprocessing

In [22]:
#CONFIG VARIABLES
affine_prob = 0.3#default
reverse_prob = 0#default
scale_factor = 0#default
pos_noise_factor = 0#default
pp_to_origin = True #in config
pp_relative_pos = False #inconfig
random_noise_factor = 0#default 
resampling_factor = 2#default
t_drop_ratio = 0#default
gt_targets = True

In [23]:
from common.constants import Constants as C

ModuleNotFoundError: No module named 'common'

In [24]:
def tf_preprocessing(tf_data_):
    if affine_prob > 0:
        tf_data_ = tf_data_.map(
          functools.partial(pp_random_affine_all),
          num_parallel_calls=num_parallel_calls)
    #Does this
    tf_data_ = tf_data_.map(
        functools.partial(set_start_end_coord),
        num_parallel_calls=num_parallel_calls)
    #Does this
    if pp_to_origin:
        tf_data_ = tf_data_.map(
          functools.partial(pp_translate_to_origin),
          num_parallel_calls=num_parallel_calls)
    if resampling_factor > 1:
        tf_data_ = tf_data_.map(
          functools.partial(pp_temporal_resampling),
          num_parallel_calls=num_parallel_calls)
    return tf_data_

In [25]:
  def pp_random_affine_all(sample):
    """Applies the same affine transformation to all strokes in a diagram."""
    
    rot_prob = affine_prob
    scale_prob = affine_prob
    flip_prob = affine_prob
    shear_prob = affine_prob/3.0
  
    n_strokes = tf.shape(input=sample["ink"])[0]
    
    # Rotation
    rot_angle = tf.random.uniform([1],
                                  minval=-np.pi/2,
                                  maxval=np.pi/2,
                                  dtype=tf.float32)
    rot_angle = tf.compat.v1.where(rot_prob > tf.random.uniform([1], maxval=1.0),
                         rot_angle,
                         tf.zeros_like(rot_angle))
    rot_angle = tf.tile(rot_angle, [n_strokes])
    
    # Scale
    scale_xy = tf.random.uniform([1],
                                 minval=0.5,
                                 maxval=2.5,
                                 dtype=tf.float32)
    scale_xy = tf.compat.v1.where(scale_prob > tf.random.uniform([1], maxval=1.0),
                        scale_xy,
                        tf.ones_like(scale_xy))
    
    # Flip around x, y or both.
    scale_x = tf.compat.v1.where(flip_prob > tf.random.uniform([1], maxval=1.0),
                       scale_xy*-1,
                       scale_xy)
    scale_y = tf.compat.v1.where(flip_prob > tf.random.uniform([1], maxval=1.0),
                       scale_xy*-1,
                       scale_xy)

    scale_x = tf.tile(scale_x, [n_strokes])
    scale_y = tf.tile(scale_y, [n_strokes])
    
    # Shear
    shear_xy = tf.random.uniform([1],
                                 minval=-0.3,
                                 maxval=0.3,
                                 dtype=tf.float32)
    shear_xy = tf.compat.v1.where(shear_prob > tf.random.uniform([1], maxval=1.0),
                        shear_xy,
                        tf.zeros_like(shear_xy))
    shear_xy = tf.tile(shear_xy, [n_strokes])
    
    # Apply affine.
    affine_ = apply_affine(sample["ink"][:, :, 0:2],
                                theta=rot_angle,
                                scale_x=scale_x,
                                scale_y=scale_y,
                                shear_x=shear_xy,
                                shear_y=shear_xy)
    augmented = tf.concat([affine_, sample["ink"][:, :, 2:]], axis=-1)
    sample["ink"] = augmented
    return sample

  def pp_temporal_resampling(sample):
    """Uniform re-sampling over time dimension."""
    print("Temporal resampling factor: ".format(resampling_factor))
    if gt_targets:
      print("Temporal resampling with original targets...")
    else:
      print("Temporal resampling the targets...")
    
    pen = sample["ink"][:, -1:]
    factor = tf.cast(
        tf.cond(pred=tf.reduce_max(sample["stroke_length"]) < 20,
                true_fn=lambda: 1,
                false_fn=lambda: resampling_factor // 2),
        dtype=tf.int32)

    factor = tf.cast(
        tf.cond(pred=tf.reduce_max(sample["stroke_length"]) > 100,
                true_fn=lambda: resampling_factor,
                false_fn=lambda: factor),
        dtype=tf.int64)
    
    if gt_targets:
      if "target_ink" not in sample:
        sample["target_ink"] = sample["ink"]
        sample["target_stroke_length"] = sample["stroke_length"]

    freq = factor
    # freq = tf.random_uniform([1], minval=1, maxval=factor + 1,
    #                          dtype=tf.int64)[0]
    
    sample["ink"] = sample["ink"][:, ::freq, :]
    sample["ink"].set_shape((None, None, 4))
    # We keep the pen event.
    sample["ink"] = tf.concat([sample["ink"][:, :-1], pen], axis=1)
    sample["stroke_length"] = tf.cast(
        tf.math.ceil(sample["stroke_length"] / freq), tf.int64)
    return sample

In [26]:
  def apply_affine(sample, theta=0.0, scale_x=1.0, scale_y=1.0,
                   shear_x=0.0, shear_y=0.0):
    """
    Affine transformation by applying scaling, rotation and shearing in order.
    The sample is a sequence of 2D points. If size of the transformation factors
    is equal to batch_size, then the operation runs in batch mode.
    The default values correspond to no transformation.
    Args:
      sample: (batch_size, seq_len, 2)
      theta: rotation angle in radians.
      scale_x: scale factor in x-axis.
      scale_y: scale factor in y-axis.
      shear_x: amount of shearing in x direction.
      shear_y: amount of shearing in y direction.
    Returns:
      Transformed sample.
    """
    rot_scale_mat = tf.stack([[scale_x*tf.cos(theta), -scale_y*tf.sin(theta)],
                              [scale_x*tf.sin(theta), scale_y*tf.cos(theta)]])
    rot_scale_mat = tf.transpose(a=tf.reshape(rot_scale_mat, [2, 2, -1]),
                                 perm=[2, 0, 1])
    
    shear_mat = tf.stack(
        [[tf.ones_like(shear_x), shear_x], [shear_y, tf.ones_like(shear_y)]])
    shear_mat = tf.transpose(a=tf.reshape(shear_mat, [2, 2, -1]), perm=[2, 0, 1])
    
    affine_mat = tf.matmul(shear_mat, rot_scale_mat)
    return tf.matmul(sample, affine_mat)

In [27]:
  def set_start_end_coord(sample):
    """Sets the start and end point coordinates."""
    sample[C.INP_START_COORD] = sample["ink"][:, 0:1, 0:2]
    # Strokes are padded. The end coordinate has the pen-up event where it is 1
    # and the rest has pen-up feature 0.
    tmp_ = sample["ink"][:, :, 0:2] * sample["ink"][:, :, 3:4]
    sample[C.INP_END_COORD] = tf.reduce_sum(input_tensor=tmp_, axis=1, keepdims=True)
    return sample

  def pp_translate_to_origin(sample):
    """Translate strokes to origin."""
    # Batch mode.
    if "target_ink" in sample:
      t_pen_event = sample["target_ink"][:, :, -1:]
      t_start_coord = sample["target_ink"][:, 0:1, 0:3]
      sample["target_ink"] = tf.concat(
          [sample["target_ink"][:, :, 0:3] - t_start_coord, t_pen_event],
          axis=-1)

    pen_event = sample["ink"][:, :, -1:]
    start_coord = sample["ink"][:, 0:1, 0:3]
    sample["ink"] = tf.concat(
        [sample["ink"][:, :, 0:3] - start_coord, pen_event], axis=-1)

    # sample["xy_cov"] = self.sequence_cov(sample["ink"][:, :, 0:2], sample["stroke_length"])
    return sample

In [28]:
import numpy as np

In [None]:
from common

In [29]:
tf_data_ = tf_preprocessing(tf_data_)

NameError: in converted code:

    <ipython-input-27-f35edeb7dbae>:3 set_start_end_coord  *
        sample[C.INP_START_COORD] = sample["ink"][:, 0:1, 0:2]

    NameError: name 'C' is not defined


In [None]:
tf_data_.element_spec

In [None]:
import sys
i = 0
for sample_dict in tf_data_:
    sample = sample_dict['ink']
    if isinstance(sample, tf.Tensor):
        sample = sample.numpy()
    print (sample.shape)
    i+=1
    if i>5:
        sys.exit(0)

In [None]:
tf_data_.element_spec

# Normalization

In [30]:
import numpy as np
import os

In [31]:
meta_data_path = "/data/jcabrera/didi_wo_text/didi_wo_text-stats-origin_abs_pos.npy"

In [32]:
def load_meta_data(meta_data_path):
    """Loads meta-data file given the path.
    It is assumed to be in numpy.
    Args:
        meta_data_path:
    Returns:
        Meta-data dictionary or False if it is not found.
    """
    # if not meta_data_path or not os.path.exists(meta_data_path):
    
    _, ext = os.path.splitext(meta_data_path)
    if ext == ".json":
      meta_fp = tf.io.gfile.GFile(meta_data_path, "r")
      try:
        meta_fp.size()
        print("Loading statistics " + meta_data_path)
        json_stats = json.load(meta_fp)
        stats_np = dict()
        for key_, value_ in json_stats.items():
          stats_np[key_] = np.array(value_) if isinstance(value_, list) else \
            value_
        return stats_np
      except tf.errors.NotFoundError:
        print("Meta-data not found.")
        return False
    
    elif ext == ".npy":
      meta_fp = tf.io.gfile.GFile(meta_data_path, "rb")
      try:
        meta_fp.size()
        print("Loading statistics " + meta_data_path)
        return np.load(meta_fp, allow_pickle=True).item()
      except tf.errors.NotFoundError:
        print("Meta-data not found.")
        return False
    else:
      err_unknown_type(ext)

In [33]:
meta_data = load_meta_data(meta_data_path)

Loading statistics /data/jcabrera/didi_wo_text/didi_wo_text-stats-origin_abs_pos.npy


In [34]:
mean_all = meta_data[C.MEAN_ALL]
std_all = np.sqrt(meta_data[C.VAR_ALL])
mean_channel = meta_data[C.MEAN_CHANNEL]
std_channel = np.sqrt(meta_data[C.VAR_CHANNEL])

NameError: name 'C' is not defined

In [35]:
  def normalize_zero_mean_unit_variance_channel(sample_dict, key):
    sample_dict[key] = (sample_dict[key] - mean_channel)/std_channel
    return sample_dict

In [33]:
  def pp_seq_mask(sample):
    sample["ink"] *= tf.expand_dims(
        tf.sequence_mask(sample["stroke_length"], dtype=tf.float32), axis=2)
    return sample

In [34]:
  def tf_data_normalization(tf_data_):
    # Apply normalizatiom inherited from TFRecordStrokes
    tf_data_ = tf_data_.map(
        functools.partial(
            normalize_zero_mean_unit_variance_channel, key="ink"))
    if gt_targets and (resampling_factor > 1 or random_noise_factor > 0 or t_drop_ratio > 0):
        tf_data_ = tf_data_.map(
          functools.partial(
              normalize_zero_mean_unit_variance_channel, key="target_ink"))
    # After preprocessing and normalization steps, the padded entries
    # may have non-zero values. Here we mask them.
    tf_data_ = tf_data_.map(
        functools.partial(pp_seq_mask),
        num_parallel_calls=num_parallel_calls)
    return tf_data_

In [35]:
tf_data_ = tf_data_normalization(tf_data_)

In [36]:
import sys
i = 0
for sample_dict in tf_data_:
    sample = sample_dict['ink']
    if isinstance(sample, tf.Tensor):
        sample = sample.numpy()
    print (sample.shape)
    i+=1
    if i>5:
        sys.exit(0)

(6, 41, 4)
(7, 34, 4)
(13, 20, 4)
(7, 30, 4)
(7, 39, 4)
(8, 31, 4)


SystemExit: 0

In [37]:
tf_data_.element_spec

{'rdp_ink': SparseTensorSpec(TensorShape([None]), tf.float32),
 'rdp_stroke_length': SparseTensorSpec(TensorShape([None]), tf.int64),
 'rdp_num_strokes': TensorSpec(shape=(), dtype=tf.int64, name=None),
 'ink': TensorSpec(shape=(None, None, 4), dtype=tf.float32, name=None),
 'stroke_length': TensorSpec(shape=(None,), dtype=tf.int64, name=None),
 'num_strokes': TensorSpec(shape=(None,), dtype=tf.int64, name=None),
 'start_coord': TensorSpec(shape=(None, None, 2), dtype=tf.float32, name=None),
 'end_coord': TensorSpec(shape=(None, 1, 2), dtype=tf.float32, name=None),
 'target_ink': TensorSpec(shape=(None, None, 4), dtype=tf.float32, name=None),
 'target_stroke_length': TensorSpec(shape=(None,), dtype=tf.int64, name=None)}

## DATA TO MODEL

In [40]:
C.INP_NUM_STROKE

'num_strokes'

In [41]:
#GENERAL CONFIG VARIABLES
fixed_len = False
concat_t_inputs = False

In [42]:
# FUNCTION INNER VARIABLES
mask_pen = True
int_t_samples = False #
n_t_targets = 4 #

In [43]:
  def tf_data_to_model(tf_data):

    if fixed_len:
        tf_data = tf_data.map(
          functools.partial(pp_pad_to_max_len),
          num_parallel_calls=num_parallel_calls)

    tf_data = tf_data.map(
        functools.partial(pp_get_t_targets),
        num_parallel_calls=num_parallel_calls)

    if concat_t_inputs:
        tf_data = tf_data.map(
          functools.partial(pp_concat_t_inputs),
          num_parallel_calls=num_parallel_calls)
    
    def element_length_func(model_inputs, _):
        return tf.cast(model_inputs[C.INP_NUM_STROKE], tf.int32)

    # Converts the data into the format that a model expects.
    # Creates input, target, sequence_length, etc.
    tf_data = tf_data.map(functools.partial(__to_model_batch))
    # TODO configurable bucket_batch_size
    if batch_size >= 1:
        bucket_batch_size = [
          batch_size,
          int(math.ceil(batch_size / 2)),
          int(math.ceil(batch_size / 3)),
          int(math.ceil(batch_size / 4)),
          int(math.ceil(batch_size / 5)),
      ]
        tf_data = tf_data.apply(
          tf.data.experimental.bucket_by_sequence_length(
              element_length_func=element_length_func,
              bucket_batch_sizes=bucket_batch_size,
              bucket_boundaries=[8, 13, 18, 23],
              pad_to_bucket_boundary=False))
    else:
        tf_data = tf_data.padded_batch(batch_size=1)
    return tf_data

In [44]:
  def __to_model_batch(tf_sample_dict):
    """Transforms a TFRecord sample into a more general sample representation.
    We use global keys to represent the required fields by the models.
    Args:
        tf_sample_dict:
    Returns:
    """
    # Target are the inputs shifted by one step.
    # We ignore the timestamp and pen event.
    model_input = dict()
    ink_ = tf.concat(
        [tf_sample_dict["ink"][:, :, 0:2], tf_sample_dict["ink"][:, :, 3:]],
        axis=-1)  # Ignore the timestamp.
    if mask_pen:
      mask_ = tf.sequence_mask(
          tf_sample_dict["stroke_length"] - 1,
          dtype=tf.float32,
          maxlen=tf.reduce_max(input_tensor=tf_sample_dict["stroke_length"]))
      model_input[C.INP_SEQ_LEN] = tf_sample_dict["stroke_length"] - 1
      model_input[C.INP_ENC] = (ink_ * tf.expand_dims(mask_, axis=2))[:, 0:-1]
    else:
      model_input[C.INP_SEQ_LEN] = tf_sample_dict["stroke_length"]
      model_input[C.INP_ENC] = ink_
    model_input[C.INP_DEC] = tf.concat(
        [tf.zeros_like(ink_[:, 0:1]), ink_[:, 0:-1]], axis=1)

    model_input[C.INP_START_COORD] = tf_sample_dict[C.INP_START_COORD]
    model_input[C.INP_END_COORD] = tf_sample_dict[C.INP_END_COORD]
    model_input[C.INP_NUM_STROKE] = tf.shape(input=tf_sample_dict["stroke_length"])[0]
    model_input[C.INP_T] = tf_sample_dict[C.INP_T]
    model_input[C.TARGET_T_INK] = tf_sample_dict[C.TARGET_T_INK]
    # model_input["xy_cov"] = tf_sample_dict["xy_cov"]

    model_target = dict()
    if "target_ink" in tf_sample_dict:
      ink_t = tf.concat([tf_sample_dict["target_ink"][:, :, 0:2],
                         tf_sample_dict["target_ink"][:, :, 3:4]
                         ], axis=-1)
      model_input[C.INP_DEC] = tf.concat(
          [tf.zeros_like(ink_t[:, 0:1]), ink_t[:, 0:-1]], axis=1)
  
      model_target["stroke"] = tf_sample_dict["target_ink"][:, :, 0:2]
      model_target["pen"] = tf_sample_dict["target_ink"][:, :, 3:4]
      model_target[C.INP_SEQ_LEN] = tf_sample_dict["target_stroke_length"]
      model_target[C.INP_NUM_STROKE] = tf.shape(input=tf_sample_dict["target_stroke_length"])[0]
    else:
      model_target = dict()
      model_target["stroke"] = tf_sample_dict["ink"][:, :, 0:2]
      model_target["pen"] = tf_sample_dict["ink"][:, :, 3:4]
      model_target[C.INP_SEQ_LEN] = tf_sample_dict["stroke_length"]
      model_target[C.INP_NUM_STROKE] = tf.shape(input=tf_sample_dict["stroke_length"])[0]

    model_target[C.INP_START_COORD] = model_input[C.INP_START_COORD]
    model_target[C.INP_END_COORD] = model_input[C.INP_END_COORD]

    model_target[C.TARGET_T_INK] = tf_sample_dict[C.TARGET_T_INK]
    model_target[C.TARGET_T_STROKE] = tf_sample_dict[C.TARGET_T_INK][:, :, 0:2]
    # timestamp already discarded.
    model_target[C.TARGET_T_PEN] = tf_sample_dict[C.TARGET_T_INK][:, :, 2:3]
    # model_target["xy_cov"] = tf_sample_dict["xy_cov"]
    
    return model_input, model_target

In [45]:
  def pp_get_t_targets(sample):
    """Draw a random t from [0,1] and get the interpolated point in the sequence.
    
    Handles multiple stroke and multiple t cases.
    Args:
      sample:
    Returns:
    """
    key_len = "stroke_length"
    key_ink = "ink"
    if gt_targets and "target_ink" in sample:
      key_len = "target_stroke_length"
      key_ink = "target_ink"
    
    if int_t_samples:
      n_strokes = tf.shape(input=sample[key_ink])[0]
      t = tf.random.uniform([n_strokes, n_t_targets], minval=0, maxval=1,
                            dtype=tf.float32)
      len_t = t*tf.cast(tf.expand_dims(sample[key_len], axis=-1), tf.float32)
      len_t = tf.floor(len_t)
      t = len_t / tf.tile(tf.expand_dims(tf.cast(sample[key_len]-1, tf.float32), axis=1), (1, n_t_targets))
      lower_idx = tf.cast(len_t, tf.int32)
      
      batch_indices = tf.ones_like(lower_idx)
      batch_indices *= tf.expand_dims(tf.range(n_strokes), axis=-1)
      
      gather_lower_idx = tf.stack([
          batch_indices,
          lower_idx
          ], axis=-1)

      lower_points = tf.gather_nd(sample[key_ink], gather_lower_idx)
      inter_points = tf.concat([lower_points[:, :, :-2], lower_points[:, :, -1:]], axis=-1)
    else:
      n_strokes = tf.shape(input=sample[key_ink])[0]
      t = tf.random.uniform([n_strokes, n_t_targets], minval=0, maxval=1,
                            dtype=tf.float32)
      len_t = t*tf.cast(tf.expand_dims(sample[key_len], axis=-1) - 1,
                        tf.float32)
      
      # Identify lower and upper points.
      lower_idx = tf.cast(tf.floor(len_t), tf.int32)
      upper_idx = tf.cast(tf.math.ceil(len_t), tf.int32)
  
      batch_indices = tf.ones_like(lower_idx)
      batch_indices *= tf.expand_dims(tf.range(n_strokes), axis=-1)
  
      gather_lower_idx = tf.stack([
          batch_indices,
          lower_idx
          ], axis=-1)
  
      gather_upper_idx = tf.stack([
          batch_indices,
          upper_idx
          ], axis=-1)
  
      lower_points = tf.gather_nd(sample[key_ink], gather_lower_idx)
      upper_points = tf.gather_nd(sample[key_ink], gather_upper_idx)
  
      factor = tf.expand_dims((len_t - tf.floor(len_t)), axis=-1)
      inter_points = factor*upper_points + (1 - factor)*lower_points
      
      max_pen = tf.maximum(lower_points[:, :, -1:], upper_points[:, :, -1:])
      inter_points = tf.concat([inter_points[:, :, :-2], max_pen], axis=-1)
    sample[C.INP_T] = t
    sample[C.TARGET_T_INK] = inter_points
    return sample

In [1]:
import tensorflow as tf

In [6]:
sample_key_len = tf_data_['seq_len']

NameError: name 'tf_data_' is not defined

In [None]:
t = tf.random.uniform([100, 4], minval=0, maxval=1,
                    dtype=tf.float32)
len_t = t*tf.cast(tf.expand_dims(sample[key_len], axis=-1) - 1,
                tf.float32)

# Identify lower and upper points.
lower_idx = tf.cast(tf.floor(len_t), tf.int32)
upper_idx = tf.cast(tf.math.ceil(len_t), tf.int32)

batch_indices = tf.ones_like(lower_idx)
batch_indices *= tf.expand_dims(tf.range(n_strokes), axis=-1)

gather_lower_idx = tf.stack([
  batch_indices,
  lower_idx
  ], axis=-1)

gather_upper_idx = tf.stack([
  batch_indices,
  upper_idx
  ], axis=-1)

lower_points = tf.gather_nd(sample[key_ink], gather_lower_idx)
upper_points = tf.gather_nd(sample[key_ink], gather_upper_idx)

factor = tf.expand_dims((len_t - tf.floor(len_t)), axis=-1)
inter_points = factor*upper_points + (1 - factor)*lower_points

max_pen = tf.maximum(lower_points[:, :, -1:], upper_points[:, :, -1:])
inter_points = tf.concat([inter_points[:, :, :-2], max_pen], axis=-1)

In [46]:
import math

In [47]:
tf_data_ = tf_data_to_model(tf_data_)

In [50]:
tf_data_.element_spec

({'seq_len': TensorSpec(shape=(None, None), dtype=tf.int64, name=None),
  'encoder_inputs': TensorSpec(shape=(None, None, None, 3), dtype=tf.float32, name=None),
  'decoder_inputs': TensorSpec(shape=(None, None, None, 3), dtype=tf.float32, name=None),
  'start_coord': TensorSpec(shape=(None, None, None, 2), dtype=tf.float32, name=None),
  'end_coord': TensorSpec(shape=(None, None, 1, 2), dtype=tf.float32, name=None),
  'num_strokes': TensorSpec(shape=(None,), dtype=tf.int32, name=None),
  't_input': TensorSpec(shape=(None, None, 4), dtype=tf.float32, name=None),
  't_target_ink': TensorSpec(shape=(None, None, 4, 3), dtype=tf.float32, name=None)},
 {'stroke': TensorSpec(shape=(None, None, None, 2), dtype=tf.float32, name=None),
  'pen': TensorSpec(shape=(None, None, None, 1), dtype=tf.float32, name=None),
  'seq_len': TensorSpec(shape=(None, None), dtype=tf.int64, name=None),
  'num_strokes': TensorSpec(shape=(None,), dtype=tf.int32, name=None),
  'start_coord': TensorSpec(shape=(None, 

In [56]:
import sys
i = 0
for sample_dict in tf_data_:
    sample = sample_dict['encoder_inputs']
    if isinstance(sample, tf.Tensor):
        sample = sample.numpy()
    print (sample.shape)
    i+=1
    if i>5:
        sys.exit(0)

TypeError: tuple indices must be integers or slices, not str

In [49]:
tf_data_.take(1)

<DatasetV1Adapter shapes: ({seq_len: (None, None), encoder_inputs: (None, None, None, 3), decoder_inputs: (None, None, None, 3), start_coord: (None, None, None, 2), end_coord: (None, None, 1, 2), num_strokes: (None,), t_input: (None, None, 4), t_target_ink: (None, None, 4, 3)}, {stroke: (None, None, None, 2), pen: (None, None, None, 1), seq_len: (None, None), num_strokes: (None,), start_coord: (None, None, None, 2), end_coord: (None, None, 1, 2), t_target_ink: (None, None, 4, 3), t_target_stroke: (None, None, 4, 2), t_target_pen: (None, None, 4, 1)}), types: ({seq_len: tf.int64, encoder_inputs: tf.float32, decoder_inputs: tf.float32, start_coord: tf.float32, end_coord: tf.float32, num_strokes: tf.int32, t_input: tf.float32, t_target_ink: tf.float32}, {stroke: tf.float32, pen: tf.float32, seq_len: tf.int64, num_strokes: tf.int32, start_coord: tf.float32, end_coord: tf.float32, t_target_ink: tf.float32, t_target_stroke: tf.float32, t_target_pen: tf.float32})>

In [58]:
itr = tf_data_.make_one_shot_iterator()
for i in range(5):
    inputs, targets = itr.get_next()
    print(inputs['encoder_inputs'].shape)

(64, 12, 89, 3)
(128, 7, 63, 3)
(64, 12, 86, 3)
(43, 17, 80, 3)
(64, 12, 61, 3)


In [54]:
inputs.keys()

dict_keys(['seq_len', 'encoder_inputs', 'decoder_inputs', 'start_coord', 'end_coord', 'num_strokes', 't_input', 't_target_ink'])

In [55]:
targets.keys()

dict_keys(['stroke', 'pen', 'seq_len', 'num_strokes', 'start_coord', 'end_coord', 't_target_ink', 't_target_stroke', 't_target_pen'])

In [None]:
tf_data_

In [68]:
  def tf_data_to_model(self):

    if self.fixed_len:
      self.tf_data = self.tf_data.map(
          functools.partial(self.pp_pad_to_max_len),
          num_parallel_calls=self.num_parallel_calls)

    self.tf_data = self.tf_data.map(
        functools.partial(self.pp_get_t_targets),
        num_parallel_calls=self.num_parallel_calls)

    if self.concat_t_inputs:
      self.tf_data = self.tf_data.map(
          functools.partial(self.pp_concat_t_inputs),
          num_parallel_calls=self.num_parallel_calls)
    
    def element_length_func(model_inputs, _):
      return tf.cast(model_inputs[C.INP_NUM_STROKE], tf.int32)

    # Converts the data into the format that a model expects.
    # Creates input, target, sequence_length, etc.
    self.tf_data = self.tf_data.map(functools.partial(self.__to_model_batch))
    # TODO configurable bucket_batch_size
    if self.batch_size >= 1:
      bucket_batch_size = [
          self.batch_size,
          int(math.ceil(self.batch_size / 2)),
          int(math.ceil(self.batch_size / 3)),
          int(math.ceil(self.batch_size / 4)),
          int(math.ceil(self.batch_size / 5)),
      ]
      self.tf_data = self.tf_data.apply(
          tf.data.experimental.bucket_by_sequence_length(
              element_length_func=element_length_func,
              bucket_batch_sizes=bucket_batch_size,
              bucket_boundaries=[8, 13, 18, 23],
              pad_to_bucket_boundary=False))
    else:
      self.tf_data = self.tf_data.padded_batch(batch_size=1)