In [2]:
import os
import pandas as pd
import numpy as np
import requests
import warnings
import pickle
import time

import nmslib

from implicit.als import AlternatingLeastSquares

from rectools.metrics import Precision, Recall, MAP, calc_metrics
from rectools.models import PopularModel, RandomModel, ImplicitALSWrapperModel
from rectools import Columns
from rectools.dataset import Dataset
from rectools.models import ImplicitALSWrapperModel, LightFMWrapperModel

import matplotlib.pyplot as plt
import seaborn as sns

import matplotlib.pyplot as plt
from pathlib import Path
import typing as tp
from tqdm import tqdm

from lightfm import LightFM

from implicit.bpr import BayesianPersonalizedRanking

from implicit.lmf import LogisticMatrixFactorization

In [3]:
url = "https://storage.yandexcloud.net/itmo-recsys-public-data/kion_train.zip"

req = requests.get(url, stream=True)

with open('kion_train.zip', "wb") as fd:
    total_size_in_bytes = int(req.headers.get('Content-Length', 0))
    progress_bar = tqdm(desc='kion dataset download', total=total_size_in_bytes, unit='iB', unit_scale=True)
    for chunk in req.iter_content(chunk_size=2 ** 20):
        progress_bar.update(len(chunk))
        fd.write(chunk)

kion dataset download: 100%|█████████▉| 78.6M/78.8M [00:09<00:00, 10.9MiB/s]

In [4]:
!unzip kion_train.zip

Archive:  kion_train.zip
   creating: kion_train/
  inflating: kion_train/interactions.csv  
  inflating: __MACOSX/kion_train/._interactions.csv  
  inflating: kion_train/users.csv    
  inflating: __MACOSX/kion_train/._users.csv  
  inflating: kion_train/items.csv    
  inflating: __MACOSX/kion_train/._items.csv  


In [5]:
warnings.filterwarnings('ignore')

In [6]:
os.environ["OPENBLAS_NUM_THREADS"] = "1" 

# Load data

In [7]:
interactions = pd.read_csv('kion_train/interactions.csv')
users = pd.read_csv('kion_train/users.csv')
items = pd.read_csv('kion_train/items.csv')

# Users

In [8]:
users.head()

Unnamed: 0,user_id,age,income,sex,kids_flg
0,973171,age_25_34,income_60_90,М,1
1,962099,age_18_24,income_20_40,М,0
2,1047345,age_45_54,income_40_60,Ж,0
3,721985,age_45_54,income_20_40,Ж,0
4,704055,age_35_44,income_60_90,Ж,0


In [9]:
users.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 840197 entries, 0 to 840196
Data columns (total 5 columns):
 #   Column    Non-Null Count   Dtype 
---  ------    --------------   ----- 
 0   user_id   840197 non-null  int64 
 1   age       826102 non-null  object
 2   income    825421 non-null  object
 3   sex       826366 non-null  object
 4   kids_flg  840197 non-null  int64 
dtypes: int64(2), object(3)
memory usage: 32.1+ MB


# Items

In [10]:
items.head()

Unnamed: 0,item_id,content_type,title,title_orig,release_year,genres,countries,for_kids,age_rating,studios,directors,actors,description,keywords
0,10711,film,Поговори с ней,Hable con ella,2002.0,"драмы, зарубежные, детективы, мелодрамы",Испания,,16.0,,Педро Альмодовар,"Адольфо Фернандес, Ана Фернандес, Дарио Гранди...",Мелодрама легендарного Педро Альмодовара «Пого...,"Поговори, ней, 2002, Испания, друзья, любовь, ..."
1,2508,film,Голые перцы,Search Party,2014.0,"зарубежные, приключения, комедии",США,,16.0,,Скот Армстронг,"Адам Палли, Брайан Хаски, Дж.Б. Смув, Джейсон ...",Уморительная современная комедия на популярную...,"Голые, перцы, 2014, США, друзья, свадьбы, прео..."
2,10716,film,Тактическая сила,Tactical Force,2011.0,"криминал, зарубежные, триллеры, боевики, комедии",Канада,,16.0,,Адам П. Калтраро,"Адриан Холмс, Даррен Шалави, Джерри Вассерман,...",Профессиональный рестлер Стив Остин («Все или ...,"Тактическая, сила, 2011, Канада, бандиты, ганг..."
3,7868,film,45 лет,45 Years,2015.0,"драмы, зарубежные, мелодрамы",Великобритания,,16.0,,Эндрю Хэй,"Александра Риддлстон-Барретт, Джеральдин Джейм...","Шарлотта Рэмплинг, Том Кортни, Джеральдин Джей...","45, лет, 2015, Великобритания, брак, жизнь, лю..."
4,16268,film,Все решает мгновение,,1978.0,"драмы, спорт, советские, мелодрамы",СССР,,12.0,Ленфильм,Виктор Садовский,"Александр Абдулов, Александр Демьяненко, Алекс...",Расчетливая чаровница из советского кинохита «...,"Все, решает, мгновение, 1978, СССР, сильные, ж..."


In [11]:
items.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 15963 entries, 0 to 15962
Data columns (total 14 columns):
 #   Column        Non-Null Count  Dtype  
---  ------        --------------  -----  
 0   item_id       15963 non-null  int64  
 1   content_type  15963 non-null  object 
 2   title         15963 non-null  object 
 3   title_orig    11218 non-null  object 
 4   release_year  15865 non-null  float64
 5   genres        15963 non-null  object 
 6   countries     15926 non-null  object 
 7   for_kids      566 non-null    float64
 8   age_rating    15961 non-null  float64
 9   studios       1065 non-null   object 
 10  directors     14454 non-null  object 
 11  actors        13344 non-null  object 
 12  description   15961 non-null  object 
 13  keywords      15540 non-null  object 
dtypes: float64(3), int64(1), object(10)
memory usage: 1.7+ MB


# Interactions

In [12]:
interactions.head()

Unnamed: 0,user_id,item_id,last_watch_dt,total_dur,watched_pct
0,176549,9506,2021-05-11,4250,72.0
1,699317,1659,2021-05-29,8317,100.0
2,656683,7107,2021-05-09,10,0.0
3,864613,7638,2021-07-05,14483,100.0
4,964868,9506,2021-04-30,6725,100.0


In [13]:
interactions.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5476251 entries, 0 to 5476250
Data columns (total 5 columns):
 #   Column         Dtype  
---  ------         -----  
 0   user_id        int64  
 1   item_id        int64  
 2   last_watch_dt  object 
 3   total_dur      int64  
 4   watched_pct    float64
dtypes: float64(1), int64(3), object(1)
memory usage: 208.9+ MB


# Preprocess

In [14]:
Columns.Datetime = 'last_watch_dt'
interactions.drop(interactions[interactions[Columns.Datetime].str.len() != 10].index, inplace=True)
interactions[Columns.Datetime] = pd.to_datetime(interactions[Columns.Datetime], format='%Y-%m-%d')   #bringing the date to a single format

In [15]:
max_date = interactions[Columns.Datetime].max()
print("Max date of interactions: ", max_date)

kion dataset download: 100%|██████████| 78.8M/78.8M [00:20<00:00, 10.9MiB/s]

Max date of interactions:  2021-08-22 00:00:00


In [16]:
interactions[Columns.Weight] = np.where(interactions['watched_pct'] > 10, 3, 1)   #if the user viewed more than 10%, then the weight of interactions = 3, otherwise 1

In [17]:
train = interactions[interactions[Columns.Datetime] < max_date - pd.Timedelta(days=7)].copy()
test = interactions[interactions[Columns.Datetime] >= max_date - pd.Timedelta(days=7)].copy()   #data separation, test - data about the last week 

print(f"train: {train.shape}")
print(f"test: {test.shape}")

train: (4985269, 6)
test: (490982, 6)


In [18]:
cold_users = set(test[Columns.User]) - set(train[Columns.User])   #filter out cold users from the test
test.drop(test[test[Columns.User].isin(cold_users)].index, inplace=True)

# Prepare features

## User features

In [19]:
users.isnull().sum()

user_id         0
age         14095
income      14776
sex         13831
kids_flg        0
dtype: int64

In [20]:
users.fillna('Unknown', inplace=True)

In [21]:
users.nunique()

user_id     840197
age              7
income           7
sex              3
kids_flg         2
dtype: int64

In [22]:
users = users.loc[users[Columns.User].isin(train[Columns.User])].copy()

In [23]:
user_features_frames = []
for feature in ["sex", "age"]:    #as features, we take the sex  and age of the user
    feature_frame = users.reindex(columns=[Columns.User, feature])
    feature_frame.columns = ["id", "value"]
    feature_frame["feature"] = feature
    user_features_frames.append(feature_frame)
user_features = pd.concat(user_features_frames)
user_features.head()

Unnamed: 0,id,value,feature
0,973171,М,sex
1,962099,М,sex
3,721985,Ж,sex
4,704055,Ж,sex
5,1037719,М,sex


In [24]:
user_features.query(f"id == 973171")

Unnamed: 0,id,value,feature
0,973171,М,sex
0,973171,age_25_34,age


## Item features

In [25]:
items.isnull().sum()

item_id             0
content_type        0
title               0
title_orig       4745
release_year       98
genres              0
countries          37
for_kids        15397
age_rating          2
studios         14898
directors        1509
actors           2619
description         2
keywords          423
dtype: int64

In [26]:
items = items.loc[items[Columns.Item].isin(train[Columns.Item])].copy()

In [27]:
items.nunique()

item_id         15565
content_type        2
title           14937
title_orig      10377
release_year      105
genres           2720
countries         676
for_kids            2
age_rating          6
studios            38
directors        7809
actors          12671
description     15225
keywords        15123
dtype: int64

## Genre

In [28]:
items["genre"] = items["genres"].str.lower().str.replace(", ", ",", regex=False).str.split(",")       #explode genres to flatten table
genre_feature = items[["item_id", "genre"]].explode("genre")
genre_feature.columns = ["id", "value"]
genre_feature["feature"] = "genre"
genre_feature.head()

Unnamed: 0,id,value,feature
0,10711,драмы,genre
0,10711,зарубежные,genre
0,10711,детективы,genre
0,10711,мелодрамы,genre
1,2508,зарубежные,genre


## Content

In [29]:
content_feature = items.reindex(columns=[Columns.Item, "content_type"])
content_feature.columns = ["id", "value"]
content_feature["feature"] = "content_type"
item_features = pd.concat((genre_feature, content_feature))

In [30]:
item_features

Unnamed: 0,id,value,feature
0,10711,драмы,genre
0,10711,зарубежные,genre
0,10711,детективы,genre
0,10711,мелодрамы,genre
1,2508,зарубежные,genre
...,...,...,...
15958,6443,series,content_type
15959,2367,series,content_type
15960,10632,series,content_type
15961,4538,series,content_type


# Metrics

In [31]:
metrics_name = {
    'Precision': Precision,
    'Recall': Recall,
    'MAP': MAP,
}

metrics = {}
for metric_name, metric in metrics_name.items():
    for k in range(1, 11):
        metrics[f'{metric_name}@{k}'] = metric(k=k)

# Models

In [32]:
K_RECOS = 10
RANDOM_STATE = 42
NUM_THREADS = 8
N_FACTORS = (20, 30)
N_EPOCHS = (6, ) 
USER_ALPHA = [0, 0.1] 
ITEM_ALPHA = [0, 0.1] 
LEARNING_RATE = 0.05 

In [33]:
models = {
    'popular': PopularModel(),
}

In [34]:
implicit_models = {
    'ALS': AlternatingLeastSquares,
}

for implicit_name, implicit_model in implicit_models.items():
    for is_fitting_features in (True, False):
        for n_factors in N_FACTORS:
            models[f"{implicit_name}_n_factors:{n_factors}_is_fitting_features:{is_fitting_features}"] = (
                ImplicitALSWrapperModel(
                    model=implicit_model(
                        factors=n_factors, 
                        random_state=RANDOM_STATE, 
                        num_threads=NUM_THREADS,
                    ),
                    fit_features_together=is_fitting_features,
                )
            )



In [35]:
lightfm_losses = ('bpr', 'warp') 

for n_epoch in N_EPOCHS:
  for user_alpha in USER_ALPHA:
    for item_alpha in ITEM_ALPHA:
      for loss in lightfm_losses:
          for n_factors in N_FACTORS:
              models[f"LightFM_{loss}_n_factors:{n_factors}_user_alpha:{user_alpha}_item_alpha:{item_alpha}_n_epoch:{n_epoch}"] = LightFMWrapperModel(
                  LightFM(
                      no_components=n_factors, 
                      loss=loss, 
                      random_state=RANDOM_STATE,
                      learning_rate=LEARNING_RATE,
                      user_alpha=user_alpha,
                      item_alpha=item_alpha,
                  ),
                  epochs=n_epoch,
                  num_threads=NUM_THREADS,
              )

In [36]:
%%time
dataset = Dataset.construct(
    interactions_df=train,
    user_features_df=user_features,
    cat_user_features=["sex", "age"],
    item_features_df=item_features,
    cat_item_features=["genre", "content_type"],
)

CPU times: user 1.72 s, sys: 28.8 ms, total: 1.75 s
Wall time: 1.76 s


In [37]:
test_users = test[Columns.User].unique()

In [None]:
%%time
results = []
for model_name, model in models.items():
    print(f"Fitting model {model_name}...")
    model_quality = {'model': model_name}

    model.fit(dataset)
    recos = model.recommend(
        users=test_users,
        dataset=dataset,
        k=K_RECOS,
        filter_viewed=True,
    )
    metric_values = calc_metrics(metrics, recos, test, train)
    model_quality.update(metric_values)
    results.append(model_quality)

Fitting model popular...
Fitting model ALS_n_factors:20_is_fitting_features:True...
Fitting model ALS_n_factors:30_is_fitting_features:True...
Fitting model ALS_n_factors:20_is_fitting_features:False...
Fitting model ALS_n_factors:30_is_fitting_features:False...
Fitting model LightFM_bpr_n_factors:20_user_alpha:0_item_alpha:0_n_epoch:6...
Fitting model LightFM_bpr_n_factors:30_user_alpha:0_item_alpha:0_n_epoch:6...
Fitting model LightFM_warp_n_factors:20_user_alpha:0_item_alpha:0_n_epoch:6...
Fitting model LightFM_warp_n_factors:30_user_alpha:0_item_alpha:0_n_epoch:6...
Fitting model LightFM_bpr_n_factors:20_user_alpha:0_item_alpha:0.1_n_epoch:6...
Fitting model LightFM_bpr_n_factors:30_user_alpha:0_item_alpha:0.1_n_epoch:6...
Fitting model LightFM_warp_n_factors:20_user_alpha:0_item_alpha:0.1_n_epoch:6...
Fitting model LightFM_warp_n_factors:30_user_alpha:0_item_alpha:0.1_n_epoch:6...
Fitting model LightFM_bpr_n_factors:20_user_alpha:0.1_item_alpha:0_n_epoch:6...
Fitting model LightFM

In [None]:
df_quality = pd.DataFrame(results).T

df_quality.columns = df_quality.iloc[0]

df_quality.drop('model', inplace=True)

In [None]:
df_quality.style.highlight_max(color='lightgreen', axis=1)

model,popular,ALS_n_factors:20_is_fitting_features:True,ALS_n_factors:30_is_fitting_features:True,ALS_n_factors:20_is_fitting_features:False,ALS_n_factors:30_is_fitting_features:False,LightFM_bpr_n_factors:20_user_alpha:0_item_alpha:0_n_epoch:6,LightFM_bpr_n_factors:30_user_alpha:0_item_alpha:0_n_epoch:6,LightFM_warp_n_factors:20_user_alpha:0_item_alpha:0_n_epoch:6,LightFM_warp_n_factors:30_user_alpha:0_item_alpha:0_n_epoch:6,LightFM_bpr_n_factors:20_user_alpha:0_item_alpha:0.1_n_epoch:6,LightFM_bpr_n_factors:30_user_alpha:0_item_alpha:0.1_n_epoch:6,LightFM_warp_n_factors:20_user_alpha:0_item_alpha:0.1_n_epoch:6,LightFM_warp_n_factors:30_user_alpha:0_item_alpha:0.1_n_epoch:6,LightFM_bpr_n_factors:20_user_alpha:0.1_item_alpha:0_n_epoch:6,LightFM_bpr_n_factors:30_user_alpha:0.1_item_alpha:0_n_epoch:6,LightFM_warp_n_factors:20_user_alpha:0.1_item_alpha:0_n_epoch:6,LightFM_warp_n_factors:30_user_alpha:0.1_item_alpha:0_n_epoch:6,LightFM_bpr_n_factors:20_user_alpha:0.1_item_alpha:0.1_n_epoch:6,LightFM_bpr_n_factors:30_user_alpha:0.1_item_alpha:0.1_n_epoch:6,LightFM_warp_n_factors:20_user_alpha:0.1_item_alpha:0.1_n_epoch:6,LightFM_warp_n_factors:30_user_alpha:0.1_item_alpha:0.1_n_epoch:6
Precision@1,0.073308,0.085862,0.085862,0.062164,0.062164,0.025299,0.025722,0.080054,0.080469,0.0,0.0,0.0,5e-05,0.00492,0.002191,8e-06,2.5e-05,8e-06,0.0,0.062463,0.078552
Recall@1,0.038149,0.044848,0.044848,0.031754,0.031754,0.01361,0.013408,0.040546,0.040927,0.0,0.0,0.0,7e-06,0.002908,0.00116,0.0,1.2e-05,0.0,0.0,0.031665,0.04169
Precision@2,0.069263,0.073258,0.073258,0.055311,0.055311,0.019163,0.020076,0.067491,0.067939,4e-06,4e-06,4e-06,8.3e-05,0.004277,0.003456,4e-06,1.2e-05,4e-06,0.0,0.058634,0.069752
Recall@2,0.071011,0.074754,0.074754,0.055191,0.055191,0.020225,0.020416,0.066716,0.067415,0.0,2e-06,3e-06,5.2e-05,0.004653,0.003948,0.0,1.2e-05,0.0,0.0,0.058741,0.071784
Precision@3,0.066225,0.062087,0.062087,0.051309,0.051309,0.015843,0.016902,0.059462,0.059982,6e-06,8e-06,3e-06,8.9e-05,0.004423,0.003112,3e-06,1.4e-05,3e-06,0.0,0.05707,0.065929
Recall@3,0.1004,0.093394,0.093394,0.075896,0.075896,0.024397,0.025435,0.087277,0.088542,1e-06,5e-06,3e-06,0.000103,0.007097,0.005061,0.0,1.4e-05,0.0,0.0,0.086599,0.100448
Precision@4,0.059383,0.05461,0.05461,0.047447,0.047447,0.013871,0.014782,0.053272,0.053859,4e-06,1.9e-05,1.2e-05,8.5e-05,0.004207,0.002941,4e-06,1.7e-05,2e-06,2e-06,0.051148,0.057283
Recall@4,0.118878,0.108308,0.108308,0.092287,0.092287,0.027895,0.028815,0.103247,0.104795,1e-06,2.6e-05,1.4e-05,0.000142,0.008911,0.006309,1e-06,1.6e-05,0.0,4e-06,0.103107,0.114805
Precision@5,0.052735,0.048517,0.048517,0.043812,0.043812,0.012362,0.013324,0.048442,0.048968,3e-06,3.5e-05,1.7e-05,8.6e-05,0.00389,0.002561,3e-06,2e-05,2e-06,2e-06,0.041645,0.052307
Recall@5,0.130473,0.118726,0.118726,0.105005,0.105005,0.03069,0.031978,0.116114,0.117826,1e-06,5.9e-05,2.2e-05,0.000175,0.010052,0.006703,1e-06,3e-05,0.0,4e-06,0.104484,0.129606


## Approximate Nearest Neighbors

In [40]:
user_embeddings, item_embeddings = model.get_vectors(dataset)

user_embeddings.shape, item_embeddings.shape

((896791, 32), (15565, 32))

In [41]:
def augment_inner_product(factors):
    normed_factors = np.linalg.norm(factors, axis=1)
    max_norm = normed_factors.max()
    
    extra_dim = np.sqrt(max_norm ** 2 - normed_factors ** 2).reshape(-1, 1)
    augmented_factors = np.append(factors, extra_dim, axis=1)
    return max_norm, augmented_factors

In [42]:
print('Pre shape items: ', item_embeddings.shape)

max_norm, augmented_item_embeddings = augment_inner_product(item_embeddings)

print('Shape items after augmented: ', augmented_item_embeddings.shape)

Pre shape items:  (15565, 32)
Shape items after augmented:  (15565, 33)


In [43]:
extra_zero = np.zeros((user_embeddings.shape[0], 1))
augmented_user_embeddings = np.append(user_embeddings, extra_zero, axis=1)

print('Shape users after augmented: ', augmented_user_embeddings.shape)

Shape users after augmented:  (896791, 33)


### Examples of user embeddings and item embeddings

In [44]:
user_id = 30

print('User embeddings for ', user_id)
user_embeddings[user_id]

User embeddings for  30


array([-7.02635677e-36,  1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00])

In [45]:
print('User augmented embeddings for ', user_id)
augmented_user_embeddings[user_id]

User augmented embeddings for  30


array([-7.02635677e-36,  1.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00])

In [46]:
item_id = 0

print('Item embeddings for ', item_id)
item_embeddings[item_id]

Item embeddings for  0


array([ 1.00000000e+00, -7.31829641e-05,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00])

In [47]:
print('Item augmented embeddings for ', item_id)
augmented_item_embeddings[item_id]

Item augmented embeddings for  0


array([ 1.00000000e+00, -7.31829641e-05,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        0.00000000e+00,  0.00000000e+00,  0.00000000e+00,  0.00000000e+00,
        7.95802230e-05])

In [48]:
#set index parameters
M = 48
efC = 100

num_threads = 4
index_time_params = {'M': M, 'indexThreadQty': num_threads, 'efConstruction': efC, 'post' : 0}

print('Index-time parameters', index_time_params)

Index-time parameters {'M': 48, 'indexThreadQty': 4, 'efConstruction': 100, 'post': 0}


In [49]:
K = 10   #number of neighbors 

space_name = 'negdotprod'   #used for brute-force search

In [50]:
#intitialize the library, specify the space, the type of the vector and add data points 
index = nmslib.init(method='hnsw', space=space_name, data_type=nmslib.DataType.DENSE_VECTOR) 
index.addDataPointBatch(augmented_item_embeddings) 

15565

In [51]:
#create an index
start = time.time()
index_time_params = {'M': M, 'indexThreadQty': num_threads, 'efConstruction': efC}
index.createIndex(index_time_params) 
end = time.time() 

print('Index-time parameters', index_time_params)

print('Indexing time = %f' % (end-start))

Index-time parameters {'M': 48, 'indexThreadQty': 4, 'efConstruction': 100}
Indexing time = 23.432128


In [52]:
#setting query-time parameters
efS = 100
query_time_params = {'efSearch': efS}

print('Setting query-time parameters', query_time_params) 

index.setQueryTimeParams(query_time_params)

Setting query-time parameters {'efSearch': 100}


In [53]:
query_matrix = augmented_user_embeddings[:1000, :]

In [54]:
# Querying
query_qty = query_matrix.shape[0]
start = time.time() 
nbrs = index.knnQueryBatch(query_matrix, k = K, num_threads = num_threads)
end = time.time() 

print('kNN time total=%f (sec), per query=%f (sec), per query adjusted for thread number=%f (sec)' % 
      (end-start, float(end-start)/query_qty, num_threads*float(end-start)/query_qty)) 

kNN time total=0.099473 (sec), per query=0.000099 (sec), per query adjusted for thread number=0.000398 (sec)


In [55]:
nbrs[0]

(array([ 3507,   133, 13880,  1732,  1710,  3554,    93,  1801, 13842,
         2987], dtype=int32),
 array([0.00016064, 0.0001668 , 0.00017086, 0.00017226, 0.00017367,
        0.00017454, 0.00017495, 0.00017515, 0.00017515, 0.00017675],
       dtype=float32))