-
Notifications
You must be signed in to change notification settings - Fork 344
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[docs] Split the documentation of EmbeddingVariable. (#16)
- Loading branch information
1 parent
bd17771
commit a476f01
Showing
4 changed files
with
109 additions
and
228 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,96 +1,52 @@ | ||
# 基于Bloom Filter的特征准入 | ||
## Bloom Filter介绍 | ||
布隆过滤器实际上是由一个超长的二进制位数组和一系列的哈希函数组成。二进制位数组初始全部为0,当给定一个待查询的元素时,这个元素会被一系列哈希函数计算映射出一系列的值,所有的值在位数组的偏移量处置为1。 | ||
# EmbeddingVariable进阶功能:特征淘汰 | ||
## 功能介绍 | ||
对于一些对训练没有帮助的特征,我们需要将其淘汰以免影响训练效果,同时也能节约内存。在DeepRec中我们支持了特征淘汰功能,每次存ckpt的时候会触发特征淘汰,目前我们提供了两种特征淘汰的策略: | ||
|
||
如下图所示: | ||
![img_1.png](Feature-Eviction/img_1.png) | ||
如何判断某个元素是否在这个集合中呢? | ||
同样是这个元素经过哈希函数计算后得到所有的偏移位置,若这些位置全都为1,则判断这个元素在这个集合中,若有一个不为1,则判断这个元素不在这个集合中。 | ||
- 基于global step的特征淘汰功能:第一种方式是根据global step来判断一个特征是否要被淘汰。我们会给每一个特征分配一个时间戳,每次前向该特征被访问时就会用当前的global step更新其时间戳。在保存ckpt的时候判断当前的global step和时间戳之间的差距是否超过一个阈值,如果超过了则将这个特征淘汰(即删除)。这种方法的好处在于查询和更新的开销是比较小的,缺点是需要一个int64的数据来记录metadata,有额外的内存开销。 用户通过配置**steps_to_live**参数来配置淘汰的阈值大小。 | ||
- 基于l2 weight的特征淘汰: 在训练中如果一个特征的embedding值的L2范数越小,则代表这个特征在模型中的贡献越小,因此在存ckpt的时候淘汰淘汰L2范数小于某一阈值的特征。这种方法的好处在于不需要额外的metadata,缺点则是引入了额外的计算开销。用户通过配置**l2_weight_threshold**来配置淘汰的阈值大小。 | ||
|
||
Bloom Filter的优点包括: | ||
1. 可以在O(1)时间内判断一个元素是否属于一个集合 | ||
2. 不会出现漏判 (属于该集合的一定可以判断出来) | ||
## 使用方法 | ||
用户可以通过以下的方法使用特征淘汰功能 | ||
|
||
缺点包括: | ||
```python | ||
#使用global step特征淘汰 | ||
evict_opt = tf.GlobalStepEvict(steps_to_live=4000) | ||
|
||
1. 可能会出现误判(不属于该集合的可能会被误判为属于该集合) | ||
#使用l2 weight特征淘汰: | ||
evict_opt = tf.L2WeightEvict(l2_weight_threshold=1.0) | ||
|
||
## 基于Bloom Filter的特征准入功能 | ||
### 原理 | ||
而我们的特征准入功能是CBF (Counting Bloom Filter)实现的,CBF相比于基础的Bloom filter的不同之处在于它将比特位替换成了counter,因此拥有计数的功能,因此可以判断用于判断是否特征的频次是否已经超过某一阈值。 | ||
### 参数设置 | ||
主要设置的参数有四个:特征的数量 n、 hash函数的数量 k、以及counter的数量m、允许的错误率p。 | ||
ev_opt = tf.EmbeddingVariableOption(evict_option=evict_opt) | ||
|
||
这四个参数的关系可以参考下表: | ||
![img_2](Feature-Eviction/img_2.png) | ||
#通过get_embedding_variable接口使用 | ||
emb_var = tf.get_embedding_variable("var", embedding_dim = 16, ev_option=ev_opt) | ||
|
||
另外,当用户给定错误率p时以及特征数量n时,m和k可以通过如下计算得到: | ||
#通过sparse_column_with_embedding接口使用 | ||
from tensorflow.contrib.layers.python.layers import feature_column | ||
emb_var = feature_column.sparse_column_wth_embedding("var", ev_option=ev_opt) | ||
|
||
$$ | ||
m = -\frac{n\ln_{}{p}}{(\ln_{}{2})^2 } \\ | ||
\\ | ||
k = \frac{m}{n}\ln_{}{2} | ||
$$ | ||
### 使用方法 | ||
emb_var = tf.feature_column.categorical_column_with_embedding("var", ev_option=ev_opt) | ||
``` | ||
下面是特征淘汰接口的定义 | ||
```python | ||
evconfig = variables.EVConfig( bloom_filter_strategy = variables.BloomFilterStrategy( | ||
filter_freq=3, | ||
max_element_size = 2**30, | ||
false_positive_probability = 0.01, | ||
counter_type=dtypes.uint64)) | ||
embedding = tf.get_embedding_variable("var_dist", | ||
embedding_dim=6, | ||
initializer=tf.ones_initializer, | ||
ev=evconfig) | ||
@tf_export(v1=["GlobalStepEvict"]) | ||
class GlobalStepEvict(object): | ||
def __init__(self, | ||
steps_to_live = None): | ||
self.steps_to_live = steps_to_live | ||
|
||
@tf_export(v1=["L2WeightEvict"]) | ||
class L2WeightEvict(object): | ||
def __init__(self, | ||
l2_weight_threshold = -1.0): | ||
self.l2_weight_threshold = l2_weight_threshold | ||
if l2_weight_threshold <= 0 and l2_weight_threshold != -1.0: | ||
print("l2_weight_threshold is invalid, l2_weight-based eviction is disabled") | ||
``` | ||
参数解释: | ||
|
||
用户首先构造`BloomFilterStrategy`对象,配置最大的特征的数量、可接受的错误率以及存储frequency的数据类型。之后将该对象传入到`EVConfig`的构造函数并配置`filter_freq`即可。该功能默认关闭。 | ||
在save ckpt的时候不会记录特征对应的频次,即没有达到准入标准的特征在restore ckpt后会从0开始计数。 | ||
- `steps_to_live`:Global step特征淘汰的阈值,如果特征超过`steps_to_live`个global step没有被访问过,那么则淘汰 | ||
- `l2_weight_threshold`: L2 weight特征淘汰的阈值,如果特征的L2-norm小于阈值,则淘汰 | ||
|
||
**下面是一个使用示例:** | ||
功能开关: | ||
|
||
```python | ||
import tensorflow as tf | ||
import tensorflow | ||
import time | ||
import random | ||
from tensorflow.python.framework import ops | ||
from tensorflow.python.ops import variables | ||
from tensorflow.python.client import timeline | ||
from tensorflow.python.lib.io import file_io | ||
from tensorflow.python.ops import math_ops | ||
from tensorflow.python.framework import dtypes | ||
from tensorflow.python.training import saver as saver_module | ||
from tensorflow.python.training import incremental_saver as incr_saver_module | ||
from tensorflow.python.training import training_util | ||
from tensorflow.contrib.framework.python.framework import checkpoint_utils | ||
import numpy as np | ||
|
||
def main(unused_argv): | ||
evconfig = variables.EVConfig( filter_freq=3, | ||
bloom_filter_strategy = variables.BloomFilterStrategy( | ||
max_element_size = 2**30, | ||
false_positive_probability = 0.01), | ||
evict=evictconfig) | ||
embedding = tf.get_embedding_variable("var_dist", | ||
embedding_dim=6, | ||
initializer=tf.ones_initializer, | ||
ev=evconfig) | ||
ids = math_ops.cast([1,1,1,1,2,2,2,3,3,4], dtypes.int64) | ||
values = tf.nn.embedding_lookup(embedding, ids) | ||
fun = math_ops.multiply(values, 2.0, name='multiply') | ||
gs = training_util.get_or_create_global_step() | ||
loss1 = math_ops.reduce_sum(fun, name='reduce_sum') | ||
opt = tf.train.AdagradOptimizer(0.1) | ||
g_v = opt.compute_gradients(loss1) | ||
train_op = opt.apply_gradients(g_v) | ||
|
||
init = variables.global_variables_initializer() | ||
with tf.Session(config=config) as sess: | ||
sess.run([init]) | ||
print(sess.run([values, train_op])) | ||
|
||
|
||
if __name__=="__main__": | ||
tf.app.run() | ||
``` | ||
如果没有配置`GlobalStepEvict`以及`L2WeightEvict`、`steps_to_live`设置为`None`以及`l2_weight_threshold`设置小于0则功能关闭,否则功能打开。 |
Oops, something went wrong.