In [1]:
# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

In [2]:
%pip install pymilvus-orm==2.0.0rc1

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Note: you may need to restart the kernel to use updated packages.


In [3]:
import os

import cudf
import numpy as np
import rmm

from tensorflow import keras
from tqdm import tqdm

import nvtabular as nvt

In [4]:
# from merlin_models.tensorflow.models.retrieval import YouTubeDNN

In [5]:
INPUT_DATA_DIR = os.environ.get(
    "INPUT_DATA_DIR", os.path.expanduser("./data/")
)
MODEL_BASE_DIR = os.environ.get(
    "MODEL_BASE_DIR", os.path.expanduser("./models/")
)

In [6]:
MODEL_NAME_TF = os.environ.get("MODEL_NAME_TF", "movielens_retrieval_tf")
MODEL_PATH_TEMP_TF = os.path.join(MODEL_BASE_DIR, MODEL_NAME_TF, "1/model.savedmodel")

def sampled_softmax_loss(y_true, y_pred):
    return tf.nn.sampled_softmax_loss(
        weights=item_embeddings,
        biases=tf.zeros((item_embeddings.shape[0],)),
        labels=y_true,
        inputs=y_pred,
        num_sampled=20,
        num_classes=item_embeddings.shape[0],
    )

model = keras.models.load_model(MODEL_PATH_TEMP_TF, custom_objects={"sampled_softmax_loss": sampled_softmax_loss})





In [7]:
item_embeddings = model.input_layer.embedding_tables["movie_ids"].numpy()

In [8]:
item_embeddings.shape

(62424, 128)

In [9]:
# Connect to Milvus
import pymilvus_orm
from pymilvus_orm import schema, DataType, Collection

In [10]:
pymilvus_orm.connections.connect()

<pymilvus.client.stub.Milvus at 0x7ff23c387340>

In [11]:
dim = 128
default_fields = [
    schema.FieldSchema(name="item_id", dtype=DataType.INT64, is_primary=True),
    schema.FieldSchema(name="item_vector", dtype=DataType.FLOAT_VECTOR, dim=dim)
]
default_schema = schema.CollectionSchema(fields=default_fields, description="MovieLens item vectors")

collection_name = "movielens_retrieval_tf"
collection = Collection(name=collection_name, data=None, schema=default_schema)

In [12]:
collection.partitions

[{"name": "_default", "description": "", "num_entities": 0}]

In [13]:
item_ids = np.array(range(item_embeddings.shape[0]))
item_vectors = [item_embeddings[i]/ np.sqrt(np.sum(item_embeddings[i]**2)) for i in item_ids]

In [14]:
item_id_groups = np.array_split(item_ids, 1000)
item_vector_groups = np.array_split(item_vectors, 1000)

In [15]:
len(item_id_groups)

1000

In [16]:
# Insert vectors into collection
for ids, vectors in tqdm(zip(item_id_groups, item_vector_groups)):
    collection.insert([list(ids), list(vectors)])
pymilvus_orm.utility.get_connection().flush([collection_name])

1000it [00:11, 87.95it/s]


In [17]:
# Build index using product quantization
# default_index = {"index_type": "IVF_PQ", "params": {"nlist": 8192, "m": 64}, "metric_type": "IP"}
default_index = {"index_type": "IVF_FLAT", "params": {"nlist": 128}, "metric_type": "IP"}
collection.create_index(field_name="item_vector", index_params=default_index)

Status(code=0, message='')

In [18]:
collection.partitions

[{"name": "_default", "description": "", "num_entities": 62424}]

In [19]:
# collection.drop_index()
# collection.drop()

In [20]:
collection.load()

In [21]:
item_embeddings[0]

array([-0.31062913, -0.4289455 ,  0.24396195,  0.26398358, -0.3841409 ,
        0.34630343, -0.4548074 ,  0.26321203, -0.45085925, -0.3827003 ,
        0.26288456,  0.18450627,  0.31302097,  0.24035817, -0.8597852 ,
        0.26624614,  0.2516488 ,  0.23955749, -0.56014144,  0.30333632,
       -0.42345428,  0.30990636, -0.53592336, -0.4828622 ,  0.4123447 ,
        0.33423057, -0.35016668,  0.31573853, -0.36890745, -0.4432149 ,
        0.24892788,  0.3341653 , -0.5647062 , -0.43414497,  0.2669691 ,
       -0.37266985, -0.18514106,  0.39805707,  0.1711441 ,  0.33105138,
        0.2585441 ,  0.25489008,  0.21271987, -0.56644106,  0.34249297,
        0.31677136,  0.23902795,  0.4251639 ,  0.26552078,  0.32625797,
       -0.24719307,  0.34270865,  0.24732196, -0.40771168, -0.4178784 ,
       -0.78579783,  0.24509811, -0.53588665,  0.41048205, -0.33893058,
        0.28821918, -0.24842651,  0.26238173, -0.4987572 , -0.4171932 ,
       -0.36096543,  0.16542494,  0.22025682, -0.62787   ,  0.40

In [22]:
topK = 100
search_params = {"metric_type": "IP", "params": {"nprobe": 2048}}
res = collection.search([item_vectors[0]], "item_vector", search_params, topK)

In [23]:
res

<pymilvus_orm.search.SearchResult at 0x7ff23c3420d0>

In [24]:
for hits in res:
    for h in hits:
        print(f"{h.id}, {h.distance}, {h.score}")

0, 1.0, 1.0
51, 0.9666732549667358, 0.9666732549667358
703, 0.9582085609436035, 0.9582085609436035
3017, 0.9579051733016968, 0.9579051733016968
938, 0.9577622413635254, 0.9577622413635254
1280, 0.9568179845809937, 0.9568179845809937
5818, 0.9567961692810059, 0.9567961692810059
8358, 0.9565551280975342, 0.9565551280975342
8719, 0.9563549757003784, 0.9563549757003784
1129, 0.956012487411499, 0.956012487411499
2110, 0.9553396701812744, 0.9553396701812744
52948, 0.9548351168632507, 0.9548351168632507
27415, 0.9548328518867493, 0.9548328518867493
2039, 0.9532161951065063, 0.9532161951065063
9876, 0.9532091021537781, 0.9532091021537781
13752, 0.9527285099029541, 0.9527285099029541
8317, 0.9524883031845093, 0.9524883031845093
2086, 0.952184796333313, 0.952184796333313
5133, 0.9518344402313232, 0.9518344402313232
3662, 0.951751708984375, 0.951751708984375
4157, 0.9514551162719727, 0.9514551162719727
6806, 0.9511220455169678, 0.9511220455169678
5609, 0.9509130716323853, 0.9509130716323853
7497,

In [25]:
collection.release()

In [26]:
# collection.drop_index()
# collection.drop()