In [1]:
import tensorflow as tf
from tensorflow.python.ops import math_ops,array_ops,sparse_ops
from tensorflow.python.framework import tensor_shape,sparse_tensor,dtypes

In [2]:
def embedding_lookup_sparse(params,
                            sp_ids,
                            sp_weights,
                            partition_strategy="mod",
                            name=None,
                            combiner=None,
                            max_norm=None):
  """
  对给定ids和weights在Tensors的list中查找emb
  
  1. 假定给定的sp_ids没有行存在特征为空的情况；
  2. 在safe_embedding_lookup_sparse函数处理了空行，填补default_id or 0；保证1可靠；
  3. 保证sp_ids和sp_weights维度为2；
  4. Embeddings可沿着最后一维进行聚合；
  5. 同样假定id值在 [0, p0),参数总和；怀疑是分片emb导致p0大致相同，为该分片的具体大小。

  Returns:
    sparse_ids的dense tensor代表该行的emb；
    对该行查找所有id的meb，然后乘以对应weight，按照特定方式合并这些emb；

   例如：
      ```python
      [0, 0]: id 1, weight 2.0
      [0, 1]: id 3, weight 0.5
      [1, 0]: id 0, weight 1.0
      [2, 3]: id 1, weight 3.0
      ```
    with `combiner`="mean", then the output will be a 3x20 matrix where

      ```python
      output[0, :] = (params[1, :] * 2.0 + params[3, :] * 0.5) / (2.0 + 0.5)
      output[1, :] = (params[0, :] * 1.0) / 1.0
      output[2, :] = (params[1, :] * 3.0) / 3.0
      ```
    """
  pass

### check weights and combiner
safe_embedding_lookup_sparse 处理后断言weights和ids的值维度、dense维度和下标均相同

In [3]:
sp_ids = tf.sparse.SparseTensor(indices=[[0,0],[0,1],[0,2],[1,0],[1,1],[2,0],[2,1],[3,0]],
                                    values=[123,234,11,1245,124,2121,124,2121],
                                    dense_shape=[4,3])
sp_ids,tf.sparse.to_dense(sp_ids)

(<tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7f42efea5ed0>,
 <tf.Tensor: shape=(4, 3), dtype=int32, numpy=
 array([[ 123,  234,   11],
        [1245,  124,    0],
        [2121,  124,    0],
        [2121,    0,    0]], dtype=int32)>)

In [4]:
sp_weights = tf.sparse.SparseTensor(indices=[[0,0],[0,1],[0,2],[1,0],[1,1],[2,0],[2,1],[3,0]],
                                    values=[1.0,2,1,1,1,1,1,1],
                                    dense_shape=[4,3])
sp_weights,tf.sparse.to_dense(sp_weights)
# 构造缺少8989id的weights

(<tensorflow.python.framework.sparse_tensor.SparseTensor at 0x7f42efebb110>,
 <tf.Tensor: shape=(4, 3), dtype=float32, numpy=
 array([[1., 2., 1.],
        [1., 1., 0.],
        [1., 1., 0.],
        [1., 0., 0.]], dtype=float32)>)

### 对sp_ids进行去重，与取首列
1. segment_ids：每行样本几个非空值；
2. 对values产生集合，与其下标对应可以复原values。

In [5]:
segment_ids = sp_ids.indices[:, 0]
ids = sp_ids.values
ids, idx = array_ops.unique(ids)
segment_ids,ids,idx

(<tf.Tensor: shape=(8,), dtype=int64, numpy=array([0, 0, 0, 1, 1, 2, 2, 3])>,
 <tf.Tensor: shape=(6,), dtype=int32, numpy=array([ 123,  234,   11, 1245,  124, 2121], dtype=int32)>,
 <tf.Tensor: shape=(8,), dtype=int32, numpy=array([0, 1, 2, 3, 4, 5, 4, 5], dtype=int32)>)

### 对去重后的ids获取embeddings
> 这里跳转 01-02-embedding_lookup_and_transform.ipynb
1. 执行embedding_lookup(params,ids)，获取embeddings
2. 此处假定获取为如下

In [6]:
emb = []
for index, values in enumerate(ids):
    emb.append(array_ops.ones([1,8])*(index+1))
embeddings = tf.concat(emb, axis=0)
embeddings

<tf.Tensor: shape=(6, 8), dtype=float32, numpy=
array([[1., 1., 1., 1., 1., 1., 1., 1.],
       [2., 2., 2., 2., 2., 2., 2., 2.],
       [3., 3., 3., 3., 3., 3., 3., 3.],
       [4., 4., 4., 4., 4., 4., 4., 4.],
       [5., 5., 5., 5., 5., 5., 5., 5.],
       [6., 6., 6., 6., 6., 6., 6., 6.]], dtype=float32)>

### 假设sp_weights非空

In [7]:
if segment_ids.dtype != dtypes.int32:
    segment_ids = math_ops.cast(segment_ids, dtypes.int32)
weights = sp_weights.values
if weights.dtype != embeddings.dtype:
    weights = math_ops.cast(weights, embeddings.dtype)
embeddings = array_ops.gather(embeddings, idx)
embeddings,weights

(<tf.Tensor: shape=(8, 8), dtype=float32, numpy=
 array([[1., 1., 1., 1., 1., 1., 1., 1.],
        [2., 2., 2., 2., 2., 2., 2., 2.],
        [3., 3., 3., 3., 3., 3., 3., 3.],
        [4., 4., 4., 4., 4., 4., 4., 4.],
        [5., 5., 5., 5., 5., 5., 5., 5.],
        [6., 6., 6., 6., 6., 6., 6., 6.],
        [5., 5., 5., 5., 5., 5., 5., 5.],
        [6., 6., 6., 6., 6., 6., 6., 6.]], dtype=float32)>,
 <tf.Tensor: shape=(8,), dtype=float32, numpy=array([1., 2., 1., 1., 1., 1., 1., 1.], dtype=float32)>)

### Reshape weights to allow broadcast


In [8]:
shape_1 = array_ops.rank(embeddings) - 1
shape_1

<tf.Tensor: shape=(), dtype=int32, numpy=1>

In [9]:
ones = array_ops.fill(array_ops.expand_dims(shape_1, 0), 1)
ones

<tf.Tensor: shape=(1,), dtype=int32, numpy=array([1], dtype=int32)>

In [10]:
w_shape = array_ops.shape(weights)
bcast_weights_shape = array_ops.concat([w_shape, ones],0)
orig_weights_shape = weights.get_shape()
weights = array_ops.reshape(weights, bcast_weights_shape)
w_shape,bcast_weights_shape,orig_weights_shape,weights

(<tf.Tensor: shape=(1,), dtype=int32, numpy=array([8], dtype=int32)>,
 <tf.Tensor: shape=(2,), dtype=int32, numpy=array([8, 1], dtype=int32)>,
 TensorShape([8]),
 <tf.Tensor: shape=(8, 1), dtype=float32, numpy=
 array([[1.],
        [2.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]], dtype=float32)>)

In [11]:
# Set the weight shape, since after reshaping to bcast_weights_shape,
# the shape becomes None.
if embeddings.get_shape().ndims is not None:
    _shape = [1 for _ in range(embeddings.get_shape().ndims - 1)]
    weights.set_shape(orig_weights_shape.concatenate(_shape))
_shape,weights

([1], <tf.Tensor: shape=(8, 1), dtype=float32, numpy=
 array([[1.],
        [2.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.],
        [1.]], dtype=float32)>)

In [12]:
embeddings *= weights
embeddings

<tf.Tensor: shape=(8, 8), dtype=float32, numpy=
array([[1., 1., 1., 1., 1., 1., 1., 1.],
       [4., 4., 4., 4., 4., 4., 4., 4.],
       [3., 3., 3., 3., 3., 3., 3., 3.],
       [4., 4., 4., 4., 4., 4., 4., 4.],
       [5., 5., 5., 5., 5., 5., 5., 5.],
       [6., 6., 6., 6., 6., 6., 6., 6.],
       [5., 5., 5., 5., 5., 5., 5., 5.],
       [6., 6., 6., 6., 6., 6., 6., 6.]], dtype=float32)>

In [13]:
name = 'combiner'
combiner = "sum"

In [14]:
if combiner == "sum":
    embeddings = math_ops.segment_sum(embeddings, segment_ids, name=name)
elif combiner == "mean":
    embeddings = math_ops.segment_sum(embeddings, segment_ids)
    weight_sum = math_ops.segment_sum(weights, segment_ids)
    embeddings = math_ops.divide(embeddings, weight_sum, name=name)
elif combiner == "sqrtn":
    embeddings = math_ops.segment_sum(embeddings, segment_ids)
    weights_squared = math_ops.pow(weights, 2)
    weight_sum = math_ops.segment_sum(weights_squared, segment_ids)
    weight_sum_sqrt = math_ops.sqrt(weight_sum)
    embeddings = math_ops.divide(embeddings, weight_sum_sqrt, name=name)
else:
    assert False, "Unrecognized combiner"

In [15]:
embeddings

<tf.Tensor: shape=(4, 8), dtype=float32, numpy=
array([[ 8.,  8.,  8.,  8.,  8.,  8.,  8.,  8.],
       [ 9.,  9.,  9.,  9.,  9.,  9.,  9.,  9.],
       [11., 11., 11., 11., 11., 11., 11., 11.],
       [ 6.,  6.,  6.,  6.,  6.,  6.,  6.,  6.]], dtype=float32)>