In [1]:
import pandas as pd
import numpy as np
from datetime import datetime, timedelta
import os
from matplotlib import pyplot as plt
import sys
import torch
from lightgbm import LGBMRanker
import time

sys.path.append('/home/juravlik/PycharmProjects/kaggle_hnm_recsys/')

from scripts.first_stage_models.LastPurchasesPopularity import LastPurchasesPopularity
from scripts.utils import create_predictions_for_second_stage,\
prepare_dataset, create_labels_for_second_stage, combine_train_sets_and_labels

from scripts.images_scripts.similarity_search import SimilaritySearch
from scripts.images_scripts.index import FlatFaissIndex
from scripts.first_stage_models.LastPurchasesImageSimilarity import LastPurchasesImageSimilarity


pd.set_option('display.max_columns', 500)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DEVICE = torch.device('cpu')

PATH_TO_ARTICLES_INDEX = '/home/juravlik/PycharmProjects/kaggle_hnm_recsys/data/compressed_dataset/index/articles_index.parquet'
PATH_TO_INDEX = '/home/juravlik/PycharmProjects/kaggle_hnm_recsys/data/compressed_dataset/index/faiss.index'
PATH_TO_EMBEDDINGS = '/home/juravlik/PycharmProjects/kaggle_hnm_recsys/data/compressed_dataset/index/embeddings.pickle'
PATH_TO_ARTICLE_ID_INT = '../data/compressed_dataset/article_id_int.pickle'

In [3]:
df_transactions = pd.read_parquet('../data/compressed_dataset/transactions.parquet')
df_articles = pd.read_parquet('../data/compressed_dataset/articles.parquet')
df_customers = pd.read_parquet('../data/compressed_dataset/customers.parquet')

article_id_int = pd.read_pickle('../data/compressed_dataset/article_id_int.pickle')
int_article_id = pd.read_pickle('../data/compressed_dataset/int_article_id.pickle')

customer_id_int = pd.read_pickle('../data/compressed_dataset/customer_id_int.pickle')
int_customer_id = pd.read_pickle('../data/compressed_dataset/int_customer_id.pickle')

In [4]:
index = FlatFaissIndex(dimension=384,
                       device=DEVICE)
index.load_ranking_model(PATH_TO_INDEX)

In [5]:
Searcher = SimilaritySearch(
    index=index,
    parquet_file_with_articles_index=PATH_TO_ARTICLES_INDEX,
    pickle_file_with_embeddings=PATH_TO_EMBEDDINGS,
    path_article_id_int=PATH_TO_ARTICLE_ID_INT
)

In [6]:
%%time

Searcher.search_similar(
    target_int_article_id=1000,
    n_images=4
)

CPU times: user 34.7 ms, sys: 423 µs, total: 35.1 ms
Wall time: 33.3 ms


(array([0.9999997 , 0.96716285, 0.9669125 , 0.9665159 ], dtype=float32),
 array([ 1000, 26087, 43345, 38025]))

In [7]:
model = LastPurchasesImageSimilarity(
    searcher=Searcher,
    int_article_id=int_article_id,
    int_customer_id=int_customer_id

)
model.fit(df_transactions)

In [8]:
df_predict, df_submit = model.predict(df_customers['customer_id'].tolist()[:100],
                                          return_submit=True
                                         )

In [9]:
df_submit

Unnamed: 0,customer_id,prediction
0,0003e867a930d0d6842f923d6ba7c9b77aba33fe2a0fbf...,0837741002 0889722001 0823242001 0875736004 09...
1,00039306476aaf41a07fed942884f16b30abfa83a2a8be...,0855793001 0792301002 0881916001 0771602002 06...
2,00040239317e877c77ac6e79df42eb2633ad38fcac09fc...,0604295001 0763706005 0881570001 0875272002 08...
3,000058a12d5b43e67d225668fa1f8d618c13dc232df0ca...,0794321011 0730687001 0794321001 0763785001 09...
4,0001d44dbe7f6c4b35200abdb052c77a87596fe1bdcc37...,0865912001 0760735001 0721632001 0928040002 08...
...,...,...
95,00008469a21b50b3d147c97135e25b4201a8c58997f787...,0661162002 0680441019 0867948004 0661162004 05...
96,000097d91384a0c14893c09ed047a963c4fc6a5c021044...,0741762005 0884901001 0741762001 0812854003 07...
97,000383021a8cf9a542b9c855b42f99aa76374873661e83...,0643140001 0815650001 0710760001 0516000073 05...
98,0004518be81f6f0dd216dc5699016bc159ebb9dbd62a76...,0817086002 0889456001 0501722004 0616733003 06...


In [13]:
df_predict

Unnamed: 0,customer_id,article_id,score
0,86,89351,0.939401
1,86,100352,0.931560
2,86,86274,0.924446
3,80,92916,0.988128
4,80,78075,0.985729
...,...,...,...
994,54,28415,0.937954
995,54,42168,0.893162
996,83,48959,0.935450
997,83,48961,0.934343
