-
Notifications
You must be signed in to change notification settings - Fork 49
/
index.py
284 lines (244 loc) · 9.96 KB
/
index.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
#
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Optional, Union
import numpy as np
import tensorflow as tf
from tensorflow.python import to_dlpack
import merlin.io
from merlin.core.dispatch import DataFrameType
from merlin.models.tf.blocks.core.base import Block, PredictionOutput
from merlin.models.tf.utils import tf_utils
from merlin.models.tf.utils.batch_utils import TFModelEncode
from merlin.models.utils.constants import MIN_FLOAT
from merlin.schema import Tags
@tf.keras.utils.register_keras_serializable(package="merlin_models")
class IndexBlock(Block):
def __init__(self, values: tf.Tensor, ids: Optional[tf.Tensor] = None, **kwargs):
super(IndexBlock, self).__init__(**kwargs)
self.values = tf.Variable(
values,
name="values",
trainable=False,
dtype=tf.float32,
validate_shape=False,
shape=tf.TensorShape([None, tf.shape(values)[-1]]),
)
if ids is not None:
id_dtype = ids.dtype
else:
id_dtype = tf.int64
self.ids = tf.Variable(
ids,
name="ids",
trainable=False,
dtype=id_dtype,
validate_shape=False,
shape=tf.TensorShape([None]),
)
@classmethod
def from_dataset(
cls, data: merlin.io.Dataset, check_unique_ids: bool = True, **kwargs
) -> "IndexBlock":
ids, values = cls.extract_ids_embeddings(data, check_unique_ids)
return cls(values=values, ids=ids, **kwargs)
@classmethod
def from_block(
cls, block: Block, data: merlin.io.Dataset, id_column: Optional[str] = None, **kwargs
) -> "IndexBlock":
"""Build candidates embeddings from applying `block` to a dataset of features `data`.
Parameters:
-----------
block: Block
The Block that returns embeddings from raw item features.
data: merlin.io.Dataset
Dataset containing raw item features.
id_column: Optional[str]
The candidates ids column name.
Note, this will be inferred automatically if the block contains
a schema with an item-id Tag.
"""
embedding_df = cls.get_candidates_dataset(block, data, id_column)
return cls.from_dataset(embedding_df, **kwargs)
@staticmethod
def _check_unique_ids(data: DataFrameType):
if data.index.to_series().nunique() != data.shape[0]:
raise ValueError("Please make sure that `data` contains unique indices")
@classmethod
def extract_ids_embeddings(cls, data: merlin.io.Dataset, check_unique_ids: bool = True):
if hasattr(data, "to_ddf"):
data = data.to_ddf()
if check_unique_ids:
cls._check_unique_ids(data=data)
values = tf_utils.df_to_tensor(data)
ids = tf_utils.df_to_tensor(data.index)
if len(ids.shape) == 2:
ids = tf.squeeze(ids)
return ids, values
@classmethod
def get_candidates_dataset(
cls, block: Block, data: merlin.io.Dataset, id_column: Optional[str] = None
):
if not id_column and getattr(block, "schema", None):
tagged = block.schema.select_by_tag(Tags.ITEM_ID)
if tagged.column_schemas:
id_column = tagged.first.name
model_encode = TFModelEncode(model=block, output_concat_func=np.concatenate)
data = data.to_ddf()
embedding_df = data.map_partitions(model_encode, filter_input_columns=[id_column]).compute()
embedding_df.set_index(id_column, inplace=True)
return embedding_df
def update_from_block(
self,
block: Block,
data: merlin.io.Dataset,
id_column: Optional[str] = None,
check_unique_ids: bool = True,
):
embedding_df = IndexBlock.get_candidates_dataset(block, data, id_column)
ids, embeddings = IndexBlock.extract_ids_embeddings(embedding_df, check_unique_ids)
self.update(embeddings, ids)
def update(self, values: tf.Tensor, ids: Optional[tf.Tensor] = None):
if len(tf.shape(values)) != 2:
raise ValueError(f"The candidates embeddings tensor must be 2D (got {values.shape}).")
_ids: tf.Tensor = ids if ids is not None else tf.range(values.shape[0])
self.ids.assign(_ids)
self.values.assign(values)
return self
def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor:
return self.values[inputs]
def to_dataset(self, gpu=True) -> merlin.io.Dataset:
if gpu:
import cudf
df = cudf.from_dlpack(to_dlpack(tf.convert_to_tensor(self.values)))
df.columns = [str(col) for col in list(df.columns)]
df.set_index(cudf.RangeIndex(0, self.values.shape[0]))
else:
import pandas as pd
df = pd.DataFrame(self.values.numpy())
df.columns = [str(col) for col in list(df.columns)]
df.set_index(pd.RangeIndex(0, self.values.shape[0]))
return merlin.io.Dataset(df)
@tf.keras.utils.register_keras_serializable(package="merlin_models")
class TopKIndexBlock(IndexBlock):
"""Top-K index to retrieve top-k scores and indices from an item block.
Parameters:
-----------
k: int
Number of top candidates to retrieve.
values: tf.Tensor
The pre-computed embedddings of candidates.
ids: tf.Tensor
The candidates ids.
"""
def __init__(self, k, values: tf.Tensor, ids: Optional[tf.Tensor] = None, **kwargs):
self._k = k
super(TopKIndexBlock, self).__init__(values, ids, **kwargs)
self.false_negatives_score = MIN_FLOAT
@classmethod
def from_block( # type: ignore
cls,
block: Block,
data: merlin.io.Dataset,
k: int = 20,
id_column: Optional[str] = None,
**kwargs,
) -> "TopKIndexBlock":
"""
class method to build candidates embeddings from
applying `block` to a dataset of features `data`
Parameters:
-----------
block: Block
The Block that returns embeddings from raw item features.
output_dim: int
The output dimension of `block`.
data: merlin.io.Dataset
Dataset containing raw item features.
k: int
Number of top candidates to retrieve.
Defaults to 20
id_column: Optional[str]
The candidates ids column name.
Note, this will be inferred automatically if the block contains
a schema with an item-id Tag.
"""
return super().from_block(block=block, data=data, id_column=id_column, k=k, **kwargs)
def call(self, inputs: tf.Tensor, k=None, **kwargs) -> Union[tf.Tensor, tf.Tensor]:
"""
Compute Top-k scores and related indices from query inputs
Parameters:
----------
inputs: tf.Tensor
Tensor of pre-computed query embeddings.
k: int
Number of top candidates to retrieve
Defaults to constructor `_k` parameter.
Returns
-------
top_scores, top_indices: tf.Tensor, tf.Tensor
2D Tensors with the scores for the top-k candidates and related ids.
"""
k = k if k is not None else self._k
scores = tf.matmul(inputs, self.values, transpose_b=True)
top_scores, top_indices = tf.math.top_k(scores, k=k)
top_indices = tf.gather(self.ids, top_indices)
return top_scores, top_indices
def call_outputs(
self, outputs: PredictionOutput, training=False, **kwargs
) -> "PredictionOutput":
"""
Retrieve top-k negative scores for evaluation.
Parameters
----------
predictions: tf.Tensor
Tensor of pre-computed positive scores.
If`training=True`, the first column of predictions is expected
to be positive scores and the remaining sampled negatives are ignored.
Returns
-------
targets, predictions: tf.Tensor, tf.Tensor
2D Tensors with the one-hot representation of true targets and
the scores for the top-k implicit negatives.
"""
queries = self.context["query"]
top_scores, top_ids = self(queries, k=self._k)
# remove accidental hits
top_scores, _ = tf_utils.rescore_false_negatives(
outputs.positive_item_ids, top_ids, top_scores, self.false_negatives_score
)
# Update top-k scores with positives
positive_scores = tf.reduce_sum(
queries * self.context["positive_candidates_embeddings"], axis=1, keepdims=True
)
predictions = tf.concat([positive_scores, top_scores], axis=-1)
targets = tf.concat(
[
tf.ones([tf.shape(predictions)[0], 1]),
tf.zeros([tf.shape(predictions)[0], self._k]),
],
axis=1,
)
# Sort the updated scores
predictions_sorted, targets_sorted, _ = tf_utils.extract_topk(self._k, predictions, targets)
label_relevant_counts = tf.ones([tf.shape(predictions)[0]])
return outputs.copy_with_updates(
predictions=predictions_sorted,
targets=targets_sorted,
label_relevant_counts=label_relevant_counts,
)
def compute_output_shape(self, input_shape):
batch_size = input_shape[0]
return tf.TensorShape((batch_size, self._k)), tf.TensorShape((batch_size, self._k))