In [1]:
import tensorflow as tf
from tensorflow.python.framework import ops,dtypes
from tensorflow.python.ops import array_ops,math_ops,data_flow_ops
from six.moves import xrange

In [2]:
def _embedding_lookup_and_transform(params,
                                    ids,
                                    partition_strategy="mod",
                                    name=None,
                                    max_norm=None,
                                    transform_fn=None):
  """Helper function for embedding_lookup and _compute_sampled_logits.

  This function is a generalization of embedding_lookup that optionally
  applies a caller-specified transformation to each embedding. This is
  done through the `transform_fn` argument. If provided, the function is
  applied to each partitioned tensor of retrieved embeddings, colocated
  with the embeddings. This function will be called with a single `Tensor`
  argument of the same type as the `params` tensor and should return a
  `Tensor`. The shape of the argument will be the same as `params` except
  for the size of the first dimension. The first dimension of the result's
  shape must be the same size as the argument's.

  Args:
    params: See embedding_lookup.
    ids: See embedding_lookup.
    partition_strategy: See embedding_lookup.
    name: See embedding_lookup.
    max_norm: See embedding_lookup.
    transform_fn: 可选择转换函数，对检索到的emb进行转换。

  Returns:
    See embedding_lookup for details.
  Raises:
    ValueError: If `params` is empty.
  """
  pass

## embedding_lookup input
1. ids，输入的是去重后的id集合，一维；
2. sp_ids，输入的是sparse_ids转dense后的，二维且含空值

In [3]:
def _clip(params, ids, max_norm):
  def _rank(x):
    rank = ops.convert_to_tensor(x).get_shape().ndims
    if rank:
      return rank, True
    else:
      return array_ops.rank(x), False
  if max_norm is None:
    return params
  ids_rank, ids_static = _rank(ids)
  params_rank, params_static = _rank(params)
  return clip_ops.clip_by_norm(
      params,
      max_norm,
      axes=(list(range(ids_rank, params_rank)) if ids_static and params_static
            else math_ops.range(ids_rank, params_rank)))

In [4]:
def sparse_slice(sparse_feature, max_length):
    if max_length is None:
        return sparse_feature
    max_length = tf.constant(max_length, dtype=tf.int64)
    return tf.sparse.slice(sparse_feature, [0, 0], [sparse_feature.dense_shape[0], max_length])

def SparseTensor_to_Dense(sp_input, max_length, default_id):
    process_feature = sparse_slice(sp_input, max_length)
    not_empty = tf.cast(tf.sparse.to_dense(process_feature, -1, name='check_empty') > -1, tf.float32)
    mask = tf.expand_dims(not_empty, axis=-1)
    sparse_hash_feature = tf.sparse.to_dense(process_feature, default_id, name='default_id')
    return sparse_hash_feature, mask

### build params

In [5]:
params = []
for p in range(8):
    emb = []
    for index in range(12):
        emb.append(array_ops.ones([1,8])* (p * 12 + index + 1))
    params.append(tf.concat(emb, axis=0))
params

[<tf.Tensor: shape=(12, 8), dtype=float32, numpy=
 array([[ 1.,  1.,  1.,  1.,  1.,  1.,  1.,  1.],
        [ 2.,  2.,  2.,  2.,  2.,  2.,  2.,  2.],
        [ 3.,  3.,  3.,  3.,  3.,  3.,  3.,  3.],
        [ 4.,  4.,  4.,  4.,  4.,  4.,  4.,  4.],
        [ 5.,  5.,  5.,  5.,  5.,  5.,  5.,  5.],
        [ 6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.],
        [ 7.,  7.,  7.,  7.,  7.,  7.,  7.,  7.],
        [ 8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.],
        [ 9.,  9.,  9.,  9.,  9.,  9.,  9.,  9.],
        [10., 10., 10., 10., 10., 10., 10., 10.],
        [11., 11., 11., 11., 11., 11., 11., 11.],
        [12., 12., 12., 12., 12., 12., 12., 12.]], dtype=float32)>,
 <tf.Tensor: shape=(12, 8), dtype=float32, numpy=
 array([[13., 13., 13., 13., 13., 13., 13., 13.],
        [14., 14., 14., 14., 14., 14., 14., 14.],
        [15., 15., 15., 15., 15., 15., 15., 15.],
        [16., 16., 16., 16., 16., 16., 16., 16.],
        [17., 17., 17., 17., 17., 17., 17., 17.],
        [18., 18., 18., 18., 18.

In [6]:
sparse_ids = tf.sparse.SparseTensor(indices=[[0,0],[0,1],[0,2],[1,0],[1,1],[2,0],[2,1],[3,0]],
                                    values=[1,10,13,14,15,6,7,8],
                                    dense_shape=[4,3])
sparse_ids,tf.sparse.to_dense(sparse_ids)

(<tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7fba704c7250>,
 <tf.Tensor: shape=(4, 3), dtype=int32, numpy=
 array([[ 1, 10, 13],
        [14, 15,  0],
        [ 6,  7,  0],
        [ 8,  0,  0]], dtype=int32)>)

In [7]:
sp_ids, mask = SparseTensor_to_Dense(sparse_ids, None, 0)
sp_ids

<tf.Tensor: shape=(4, 3), dtype=int32, numpy=
array([[ 1, 10, 13],
       [14, 15,  0],
       [ 6,  7,  0],
       [ 8,  0,  0]], dtype=int32)>

In [8]:
sp_ids = ops.convert_to_tensor(sp_ids)
sp_ids

<tf.Tensor: shape=(4, 3), dtype=int32, numpy=
array([[ 1, 10, 13],
       [14, 15,  0],
       [ 6,  7,  0],
       [ 8,  0,  0]], dtype=int32)>

### 假设按照mod分片查找与reshape

In [9]:
flat_ids = array_ops.reshape(sp_ids, [-1])
original_indices = math_ops.range(array_ops.size(flat_ids))
flat_ids,original_indices

(<tf.Tensor: shape=(12,), dtype=int32, numpy=array([ 1, 10, 13, 14, 15,  0,  6,  7,  0,  8,  0,  0], dtype=int32)>,
 <tf.Tensor: shape=(12,), dtype=int32, numpy=array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11], dtype=int32)>)

In [10]:
np = 8
p_assignments = flat_ids % np
new_ids = flat_ids // np
p_assignments,new_ids

(<tf.Tensor: shape=(12,), dtype=int32, numpy=array([1, 2, 5, 6, 7, 0, 6, 7, 0, 0, 0, 0], dtype=int32)>,
 <tf.Tensor: shape=(12,), dtype=int32, numpy=array([0, 1, 1, 1, 1, 0, 0, 0, 0, 1, 0, 0], dtype=int32)>)

In [11]:
p_assignments * 12 + new_ids + 1

<tf.Tensor: shape=(12,), dtype=int32, numpy=array([13, 26, 62, 74, 86,  1, 73, 85,  1,  2,  1,  1], dtype=int32)>

In [12]:
transform_fn=None
# Cast partition assignments to int32 for use in dynamic_partition.
# There really should not be more than 2^32 partitions.
p_assignments = math_ops.cast(p_assignments, dtypes.int32)
# Partition list of ids based on assignments into np separate lists
gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
# Similarly, partition the original indices.
pindices = data_flow_ops.dynamic_partition(original_indices,p_assignments, np)
# Do np separate lookups, finding embeddings for plist[p] in params[p]
partitioned_result = []
for p in xrange(np):
    pids = gather_ids[p]
    with ops.colocate_with(params[p]):
        result = array_ops.gather(params[p], pids)
        if transform_fn:
            # If transform_fn is provided, the clip_by_norm precedes
            # the transform and hence must be co-located. See below
            # for the counterpart if transform_fn is not provided.
            result = transform_fn(_clip(result, pids, max_norm))
    partitioned_result.append(result)
# Stitch these back together
ret = data_flow_ops.parallel_dynamic_stitch(pindices, partitioned_result)
print('result is {}'.format(ret))
# Determine the static element shape.
if transform_fn is None:
    element_shape_s = params[0].get_shape()[1:]
    for p in params[1:]:
        element_shape_s = element_shape_s.merge_with(p.get_shape()[1:])
else:
    element_shape_s = ret.get_shape()[1:]

# Compute the dynamic element shape.
if element_shape_s.is_fully_defined():
    element_shape_d = element_shape_s
elif transform_fn is None:
        # It's important that we compute params[0].shape on the right device
        # to avoid data motion.
    with ops.colocate_with(params[0]):
        params_shape = array_ops.shape(params[0])
    element_shape_d = params_shape[1:]
else:
    element_shape_d = array_ops.shape(ret)[1:]

# Reshape to reverse the flattening of ids.
ret = array_ops.reshape(ret, array_ops.concat([array_ops.shape(sp_ids), element_shape_d], 0))
ret.set_shape(sp_ids.get_shape().concatenate(element_shape_s))
if not transform_fn:
    # If transform_fn was provided, the clip_by_norm was done above.
    ret = _clip(ret, sp_ids, None)
print('final result is {}'.format(ret))

result is [[13. 13. 13. 13. 13. 13. 13. 13.]
 [26. 26. 26. 26. 26. 26. 26. 26.]
 [62. 62. 62. 62. 62. 62. 62. 62.]
 [74. 74. 74. 74. 74. 74. 74. 74.]
 [86. 86. 86. 86. 86. 86. 86. 86.]
 [ 1.  1.  1.  1.  1.  1.  1.  1.]
 [73. 73. 73. 73. 73. 73. 73. 73.]
 [85. 85. 85. 85. 85. 85. 85. 85.]
 [ 1.  1.  1.  1.  1.  1.  1.  1.]
 [ 2.  2.  2.  2.  2.  2.  2.  2.]
 [ 1.  1.  1.  1.  1.  1.  1.  1.]
 [ 1.  1.  1.  1.  1.  1.  1.  1.]]
final result is [[[13. 13. 13. 13. 13. 13. 13. 13.]
  [26. 26. 26. 26. 26. 26. 26. 26.]
  [62. 62. 62. 62. 62. 62. 62. 62.]]

 [[74. 74. 74. 74. 74. 74. 74. 74.]
  [86. 86. 86. 86. 86. 86. 86. 86.]
  [ 1.  1.  1.  1.  1.  1.  1.  1.]]

 [[73. 73. 73. 73. 73. 73. 73. 73.]
  [85. 85. 85. 85. 85. 85. 85. 85.]
  [ 1.  1.  1.  1.  1.  1.  1.  1.]]

 [[ 2.  2.  2.  2.  2.  2.  2.  2.]
  [ 1.  1.  1.  1.  1.  1.  1.  1.]
  [ 1.  1.  1.  1.  1.  1.  1.  1.]]]


#### CAN网络此处，返回且计算

## turely test raw ids set 1-dims

In [13]:
ids = tf.constant([0,1,14,27,8,19,15,20])
ids = ops.convert_to_tensor(ids)
ids

<tf.Tensor: shape=(8,), dtype=int32, numpy=array([ 0,  1, 14, 27,  8, 19, 15, 20], dtype=int32)>

In [14]:
# for raw ids
flat_ids = array_ops.reshape(ids, [-1])
original_indices = math_ops.range(array_ops.size(flat_ids))
np = 8
p_assignments = flat_ids % np
new_ids = flat_ids // np
p_assignments,new_ids

(<tf.Tensor: shape=(8,), dtype=int32, numpy=array([0, 1, 6, 3, 0, 3, 7, 4], dtype=int32)>,
 <tf.Tensor: shape=(8,), dtype=int32, numpy=array([0, 0, 1, 3, 1, 2, 1, 2], dtype=int32)>)

In [15]:
p_assignments * 12 + new_ids + 1

<tf.Tensor: shape=(8,), dtype=int32, numpy=array([ 1, 13, 74, 40,  2, 39, 86, 51], dtype=int32)>

In [16]:
transform_fn=None
# Cast partition assignments to int32 for use in dynamic_partition.
# There really should not be more than 2^32 partitions.
p_assignments = math_ops.cast(p_assignments, dtypes.int32)
# Partition list of ids based on assignments into np separate lists
gather_ids = data_flow_ops.dynamic_partition(new_ids, p_assignments, np)
# Similarly, partition the original indices.
pindices = data_flow_ops.dynamic_partition(original_indices,
                                                 p_assignments, np)
# Do np separate lookups, finding embeddings for plist[p] in params[p]
partitioned_result = []
for p in xrange(np):
    pids = gather_ids[p]
    with ops.colocate_with(params[p]):
        result = array_ops.gather(params[p], pids)
        if transform_fn:
            # If transform_fn is provided, the clip_by_norm precedes
            # the transform and hence must be co-located. See below
            # for the counterpart if transform_fn is not provided.
            result = transform_fn(_clip(result, pids, max_norm))
    partitioned_result.append(result)
# Stitch these back together
ret = data_flow_ops.parallel_dynamic_stitch(pindices, partitioned_result)
print('result is {}'.format(ret))
# Determine the static element shape.
if transform_fn is None:
    element_shape_s = params[0].get_shape()[1:]
    for p in params[1:]:
        element_shape_s = element_shape_s.merge_with(p.get_shape()[1:])
else:
    element_shape_s = ret.get_shape()[1:]

# Compute the dynamic element shape.
if element_shape_s.is_fully_defined():
    element_shape_d = element_shape_s
elif transform_fn is None:
        # It's important that we compute params[0].shape on the right device
        # to avoid data motion.
    with ops.colocate_with(params[0]):
        params_shape = array_ops.shape(params[0])
    element_shape_d = params_shape[1:]
else:
    element_shape_d = array_ops.shape(ret)[1:]

# Reshape to reverse the flattening of ids.
ret = array_ops.reshape(ret, array_ops.concat([array_ops.shape(ids), element_shape_d], 0))
ret.set_shape(ids.get_shape().concatenate(element_shape_s))
if not transform_fn:
    # If transform_fn was provided, the clip_by_norm was done above.
    ret = _clip(ret, ids, None)
print('final result is {}'.format(ret))

result is [[ 1.  1.  1.  1.  1.  1.  1.  1.]
 [13. 13. 13. 13. 13. 13. 13. 13.]
 [74. 74. 74. 74. 74. 74. 74. 74.]
 [40. 40. 40. 40. 40. 40. 40. 40.]
 [ 2.  2.  2.  2.  2.  2.  2.  2.]
 [39. 39. 39. 39. 39. 39. 39. 39.]
 [86. 86. 86. 86. 86. 86. 86. 86.]
 [51. 51. 51. 51. 51. 51. 51. 51.]]
final result is [[ 1.  1.  1.  1.  1.  1.  1.  1.]
 [13. 13. 13. 13. 13. 13. 13. 13.]
 [74. 74. 74. 74. 74. 74. 74. 74.]
 [40. 40. 40. 40. 40. 40. 40. 40.]
 [ 2.  2.  2.  2.  2.  2.  2.  2.]
 [39. 39. 39. 39. 39. 39. 39. 39.]
 [86. 86. 86. 86. 86. 86. 86. 86.]
 [51. 51. 51. 51. 51. 51. 51. 51.]]
