In [1]:
# Copyright 2021 NVIDIA Corporation. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

In [2]:
import os

import cudf
import rmm
from tensorflow import keras

import nvtabular as nvt

In [3]:
# from merlin_models.tensorflow.models.retrieval import YouTubeDNN

In [4]:
INPUT_DATA_DIR = os.environ.get(
    "INPUT_DATA_DIR", os.path.expanduser("./data/")
)
MODEL_BASE_DIR = os.environ.get(
    "MODEL_BASE_DIR", os.path.expanduser("./models/")
)

In [5]:
MODEL_NAME_TF = os.environ.get("MODEL_NAME_TF", "movielens_retrieval_tf")
MODEL_PATH_TEMP_TF = os.path.join(MODEL_BASE_DIR, MODEL_NAME_TF, "1/model.savedmodel")

def sampled_softmax_loss(y_true, y_pred):
    return tf.nn.sampled_softmax_loss(
        weights=item_embeddings,
        biases=tf.zeros((item_embeddings.shape[0],)),
        labels=y_true,
        inputs=y_pred,
        num_sampled=20,
        num_classes=item_embeddings.shape[0],
    )

model = keras.models.load_model(MODEL_PATH_TEMP_TF, custom_objects={"sampled_softmax_loss": sampled_softmax_loss})

In [15]:
item_embeddings = model.input_layer.embedding_tables["movie_ids"].numpy()
item_embeddings

array([[-1.82572678e-02, -7.22970217e-02, -1.66036021e-02, ...,
        -3.73224229e-01, -2.28869040e-02, -2.40866601e-01],
       [-1.19940303e-01,  1.65193588e-01, -1.10976525e-01, ...,
         3.20175916e-01, -2.75454193e-01,  8.93029273e-02],
       [-1.96990877e-01,  1.04352841e-02, -1.91766590e-01, ...,
         2.46757627e-01, -2.81902462e-01,  2.44285036e-02],
       ...,
       [ 3.97796324e-03, -1.07903769e-02,  6.37699105e-03, ...,
        -2.49420106e-03, -1.08170928e-02, -3.63471918e-03],
       [-1.08442246e-03,  3.19797080e-03, -2.84928526e-03, ...,
        -5.60011948e-04,  1.45310897e-03,  1.51821936e-03],
       [-7.40269490e-04, -2.22443463e-03,  3.06616239e-05, ...,
        -1.23117846e-02, -4.94043203e-03, -2.20183432e-02]], dtype=float32)

In [19]:
item_embeddings.shape

(62424, 128)

In [18]:
# Connect to Milvus
from milvus import Milvus, IndexType, MetricType, Status
milvus = Milvus(host='localhost', port='19530')

In [20]:
# Create a Milvus collection
param = {'collection_name':'movielens_items', 'dimension':128, 'index_file_size':1024, 'metric_type':MetricType.L2}
milvus.create_collection(param)

Status(code=0, message='Create collection successfully!')

In [None]:
# Create partitions

# milvus.create_partition('test01', 'tag01')

In [23]:
# Insert vectors into collection

item_ids = list(range(item_embeddings.shape[0]))
status = milvus.insert(collection_name='movielens_items', records=item_embeddings, ids=item_ids)

In [9]:
# TODO: Test a nearest neighbor query

In [31]:
# Search a collection
import random

q_records = item_embeddings[[random.choice(item_ids) for _ in range(5)]]

In [32]:
search_param = {'nprobe': 16}
milvus.search(collection_name='movielens_items', query_records=q_records, top_k=10, params=search_param)

(Status(code=0, message='Search vectors successfully!'),
 [
 [
 (id:48373, distance:0.0),
 (id:48373, distance:0.0),
 (id:60339, distance:0.006392437964677811),
 (id:60339, distance:0.006392437964677811),
 (id:56017, distance:0.006486162543296814),
 (id:56017, distance:0.006486162543296814),
 (id:57837, distance:0.006511373445391655),
 (id:57837, distance:0.006511373445391655),
 (id:34449, distance:0.00652097025886178),
 (id:34449, distance:0.00652097025886178)
 ],
 [
 (id:57167, distance:0.0),
 (id:57167, distance:0.0),
 (id:24828, distance:0.005453018471598625),
 (id:24828, distance:0.005453018471598625),
 (id:57032, distance:0.0056641618721187115),
 (id:57032, distance:0.0056641618721187115),
 (id:27950, distance:0.00574270635843277),
 (id:27950, distance:0.00574270635843277),
 (id:30908, distance:0.005797192454338074),
 (id:30908, distance:0.005797192454338074)
 ],
 [
 (id:28216, distance:0.0),
 (id:28216, distance:0.0),
 (id:19616, distance:0.0066065192222595215),
 (id:19616, dist