In [2]:
import tensorflow as tf
from tensorflow.python.ops import math_ops,array_ops,sparse_ops
from tensorflow.python.framework import tensor_shape,sparse_tensor,dtypes
from tensorflow.python.ops import embedding_ops

In [3]:
def safe_embedding_lookup_sparse(embedding_weights,
                                 sparse_ids,
                                 sparse_weights=None,
                                 combiner="mean",
                                 default_id=None,
                                 name=None,
                                 partition_strategy="div",
                                 max_norm=None):
  """
  """
  pass

## 0.准备工作

In [5]:
sparse_ids = tf.sparse.SparseTensor(indices=[[0,0,0],[0,1,0],[0,2,0],[1,0,0],[1,1,0],[2,0,0],[3,0,0]],
                                    values=[123,234,-11,1245,8989,124,2121],
                                    dense_shape=[4,3,1])
sparse_weights = tf.sparse.SparseTensor(indices=[[0,0,0],[0,1,0],[0,2,0],[1,0,0],[2,0,0],[2,1,0],[3,0,0]],
                                    values=[1.0,2,1,1,1,1,-1],
                                    dense_shape=[4,3,1])
tf.sparse.to_dense(sparse_ids),tf.sparse.to_dense(sparse_weights)

(<tf.Tensor: shape=(4, 3, 1), dtype=int32, numpy=
 array([[[ 123],
         [ 234],
         [ -11]],
 
        [[1245],
         [8989],
         [   0]],
 
        [[ 124],
         [   0],
         [   0]],
 
        [[2121],
         [   0],
         [   0]]], dtype=int32)>,
 <tf.Tensor: shape=(4, 3, 1), dtype=float32, numpy=
 array([[[ 1.],
         [ 2.],
         [ 1.]],
 
        [[ 1.],
         [ 0.],
         [ 0.]],
 
        [[ 1.],
         [ 1.],
         [ 0.]],
 
        [[-1.],
         [ 0.],
         [ 0.]]], dtype=float32)>)

In [6]:
use_safe_embedding_lookup=True
sparse_id_rank = tensor_shape.dimension_value(sparse_ids.dense_shape.get_shape()[0])
embedding_lookup_sparse = embedding_ops.safe_embedding_lookup_sparse
if (not use_safe_embedding_lookup and sparse_id_rank is not None and sparse_id_rank <= 2):
    embedding_lookup_sparse = embedding_ops.embedding_lookup_sparse_v2
embedding_lookup_sparse.__name__

'safe_embedding_lookup_sparse'

## 1.降维至可线性切分
1. [d0, d1, ..., dn]，变为[d0 * d1 * ... * dn-1, dn]；
2. 保留最后一维，前N维度相乘；

In [7]:
original_shape = sparse_ids.dense_shape
original_rank_dim = tensor_shape.dimension_value(sparse_ids.dense_shape.get_shape()[0])
original_rank = (array_ops.size(original_shape) if original_rank_dim is None else original_rank_dim)
'原来shape为：{}, 原来的维度为：{}，原来的维度为：{}'.format(original_shape, original_rank_dim, original_rank)

'原来shape为：[4 3 1], 原来的维度为：3，原来的维度为：3'

In [8]:
sparse_ids = sparse_ops.sparse_reshape(
    sparse_ids, 
    [   # 前维度相乘：4*3=12
        math_ops.reduce_prod(array_ops.slice(original_shape, [0], [original_rank - 1])),
        # 取最后一个维度：1
        array_ops.gather(original_shape, original_rank - 1)
    ])
if sparse_weights is not None:
    sparse_weights = sparse_tensor.SparseTensor(sparse_ids.indices,sparse_weights.values,sparse_ids.dense_shape)
tf.sparse.to_dense(sparse_ids),tf.sparse.to_dense(sparse_weights)

(<tf.Tensor: shape=(12, 1), dtype=int32, numpy=
 array([[ 123],
        [ 234],
        [ -11],
        [1245],
        [8989],
        [   0],
        [ 124],
        [   0],
        [   0],
        [2121],
        [   0],
        [   0]], dtype=int32)>,
 <tf.Tensor: shape=(12, 1), dtype=float32, numpy=
 array([[ 1.],
        [ 2.],
        [ 1.],
        [ 1.],
        [ 1.],
        [ 0.],
        [ 1.],
        [ 0.],
        [ 0.],
        [-1.],
        [ 0.],
        [ 0.]], dtype=float32)>)

### 结论
1. 输入[4,3,1]维度的矩阵，保留最后一个维度；reshape成[4*3, 1]；
2. 如果最后一维是1，单值，那么segment_sum时就是加自身；
3. 如果最后一维非1，多值，那么segment_sum时就是pooling。

## 2.检查非法id和weights，过滤掉非法的pair<id,weights>

In [9]:
def _prune_invalid_ids(sparse_ids, sparse_weights):
  """Prune invalid IDs (< 0) from the input ids and weights."""
  is_id_valid = math_ops.greater_equal(sparse_ids.values, 0)
  if sparse_weights is not None:
    is_id_valid = math_ops.logical_and(
        is_id_valid,
        array_ops.ones_like(sparse_weights.values, dtype=dtypes.bool))
  # 按照sp的values，构造对应的bool values
  sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_id_valid)
  if sparse_weights is not None:
    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_id_valid)
  return sparse_ids, sparse_weights
def _prune_invalid_weights(sparse_ids, sparse_weights):
  """Prune invalid weights (< 0) from the input ids and weights."""
  if sparse_weights is not None:
    is_weights_valid = math_ops.greater(sparse_weights.values, 0)
    # 按照sp.values创建的bool values保留sparse_id，和weights
    sparse_ids = sparse_ops.sparse_retain(sparse_ids, is_weights_valid)
    sparse_weights = sparse_ops.sparse_retain(sparse_weights, is_weights_valid)
  return sparse_ids, sparse_weights

In [10]:
sparse_ids,sparse_weights = _prune_invalid_ids(sparse_ids, sparse_weights)
sparse_ids,sparse_weights,tf.sparse.to_dense(sparse_ids,66),tf.sparse.to_dense(sparse_weights,66)

(<tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7f9ecfcc7690>,
 <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7f9ed02a6750>,
 <tf.Tensor: shape=(12, 1), dtype=int32, numpy=
 array([[ 123],
        [ 234],
        [  66],
        [1245],
        [8989],
        [  66],
        [ 124],
        [  66],
        [  66],
        [2121],
        [  66],
        [  66]], dtype=int32)>,
 <tf.Tensor: shape=(12, 1), dtype=float32, numpy=
 array([[ 1.],
        [ 2.],
        [66.],
        [ 1.],
        [ 1.],
        [66.],
        [ 1.],
        [66.],
        [66.],
        [-1.],
        [66.],
        [66.]], dtype=float32)>)

In [11]:
sparse_ids,sparse_weights = _prune_invalid_weights(sparse_ids, sparse_weights)
sparse_ids,sparse_weights,tf.sparse.to_dense(sparse_ids,66),tf.sparse.to_dense(sparse_weights,66)

(<tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7f9ecfcc3e10>,
 <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7f9ecfccae90>,
 <tf.Tensor: shape=(12, 1), dtype=int32, numpy=
 array([[ 123],
        [ 234],
        [  66],
        [1245],
        [8989],
        [  66],
        [ 124],
        [  66],
        [  66],
        [  66],
        [  66],
        [  66]], dtype=int32)>,
 <tf.Tensor: shape=(12, 1), dtype=float32, numpy=
 array([[ 1.],
        [ 2.],
        [66.],
        [ 1.],
        [ 1.],
        [66.],
        [ 1.],
        [66.],
        [66.],
        [66.],
        [66.],
        [66.]], dtype=float32)>)

### 结论
过滤条件为：id >= 0, 且weights>0；保留合法id和weights

## 3.填补空行id与weights

In [12]:
sp_ids, is_row_empty = sparse_ops.sparse_fill_empty_rows(sparse_ids,0)
if sparse_weights is not None:
    sp_weights, _ = sparse_ops.sparse_fill_empty_rows(sparse_weights,1.0)
sp_ids,sp_weights,tf.sparse.to_dense(sp_ids),tf.sparse.to_dense(sp_weights)

(<tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7f9ecc4c2bd0>,
 <tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7f9ecc4c2690>,
 <tf.Tensor: shape=(12, 1), dtype=int32, numpy=
 array([[ 123],
        [ 234],
        [   0],
        [1245],
        [8989],
        [   0],
        [ 124],
        [   0],
        [   0],
        [   0],
        [   0],
        [   0]], dtype=int32)>,
 <tf.Tensor: shape=(12, 1), dtype=float32, numpy=
 array([[1.],
        [2.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]], dtype=float32)>)

### 结论
1. 过滤与检查非法id与weights；
2. 特征行为空的处理

## 4.embedding_look_sparse
1. 构造假的emb；

In [13]:
result = array_ops.ones([12,8])
result

<tf.Tensor: shape=(12, 8), dtype=float32, numpy=
array([[1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.]], dtype=float32)>

### 变长特征缺省处理
1. 变长特征的长度不足，导致的空，做不取特征emb处理；
2. 样本维度：
    1. 该特征为空，按照默认ID取emb；
    2. 若默认ID为空，则对最终的combiner_emb进行置0处理；

In [14]:
is_row_empty

<tf.Tensor: shape=(12,), dtype=bool, numpy=
array([False, False,  True, False, False,  True, False,  True,  True,
        True,  True,  True])>

In [15]:
is_row_empty = array_ops.tile(
    array_ops.reshape(is_row_empty, [-1, 1]),
    array_ops.stack([1, array_ops.shape(result)[1]]))
print(is_row_empty)
result = array_ops.where(
    is_row_empty, array_ops.zeros_like(result), result)

tf.Tensor(
[[False False False False False False False False]
 [False False False False False False False False]
 [ True  True  True  True  True  True  True  True]
 [False False False False False False False False]
 [False False False False False False False False]
 [ True  True  True  True  True  True  True  True]
 [False False False False False False False False]
 [ True  True  True  True  True  True  True  True]
 [ True  True  True  True  True  True  True  True]
 [ True  True  True  True  True  True  True  True]
 [ True  True  True  True  True  True  True  True]
 [ True  True  True  True  True  True  True  True]], shape=(12, 8), dtype=bool)


In [16]:
result

<tf.Tensor: shape=(12, 8), dtype=float32, numpy=
array([[1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1.],
       [1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1.],
       [0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 0., 0., 0.]], dtype=float32)>

#### 结论
某行样本，该特征值为空时
1. 按照default_id or 0获取 combiner emb；
2. 对该合并后的emb，进行置0

## 5.还原结果向量

In [17]:
rank_1 = array_ops.slice(math_ops.cast(original_shape, dtypes.int32), [0], [original_rank - 1])
rank_2 = array_ops.slice(array_ops.shape(result), [1], [-1])
target_shape = array_ops.concat([rank_1, rank_2], 0)
'前N维度为：{}，取向量维度为：{}，目标维度为：{}'.format(rank_1,rank_2,target_shape)

'前N维度为：[4 3]，取向量维度为：[8]，目标维度为：[4 3 8]'

In [18]:
final_result = array_ops.reshape(result,target_shape)
final_result

<tf.Tensor: shape=(4, 3, 8), dtype=float32, numpy=
array([[[1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0.]],

       [[1., 1., 1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0.]],

       [[1., 1., 1., 1., 1., 1., 1., 1.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]],

       [[0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0.]]], dtype=float32)>

In [19]:
# 获取维度
shape1 = (tensor_shape.Dimension(original_rank_dim) - 1).value
final_shape = tensor_shape.unknown_shape(shape1).concatenate(result.get_shape()[1:])
final_result.set_shape(final_shape)
'Tensor的维度：{}，最终维度：{}'.format(shape1,final_shape,final_result)

'Tensor的维度：2，最终维度：(None, None, 8)'