In [1]:
# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

In [2]:
import os

import cudf
import numpy as np
import rmm
from tensorflow import keras

import nvtabular as nvt

In [3]:
# from merlin_models.tensorflow.models.retrieval import YouTubeDNN

In [4]:
INPUT_DATA_DIR = os.environ.get(
    "INPUT_DATA_DIR", os.path.expanduser("./data/")
)
MODEL_BASE_DIR = os.environ.get(
    "MODEL_BASE_DIR", os.path.expanduser("./models/")
)

In [5]:
MODEL_NAME_TF = os.environ.get("MODEL_NAME_TF", "movielens_retrieval_tf")
MODEL_PATH_TEMP_TF = os.path.join(MODEL_BASE_DIR, MODEL_NAME_TF, "1/model.savedmodel")

def sampled_softmax_loss(y_true, y_pred):
    return tf.nn.sampled_softmax_loss(
        weights=item_embeddings,
        biases=tf.zeros((item_embeddings.shape[0],)),
        labels=y_true,
        inputs=y_pred,
        num_sampled=20,
        num_classes=item_embeddings.shape[0],
    )

model = keras.models.load_model(MODEL_PATH_TEMP_TF, custom_objects={"sampled_softmax_loss": sampled_softmax_loss})





In [6]:
item_embeddings = model.input_layer.embedding_tables["movie_ids"].numpy()
item_embeddings

array([[ 0.3013805 ,  0.35265744,  0.29645905, ..., -0.5987685 ,
        -0.26150104, -0.37054107],
       [-0.46744588, -0.2959144 , -0.4388544 , ...,  0.36167163,
         0.30155638,  0.254408  ],
       [-0.28044578, -0.19978808, -0.25708318, ...,  0.33076942,
         0.27765957,  0.23243934],
       ...,
       [ 0.0078734 , -0.00585937,  0.01805443, ..., -0.01167713,
        -0.00263723, -0.00176259],
       [ 0.00395395, -0.00427187, -0.00336335, ...,  0.00960268,
        -0.00115871, -0.00155194],
       [ 0.00155976,  0.00864351, -0.00166825, ..., -0.00875954,
         0.00638987, -0.00645928]], dtype=float32)

In [7]:
item_embeddings.shape

(62424, 128)

In [8]:
# Connect to Milvus
import pymilvus_orm
from pymilvus_orm import schema, DataType, Collection

In [9]:
pymilvus_orm.connections.connect()

<pymilvus.client.stub.Milvus at 0x7f1b1c3376a0>

In [10]:
# collection.drop()

NameError: name 'collection' is not defined

In [11]:
dim = 128
default_fields = [
    schema.FieldSchema(name="item_id", dtype=DataType.INT64, is_primary=True),
    schema.FieldSchema(name="item_vector", dtype=DataType.FLOAT_VECTOR, dim=dim)
]
default_schema = schema.CollectionSchema(fields=default_fields, description="MovieLens item vectors")
collection = Collection(name="movielens_retrieval_tf", data=None, schema=default_schema)

In [12]:
# Insert vectors into collection
ids = list(range(item_embeddings.shape[0]))
item_vectors = [item_embeddings[i]/ np.sqrt(np.sum(item_embeddings[i]**2)) for i in ids]
collection.insert([ids, item_vectors])

<pymilvus_orm.search.MutationResult at 0x7f1b1c2e4e80>

In [13]:
# Build flat index
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "IP"}
collection.create_index(field_name="item_vector", index_params=default_index)
collection.load()

In [14]:
item_embeddings[0]

array([ 0.3013805 ,  0.35265744,  0.29645905, -0.34815103,  0.4222383 ,
       -0.5879    ,  0.3410872 ,  0.36045167,  0.32395712,  0.32554877,
       -0.6963399 , -0.6799426 ,  0.4677404 ,  0.3946024 ,  0.41948768,
        0.23198539, -0.40055826,  0.35271585,  0.34618428, -0.5401364 ,
       -0.4036751 , -0.77406543, -0.49029392,  0.26344872, -0.32161382,
        0.28308153,  0.2837937 , -0.64452857,  0.2568618 ,  0.35715148,
       -0.23408331, -0.2622291 , -0.34247452,  0.344981  , -0.36402032,
        0.33669725, -0.29191914,  0.35381636, -0.52548295, -0.24700928,
        0.34325278,  0.37425342,  0.39291283, -0.28088292, -0.25973767,
       -0.30222186, -0.5407103 , -0.275059  , -0.2634201 ,  0.32273275,
       -0.33639053,  0.35481033,  0.4484163 ,  0.33927503,  0.35650772,
        0.20732169, -0.23978533, -0.6183755 , -0.19136897,  0.4520049 ,
        0.23173772,  0.2923455 , -0.27322826, -0.22439322,  0.29219753,
        0.3009594 , -0.8256293 ,  0.23708996,  0.3194406 ,  0.28

In [15]:
topK = 100
search_params = {"metric_type": "IP", "params": {"nprobe": 10}}
res = collection.search([item_vectors[0]], "item_vector", search_params, topK)

In [16]:
res

<pymilvus_orm.search.SearchResult at 0x7f1f0402a490>

In [17]:
for hits in res:
    for h in hits:
        print(f"{h.id}, {h.distance}, {h.score}")

0, 1.0, 1.0
1681, 0.9064712524414062, 0.9064712524414062
14042, 0.9031546115875244, 0.9031546115875244
23279, 0.9025656580924988, 0.9025656580924988
8685, 0.9024295806884766, 0.9024295806884766
10083, 0.9021850824356079, 0.9021850824356079
10540, 0.9020233154296875, 0.9020233154296875
10483, 0.9018194675445557, 0.9018194675445557
5178, 0.9016884565353394, 0.9016884565353394
6252, 0.9016156196594238, 0.9016156196594238
5286, 0.9015621542930603, 0.9015621542930603
13098, 0.9012813568115234, 0.9012813568115234
10124, 0.9010730385780334, 0.9010730385780334
36555, 0.9009886980056763, 0.9009886980056763
22805, 0.9008907079696655, 0.9008907079696655
25287, 0.900621771812439, 0.900621771812439
37755, 0.9005725383758545, 0.9005725383758545
2869, 0.9000678062438965, 0.9000678062438965
13096, 0.8999995589256287, 0.8999995589256287
3141, 0.8999850749969482, 0.8999850749969482
9798, 0.8999186158180237, 0.8999186158180237
14592, 0.8998691439628601, 0.8998691439628601
30798, 0.8998188972473145, 0.899