In [1]:
import pandas as pd
import numpy as np

In [2]:
from pyspark.sql import functions as sf

from replay.session_handler import State, get_spark_session
from replay.experiment import Experiment
from replay.metrics import Recall, NDCG, MAP
from replay.models import KNN, PopRec
from replay.utils import get_log_info

In [3]:
session = get_spark_session(spark_memory=300)
spark = State(session).session
spark

22/03/27 11:09:27 WARN Utils: Your hostname, trans4rec resolves to a loopback address: 127.0.1.1; using 192.168.1.13 instead (on interface ens160)
22/03/27 11:09:27 WARN Utils: Set SPARK_LOCAL_IP if you need to bind to another address
22/03/27 11:09:27 WARN NativeCodeLoader: Unable to load native-hadoop library for your platform... using builtin-java classes where applicable
Using Spark's default log4j profile: org/apache/spark/log4j-defaults.properties
Setting default log level to "WARN".
To adjust logging level use sc.setLogLevel(newLevel). For SparkR, use setLogLevel(newLevel).
22/03/27 11:09:28 WARN SparkConf: Note that spark.local.dir will be overridden by the value set by the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone/kubernetes and LOCAL_DIRS in YARN).
22/03/27 11:09:28 WARN Utils: Service 'SparkUI' could not bind on port 4040. Attempting port 4041.
22/03/27 11:09:28 WARN Utils: Service 'SparkUI' could not bind on port 4041. Attempting port 4042.


### Предобработка данных

In [4]:
train = pd.read_csv('./data/hist_data.csv').sort_values(['buyer_id', 'created'])

In [5]:
def get_last_session(x):
    ind = x['pav_order_id'].drop_duplicates().index
    return x.loc[ind, 'pav_order_id'].values[-1]

res = train.groupby('buyer_id').apply(get_last_session)

In [6]:
X_train = train.loc[~train.pav_order_id.isin(res.values)].copy()
val = train.loc[train.pav_order_id.isin(res.values)].copy()
val.loc[:, 'weight'] = None

In [7]:
X_train.reset_index(drop=True, inplace=True)
val.reset_index(drop=True, inplace=True)

In [8]:
np.random.seed(0)

val_shuffled = val.sample(frac=1, random_state=0)
def get_eval_ind(x): 
    num_to_get = int(len(x) * 0.3)
    ind = np.random.choice(x.index, size=num_to_get, replace=False)
    return pd.Series(ind, name='index')

target_index = val_shuffled.groupby('pav_order_id').apply(get_eval_ind).values

In [9]:
X_val = val.loc[target_index].sort_values(['buyer_id', 'created']).reset_index(drop=True)
y_val = val.loc[~val.index.isin(target_index)].sort_values(['buyer_id', 'created']).reset_index(drop=True)

In [10]:
X_train.to_csv('./data/X_train.csv', index=False)
X_val.to_csv('./data/X_val.csv', index=False)
y_val.to_csv('./data/y_val.csv', index=False)

In [13]:
data_path = "./data/"

In [15]:
X_train = pd.read_csv(data_path+"X_train.csv")
X_val = pd.read_csv(data_path+"X_val.csv") 
y_val = pd.read_csv(data_path+"y_val.csv")

In [17]:
train_log = X_train[["pav_order_id", "created", "item_id"]]
train_log.columns = ["user_idx", "timestamp", "item_idx"]
train_log['relevance'] = 1.0

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  train_log['relevance'] = 1.0


In [18]:
train_log.to_parquet("./data/train_log.parquet")

In [19]:
test_log_pred = X_val[["pav_order_id", "created", "item_id"]]
test_log_pred.columns = ["user_idx", "timestamp", "item_idx"]
test_log_pred['relevance'] = 1.0

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  test_log_pred['relevance'] = 1.0


In [20]:
test_log_pred.to_parquet("./data/test_log_pred.parquet")

In [21]:
test_log_check = y_val[["pav_order_id", "created", "item_id"]]
test_log_check.columns = ["user_idx", "timestamp", "item_idx"]

In [22]:
test_log_check.to_parquet("./data/test_log_check.parquet")

In [26]:
test_csv = pd.read_csv('./data/test.csv')

In [27]:
test_csv = test_csv[["pav_order_id", "created", "item_id"]]
test_csv.columns = ["user_idx", "timestamp", "item_idx"]
test_csv.to_parquet("./data/test.parquet")

In [28]:
from gc import collect

del train_log, test_log_pred, test_log_check, test_csv
collect()

96

### Загрузка подготовленных данных

In [32]:
%%time

train_log = spark.read.parquet("./data/train_log.parquet")
test_log_pred = spark.read.parquet("./data/test_log_pred.parquet")
test_log_check = spark.read.parquet("./data/test_log_check.parquet")

CPU times: user 0 ns, sys: 2.15 ms, total: 2.15 ms
Wall time: 204 ms


In [33]:
test = spark.read.parquet("./data/test.parquet")

In [34]:
united_log = train_log.union(test_log_pred.withColumn('relevance', sf.lit(1)))

###  Настройка метрик

In [37]:
metrics_dict = {
        Recall(): 20,
        NDCG(): 20,
        MAP(): 20
    }
metrics_replay = Experiment(test_log_check, metrics_dict)

### Обучение модели KNN из RePlay

In [50]:
%%time

knn = KNN(num_neighbours=5)
knn.fit(united_log)

22/03/27 11:33:21 WARN CacheManager: Asked to cache already cached data.
22/03/27 11:33:21 WARN CacheManager: Asked to cache already cached data.


CPU times: user 17.8 ms, sys: 0 ns, total: 17.8 ms
Wall time: 418 ms


22/03/27 11:33:22 WARN CacheManager: Asked to cache already cached data.


In [53]:
knn_preds = knn.predict(test_log_pred, 20, filter_seen_items=True)
knn_preds.count()

                                                                                

1155684

In [54]:
metrics_replay.add_result('kNN_5', knn_preds)

                                                                                

In [55]:
metrics_replay.results

Unnamed: 0,MAP@20,NDCG@20,Recall@20
kNN_5,0.050006,0.121028,0.095974


### Предсказание

In [46]:
%%time

knn = KNN(num_neighbours=5)
knn.fit(united_log)
test_preds = knn.predict(test, 20, filter_seen_items=True)
test_preds.count()

22/03/27 11:31:32 WARN CacheManager: Asked to cache already cached data.
22/03/27 11:31:32 WARN CacheManager: Asked to cache already cached data.
22/03/27 11:31:33 WARN CacheManager: Asked to cache already cached data.

CPU times: user 63.6 ms, sys: 10.1 ms, total: 73.7 ms
Wall time: 11.5 s


                                                                                

1603306

In [47]:
test_preds_df = test_preds.toPandas()

                                                                                

In [49]:
(test_preds_df[["user_idx", "item_idx"]]
    .groupby("user_idx")
    .agg({'item_idx': lambda x: x.tolist()})
    .reset_index().to_csv("res_top.csv", index=None, header=["pav_order_id","preds"]))