| 名称      | 含义 |
| ----------- | ----------- |
| n_user      | 用户数       |
| n_item   | 交互项目数        |
| train_data   | 训练集        |
| eval_data   | 验证集        |
| test_data   | 测试集        |
| adj_item   | 行:用户 列:用户产生的交互编号 每个用户采样20个交互        |
| n_classes   | 待分类的结果数量 例如性别就是男、女2类        |
| ratings_np   | 从ratings文件中读取的原始数据 numpy数组 每个用户一行 列分别是(用户id, 交互id, 性别)        |
| adj_user   | adj_user[i][j]: 与用户i产生过的某个用户id (没有顺序, 是随机采样的) |
| n_entity   | 知识库中实体的个数(头尾做union)        |
| n_relation   | 知识库中关系的个数(头尾做union)        |
| adj_entity   | adj_entity[i]: 与实体i有关系的实体集合(采样)        |
| adj_relation   | adj_relation[i]: 与实体i有关系的关系(采样)        |

# 数据预处理

In [1]:
import tensorflow as tf
import numpy as np

In [2]:
# 加载并预处理数据
def read_file(filename):
  with open('{}.npy'.format(filename), 'rb') as f:
    ret = np.load(f)
  return ret

n_user, n_item, n_entity, n_relation, n_classes = read_file('n_file')     
train_data = read_file('train_data')
eval_data = read_file('eval_data')
test_data = read_file('test_data')    
adj_relationadj_entity = read_file('adj_entity') 
adj_relation = read_file('adj_relation')
adj_item = read_file('adj_item')   
ratings_np = read_file('ratings_np') 
adj_user = read_file('adj_user')   

In [9]:
# mock命令行参数
arg = {}
arg['ratio'] = 1
arg['n_interact'] = 20
arg['n_neighbors'] = 20
arg['dataset'] = 'movie'
arg['mission'] = 'gender'
arg['neighbor_sample_size'] = 4
arg['dim'] = 32
arg['lr'] = 0.02
arg['l2_weight'] = 0.001
arg['n_iter'] = 2
arg['batch_size'] = 128

In [4]:
# 构建初始值
user_indicies = tf.placeholder(dtype=tf.int64, shape=[None], name='user_indicies')
item_indicies = tf.placeholder(dtype=tf.int64, shape=[None], name='item_indicies')
labels = tf.placeholder(dtype=tf.int64, shape=[None], name='labels')

In [5]:
def get_initializer():
  # 这个初始化器是用来使得每一层输出的方差应该尽量相等。
  return tf.contrib.layers.xavier_initializer()

# 构建模型

## 创建embedding矩阵作为用户特征输入 并构建训练参数矩阵

In [6]:
# 构建模型
dim = arg['dim']
# 用户特征embedding矩阵
user_emb_matrix = tf.get_variable(
            shape=[n_user, dim], inializer=get_initializer(), name='user_emb_matrix')
# 实体特征embedding矩阵
entity_emb_matrix = tf.get_variable(
            shape=[n_entity, dim], initializer=get_initializer(), name='entity_emb_matrix')
# 关系embedding矩阵
relation_emb_matrix = tf.get_variable(
    shape=[n_relation, dim], initializer=get_initializer(), name='relation_emb_matrix')
# 输出层参数w
output_weights = tf.get_variable(
    shape=[dim, n_classes], initializer=get_initializer(),name='output_weights1')
# 输出层偏置
output_bias = tf.get_variable(
    shape=[n_classes], initializer=tf.zeros_initializer(), name='output_bias1')

## 获取近邻实体与关系

In [7]:
seeds = tf.expand_dims(item_indicies, axis=1) # item_indicies:[n] -> seeds:[n,1] 
seeds.shape # 从行向量变成列向量了

### 收集算法
1. 对于用户u的交互实体集合$N_i(u)$中的每一个实体$e_i^0$,收集其近邻实体集合$N_i(e_i^0)$, 集合中的每一个实体直接与实体$e_i^0$相连。 则有关系r_{e_i^0,e_j^1}表示实体$e_i^0$和实体$e_j^1$之间的关系,其中$e^1_j\in N_i(e_i^0)$。 
2. 对于每一个$N_i(e_i^0)$中的实体$e_j^1$,进行近邻实体获取,得到$e_j^1$的近邻实体集合$N_i(e_j^1)$。 
3. 对$e_j^1$的近邻实体集合$N_i(e_j^1)$中的实体$e_k^2$,进行近邻实体获取,直到递归次数达到H次,递归收集过程终止。 


![迭代搜集近邻实体算法](assets/alg3.1_get_neighbor_entities.png)

In [17]:
adj_entity.shape

(37473, 4)

In [20]:
entities = [seeds] # entities:[1,n,1]
relations = []
n_iter = arg['n_iter'] # 迭代次数是超参
batch_size = arg['batch_size']
for i in range(n_iter):
    # tf.gather(a,b): 在tensor a中获取下标为 b(可以是列表)的tensor
    # tf.gather(adj_entity, entities[i]) -> [n,1,4]
    neighbor_entities = tf.reshape(tf.gather(adj_entity, entities[i]), [batch_size, -1])
    neighbor_relations = tf.reshape(tf.gather(adj_relation, entities[i]), [batch_size, -1])
    entities.append(neighbor_entities) # 每次迭代用上一次迭代结果的entities继续寻找邻居
    relations.append(neighbor_relations)

(?, 1, 4)
(128, ?, 4)


In [None]:
users = tf.expand_dims(users, axis=1)
# 获取user交互过的电影 n=10
interact_item = tf.reshape(tf.gather(adj_item, users), [batch_size, -1])

In [12]:
user_embedding = tf.nn.embedding_lookup(user_emb_matrix, user_indicies)


[<tf.Tensor 'ExpandDims:0' shape=(?, 1) dtype=int64>,
 <tf.Tensor 'Reshape:0' shape=(128, ?) dtype=int64>,
 <tf.Tensor 'Reshape_2:0' shape=(128, ?) dtype=int64>]