use loc[i, field] instead of iloc[i, field], iloc is very slow.

In [1]:
import sys
sys.path = ['/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/RecStudio/'] + sys.path
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
import random
import numpy as np
import pandas as pd
import cudf, itertools
import scipy.sparse as ssp
from functools import lru_cache, partial
from tqdm import tqdm, trange
from collections import Counter, defaultdict
import torch
import pickle


  from .autonotebook import tqdm as notebook_tqdm


In [4]:
def cast_dtype(df : pd.DataFrame):
    for k in df.columns:
        dt = type(df[k].iloc[0])
        if 'float' in str(dt):
            df[k] = df[k].astype('float32')
        elif 'int' in str(dt):
            df[k] = df[k].astype('int32')
        elif dt == list:
            dt_ = type(df.iloc[0][k][0])
            if 'float' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.float32))
            elif 'int' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.int32))

In [5]:
def _load_cache(path):
    with open(path, 'rb') as f:
        download_obj = pickle.load(f)
    return download_obj

In [6]:
def get_scores(merged_candidates_df, query_embeddings, product_embeddings):
    batch_size = 2048
    num_iter = (len(merged_candidates_df) - 1) // batch_size + 1
    score_list = []
    with torch.no_grad():
        for i in tqdm(range(num_iter)):
            st, ed = i * batch_size, (i + 1) * batch_size 
            batch_sess = merged_candidates_df.iloc[st : ed]
            batch_sess_id = torch.tensor(batch_sess['sess_id'].tolist(), dtype=torch.long, device=query_embeddings.device)
            batch_product_id = torch.tensor(batch_sess['dataset_id'].tolist(), dtype=torch.long, device=product_embeddings.device)
            query_emb = query_embeddings[batch_sess_id]
            product_emb = product_embeddings[batch_product_id]
            batch_score = (query_emb * product_emb).sum(dim=-1) 
            score_list.append(batch_score.cpu())
        score_list = torch.cat(score_list, dim=0).cpu().tolist()
        return score_list 

In [7]:
def normalize_scores(score_df, score_name, normalized_score_name):
    # score_df_g = cudf.from_pandas(score_df)
    score_df['exp_score'] = np.exp(score_df[score_name].to_numpy())
    scores_sum = score_df[['sess_id', 'exp_score']].groupby('sess_id').sum()
    scores_sum.reset_index(inplace=True)
    scores_sum = scores_sum.sort_values(by=['sess_id'], ascending=True)
    scores_sum.reset_index(drop=True, inplace=True)
    scores_sum.rename(columns={'exp_score' : 'score_sum'}, inplace=True)

    merged_score_df = score_df.merge(scores_sum, how='left', left_on=['sess_id'], right_on=['sess_id'])
    merged_score_df = merged_score_df.sort_values(by=['sess_id', 'product'])
    merged_score_df.reset_index(drop=True, inplace=True)
    
    # merged_score_df = merged_score_df_g.to_pandas(merged_score_df_g)
    score_df[normalized_score_name] = merged_score_df['exp_score'] / merged_score_df['score_sum']
    score_df['exp_score'] = merged_score_df['exp_score']
    score_df['score_sum'] = merged_score_df['score_sum']

    # del scores_sum_g
    # del merged_score_df_g  

# Merge valid score

In [None]:
FIELD_NAME = 'gru4rec_scores'

In [2]:
merged_candidates_feature_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/XGBoost/candidates/merged_candidates_2_feature.parquet'
valid_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/task1_data/task13_4_task1_valid_sessions.csv'
test_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data/sessions_test_task1.csv'

In [None]:
@lru_cache(maxsize=1)
def read_merged_candidates_feature():
    return pd.read_parquet(merged_candidates_feature_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_valid_sessions():
    return pd.read_csv(valid_sessions_path)

@lru_cache(maxsize=1)
def read_test_sessions():
    return pd.read_csv(test_sessions_path)

In [8]:
DE_product_embeddings_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/query_embeddings/GRU4Rec_Next/kdd_cup_2023_DE/product_embeddings_2023-04-29-16-27-43.pt'
DE_valid_embeddings_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/query_embeddings/GRU4Rec_Next/kdd_cup_2023_DE/valid_embeddings_2023-04-29-16-27-55.pt'
JP_product_embeddings_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/query_embeddings/GRU4Rec_Next/kdd_cup_2023_JP/product_embeddings_2023-04-29-16-29-08.pt'
JP_valid_embeddings_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/query_embeddings/GRU4Rec_Next/kdd_cup_2023_JP/valid_embeddings_2023-04-29-16-29-19.pt'
UK_product_embeddings_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/query_embeddings/GRU4Rec_Next/kdd_cup_2023_UK/product_embeddings_2023-04-29-16-30-21.pt'
UK_valid_embeddings_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/query_embeddings/GRU4Rec_Next/kdd_cup_2023_UK/valid_embeddings_2023-04-29-16-30-32.pt'

In [9]:
DE_dataset_cache = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/.recstudio/cache/c76eddf0a07106ffcce7ce8010856a3b'
JP_dataset_cache = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/.recstudio/cache/81a71d0a18766af84b3beab69bf53e69'
UK_dataset_cache = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/.recstudio/cache/250dbc09c30162452e00486051e47756'

In [10]:
DE_train_dataset, DE_valid_dataset = _load_cache(DE_dataset_cache)
JP_train_dataset, JP_valid_dataset = _load_cache(JP_dataset_cache)
UK_train_dataset, UK_valid_dataset = _load_cache(UK_dataset_cache)
locale_map = {
    'DE' : DE_train_dataset.field2token2idx['product_id'], 
    'JP' : JP_train_dataset.field2token2idx['product_id'], 
    'UK' : UK_train_dataset.field2token2idx['product_id']
    }

In [11]:
merged_candidates = read_merged_candidates_feature()
valid_sessions = read_valid_sessions()
EMBED_DIM = 128
merged_candidates.sort_values(by=['sess_id', 'product'], inplace=True)
merged_candidates.reset_index(drop=True, inplace=True)

In [12]:
# sess embeddings 
valid_DE_query_emb = torch.load(DE_valid_embeddings_path, map_location='cpu')
valid_JP_query_emb = torch.load(JP_valid_embeddings_path, map_location='cpu')
valid_UK_query_emb = torch.load(UK_valid_embeddings_path, map_location='cpu')
valid_query_embeddings = torch.empty(len(valid_sessions), EMBED_DIM)
valid_query_embeddings[(valid_sessions[valid_sessions['locale'] == 'DE'].index).tolist()] = valid_DE_query_emb
valid_query_embeddings[(valid_sessions[valid_sessions['locale'] == 'JP'].index).tolist()] = valid_JP_query_emb
valid_query_embeddings[(valid_sessions[valid_sessions['locale'] == 'UK'].index).tolist()] = valid_UK_query_emb

In [13]:
# product_embeddings 
DE_product_emb = torch.load(DE_product_embeddings_path, map_location='cpu')
JP_product_emb = torch.load(JP_product_embeddings_path, map_location='cpu')
UK_product_emb = torch.load(UK_product_embeddings_path, map_location='cpu')
product_embeddings = torch.cat([DE_product_emb, JP_product_emb, UK_product_emb], dim=0)

In [14]:
merged_candidates_gru4rec = merged_candidates[['sess_id', 'sess_locale', 'product']].copy()

In [15]:
DE_product_list, DE_id_list = list(zip(*locale_map['DE'].items()))
JP_product_list, JP_id_list = list(zip(*locale_map['JP'].items()))
UK_product_list, UK_id_list = list(zip(*locale_map['UK'].items()))
product_list = list(DE_product_list) + list(JP_product_list) + list(UK_product_list)
id_list = list(DE_id_list) + list(JP_id_list) + list(UK_id_list)
locale_list = ['DE'] * len(DE_id_list) + ['JP'] * len(JP_id_list) + ['UK'] * len(UK_id_list)
product_id_df = pd.DataFrame({'locale' : locale_list, 'product' : product_list, 'dataset_id' : id_list})

In [16]:
# merged_candidates_gru4rec_g = cudf.from_pandas(merged_candidates_gru4rec)
# product_id_df_g = cudf.from_pandas(product_id_df)

In [17]:
# merged_candidates_gru4rec_score_g = merged_candidates_gru4rec_g.merge(product_id_df_g, how='left', left_on=['sess_locale', 'product'], right_on=['locale', 'product'])
# merged_candidates_gru4rec_score_g['dataset_id'] = merged_candidates_gru4rec_score_g['dataset_id'].fillna(0)
# merged_candidates_gru4rec_score_g.drop(columns=['locale'], inplace=True)
# merged_candidates_gru4rec_score_g = merged_candidates_gru4rec_score_g.sort_values(by=['sess_id', 'product'])
# merged_candidates_gru4rec_score_g.reset_index(drop=True, inplace=True)
# merged_candidates_gru4rec_score = merged_candidates_gru4rec_score_g.to_pandas()

In [None]:
merged_candidates_gru4rec_score = merged_candidates_gru4rec.merge(product_id_df, how='left', left_on=['sess_locale', 'product'], right_on=['locale', 'product'])
merged_candidates_gru4rec_score['dataset_id'] = merged_candidates_gru4rec_score['dataset_id'].fillna(0)
merged_candidates_gru4rec_score.drop(columns=['locale'], inplace=True)
merged_candidates_gru4rec_score = merged_candidates_gru4rec_score.sort_values(by=['sess_id', 'product'])
merged_candidates_gru4rec_score.reset_index(drop=True, inplace=True)

In [18]:
# del merged_candidates_gru4rec_g
# del product_id_df_g
# del merged_candidates_gru4rec_score_g

In [19]:
locale_offset = {'DE' : 0, 'JP' : len(DE_product_list), 'UK' : len(DE_product_list) + len(JP_product_list)}
for locale in ['DE', 'JP', 'UK']:
    merged_candidates_gru4rec_score['dataset_id'][merged_candidates_gru4rec_score['sess_locale'] == locale] = \
        merged_candidates_gru4rec_score['dataset_id'][merged_candidates_gru4rec_score['sess_locale'] == locale] + locale_offset[locale]

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  merged_candidates_gru4rec_score['dataset_id'][merged_candidates_gru4rec_score['sess_locale'] == locale] = \


In [20]:
valid_query_embeddings = valid_query_embeddings.to('cuda:0')
product_embeddings = product_embeddings.to('cuda:0')

In [21]:
merged_candidates_gru4rec_score[FIELD_NAME] = get_scores(merged_candidates_gru4rec_score, valid_query_embeddings, product_embeddings)

100%|██████████| 41215/41215 [00:28<00:00, 1469.66it/s]


In [22]:
normalize_scores(merged_candidates_gru4rec_score, FIELD_NAME, 'normalized_'+FIELD_NAME)

In [23]:
merged_candidates[FIELD_NAME] = merged_candidates_gru4rec_score[FIELD_NAME]
merged_candidates['normalized_'+FIELD_NAME] = merged_candidates_gru4rec_score['normalized_'+FIELD_NAME]

In [None]:
cast_dtype(merged_candidates)
merged_candidates.to_parquet(merged_candidates_feature_path, engine='pyarrow')

In [24]:
merged_candidates

Unnamed: 0,sess_id,sess_locale,product,target,sasrec_scores_2,sasrec_normalized_scores_2,product_freq,gru4rec_scores,gru4rec_normalized_scores,sess_avg_price,product_price,gru4rec_scores_2,gru4rec_normalized_scores_2
0,0,UK,B000OPPVCS,0.0,11.972421,2.286162e-04,104,6.484859,0.000038,7.388571,7.280000,12.291418,5.528012e-05
1,0,UK,B000V599Y2,0.0,13.152878,7.443427e-04,37,4.342063,0.000004,7.388571,5.200000,12.142086,4.761183e-05
2,0,UK,B0018HH444,0.0,5.606023,3.928400e-07,7,3.220763,0.000001,7.388571,15.800000,8.919555,1.897524e-06
3,0,UK,B0079JI4DU,0.0,0.000000,1.443945e-09,67,0.000000,0.0,7.388571,22.097065,0.000000,2.537897e-10
4,0,UK,B0079JI4EY,0.0,0.000000,1.443945e-09,77,0.000000,0.0,7.388571,22.097065,0.000000,2.537897e-10
...,...,...,...,...,...,...,...,...,...,...,...,...,...
84407334,361580,DE,B0BB7XV97M,0.0,9.117821,6.077226e-05,56,9.268379,0.000014,32.424000,47.990002,14.038595,8.992638e-05
84407335,361580,DE,B0BB7YSRBX,0.0,9.163816,6.363281e-05,58,7.047796,0.000002,32.424000,43.990002,13.342258,4.482001e-05
84407336,361580,DE,B0BB7ZMGY8,0.0,11.256460,5.158278e-04,452,9.359167,0.000015,32.424000,41.990002,12.778135,2.549625e-05
84407337,361580,DE,B0BD4CP7N3,0.0,-3.778687,1.523433e-10,1,-0.593306,0.0,32.424000,24.990000,-3.986487,1.335653e-12


In [26]:
# verify gru4rec scores
merged_candidates[merged_candidates['sess_id'] == 150001].sort_values(by=['sasrec_scores_2'], ascending=False)[
    ['sess_locale', 'product', 'normalized_sasrec_scores_2', 'sasrec_scores_2', 'normalized_gru4rec_scores', 'gru4rec_scores']
].iloc[:15]

Unnamed: 0,sess_locale,product,sasrec_normalized_scores_2,sasrec_scores_2,gru4rec_normalized_scores,gru4rec_scores,gru4rec_normalized_scores_2,gru4rec_scores_2
35011843,DE,B07TP1HY5B,0.341675,16.614416,0.199157,16.564659,0.013633,16.40517
35011902,DE,B08VN4VTMC,0.258014,16.33357,0.384811,17.223316,0.477046,19.960262
35011977,DE,B0BBRKH55W,0.146798,15.769612,0.025441,14.506912,0.129742,18.658195
35011978,DE,B0BBRPFP73,0.102788,15.413227,0.084324,15.705232,0.054098,17.783445
35011979,DE,B0BBRR84KK,0.054031,14.770121,0.117207,16.034508,0.137245,18.714417
35011888,DE,B08HHXDHRJ,0.04879,14.66808,0.063105,15.415358,0.020194,16.798056
35011786,DE,B011KJ6WLU,0.013794,13.404789,0.010225,13.595442,0.05854,17.862368
35011807,DE,B0793LCHRD,0.007168,12.750231,0.032911,14.764377,0.019142,16.744513
35011968,DE,B0B7NCTHWX,0.006028,12.577011,0.014645,13.954664,0.036256,17.383266
35011790,DE,B01MSZ7WK7,0.003257,11.961242,0.01858,14.192669,0.011652,16.248108


In [31]:
# verify gru4rec scores
merged_candidates[merged_candidates['sess_id'] == 150001].sort_values(by=['sasrec_scores_2'], ascending=False)[['sess_locale', 'product', 'normalized_sasrec_scores_2', 'sasrec_scores_2', 'normalized_gru4rec_scores', 'gru4rec_scores']].iloc[:15]

Unnamed: 0,sess_locale,product,sasrec_normalized_scores_2,sasrec_scores_2,gru4rec_normalized_scores,gru4rec_scores
35011843,DE,B07TP1HY5B,0.341675,16.614416,0.199157,16.564659
35011902,DE,B08VN4VTMC,0.258014,16.33357,0.384811,17.223316
35011977,DE,B0BBRKH55W,0.146798,15.769612,0.025441,14.506912
35011978,DE,B0BBRPFP73,0.102788,15.413227,0.084324,15.705232
35011979,DE,B0BBRR84KK,0.054031,14.770121,0.117207,16.034508
35011888,DE,B08HHXDHRJ,0.04879,14.66808,0.063105,15.415358
35011786,DE,B011KJ6WLU,0.013794,13.404789,0.010225,13.595442
35011807,DE,B0793LCHRD,0.007168,12.750231,0.032911,14.764377
35011968,DE,B0B7NCTHWX,0.006028,12.577011,0.014645,13.954664
35011790,DE,B01MSZ7WK7,0.003257,11.961242,0.01858,14.192669


# Merge test score

In [None]:
FIELD_NAME = 'gru4rec_scores'

In [36]:
merged_candidates_feature_test_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/XGBoost/candidates/merged_candidates_test_2_feature.parquet'
test_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data/sessions_test_task1.csv'

In [37]:
DE_product_embeddings_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/query_embeddings/GRU4Rec_Next/kdd_cup_2023_DE/product_embeddings_2023-04-29-16-27-43.pt'
DE_test_embeddings_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/query_embeddings/GRU4Rec_Next/kdd_cup_2023_DE/predict_embeddings_2023-04-29-16-28-26.pt'
JP_product_embeddings_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/query_embeddings/GRU4Rec_Next/kdd_cup_2023_JP/product_embeddings_2023-04-29-16-29-08.pt'
JP_test_embeddings_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/query_embeddings/GRU4Rec_Next/kdd_cup_2023_JP/predict_embeddings_2023-04-29-16-29-45.pt'
UK_product_embeddings_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/query_embeddings/GRU4Rec_Next/kdd_cup_2023_UK/product_embeddings_2023-04-29-16-30-21.pt'
UK_test_embeddings_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/candidates/query_embeddings/GRU4Rec_Next/kdd_cup_2023_UK/predict_embeddings_2023-04-29-16-30-59.pt'

In [38]:
@lru_cache(maxsize=1)
def read_merged_candidates_feature_test():
    return pd.read_parquet(merged_candidates_feature_test_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_test_sessions():
    return pd.read_csv(test_sessions_path)

In [39]:
merged_candidates = read_merged_candidates_feature_test()
test_sessions = read_test_sessions()
EMBED_DIM = 128
merged_candidates.sort_values(by=['sess_id', 'product'], inplace=True)
merged_candidates.reset_index(drop=True, inplace=True)

In [40]:
# sess embeddings 
test_DE_query_emb = torch.load(DE_test_embeddings_path, map_location='cpu')
test_JP_query_emb = torch.load(JP_test_embeddings_path, map_location='cpu')
test_UK_query_emb = torch.load(UK_test_embeddings_path, map_location='cpu')
test_query_embeddings = torch.cat(
    [test_DE_query_emb[test_sessions['locale'] == 'DE'], test_JP_query_emb[test_sessions['locale'] == 'JP'], test_UK_query_emb[test_sessions['locale'] == 'UK']],
    dim=0)

In [41]:
# product embeddings 
DE_product_emb = torch.load(DE_product_embeddings_path, map_location='cpu')
JP_product_emb = torch.load(JP_product_embeddings_path, map_location='cpu')
UK_product_emb = torch.load(UK_product_embeddings_path, map_location='cpu')
product_embeddings = torch.cat([DE_product_emb, JP_product_emb, UK_product_emb], dim=0)

In [42]:
DE_dataset_cache = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/.recstudio/cache/c76eddf0a07106ffcce7ce8010856a3b'
JP_dataset_cache = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/.recstudio/cache/81a71d0a18766af84b3beab69bf53e69'
UK_dataset_cache = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/.recstudio/cache/250dbc09c30162452e00486051e47756'
DE_train_dataset, DE_valid_dataset = _load_cache(DE_dataset_cache)
JP_train_dataset, JP_valid_dataset = _load_cache(JP_dataset_cache)
UK_train_dataset, UK_valid_dataset = _load_cache(UK_dataset_cache)
locale_map = {
    'DE' : DE_train_dataset.field2token2idx['product_id'], 
    'JP' : JP_train_dataset.field2token2idx['product_id'], 
    'UK' : UK_train_dataset.field2token2idx['product_id']
    }

In [43]:
DE_product_list, DE_id_list = list(zip(*locale_map['DE'].items()))
JP_product_list, JP_id_list = list(zip(*locale_map['JP'].items()))
UK_product_list, UK_id_list = list(zip(*locale_map['UK'].items()))
product_list = list(DE_product_list) + list(JP_product_list) + list(UK_product_list)
id_list = list(DE_id_list) + list(JP_id_list) + list(UK_id_list)
locale_list = ['DE'] * len(DE_id_list) + ['JP'] * len(JP_id_list) + ['UK'] * len(UK_id_list)
product_id_df = pd.DataFrame({'locale' : locale_list, 'product' : product_list, 'dataset_id' : id_list})

In [44]:
# merged_candidates_g = cudf.from_pandas(merged_candidates)
# product_id_df_g = cudf.from_pandas(product_id_df)

In [45]:
# merged_candidates_gru4rec_g = merged_candidates_g.merge(product_id_df_g, how='left', left_on=['sess_locale', 'product'], right_on=['locale', 'product'])
# merged_candidates_gru4rec_g['dataset_id'] = merged_candidates_gru4rec_g['dataset_id'].fillna(0)
# merged_candidates_gru4rec_g.drop(columns=['locale'], inplace=True)
# merged_candidates_gru4rec_g = merged_candidates_gru4rec_g.sort_values(by=['sess_id', 'product'])
# merged_candidates_gru4rec_g.reset_index(drop=True, inplace=True)
# merged_candidates_gru4rec = merged_candidates_gru4rec_g.to_pandas()

In [None]:
merged_candidates_gru4rec = merged_candidates.merge(product_id_df, how='left', left_on=['sess_locale', 'product'], right_on=['locale', 'product'])
merged_candidates_gru4rec['dataset_id'] = merged_candidates_gru4rec['dataset_id'].fillna(0)
merged_candidates_gru4rec.drop(columns=['locale'], inplace=True)
merged_candidates_gru4rec = merged_candidates_gru4rec.sort_values(by=['sess_id', 'product'])
merged_candidates_gru4rec.reset_index(drop=True, inplace=True)
# merged_candidates_gru4rec = merged_candidates_gru4rec_g.to_pandas()

In [46]:
# del merged_candidates_g
# del product_id_df_g
# del merged_candidates_gru4rec_g

In [48]:
locale_offset = {'DE' : 0, 'JP' : len(DE_product_list), 'UK' : len(DE_product_list) + len(JP_product_list)}
for locale in ['DE', 'JP', 'UK']:
    merged_candidates_gru4rec['dataset_id'][merged_candidates_gru4rec['sess_locale'] == locale] = \
        merged_candidates_gru4rec['dataset_id'][merged_candidates_gru4rec['sess_locale'] == locale] + locale_offset[locale]

A value is trying to be set on a copy of a slice from a DataFrame

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  merged_candidates_gru4rec['dataset_id'][merged_candidates_gru4rec['sess_locale'] == locale] = \


In [49]:
merged_candidates_gru4rec[FIELD_NAME] = get_scores(merged_candidates_gru4rec, test_query_embeddings, product_embeddings)

100%|██████████| 33901/33901 [00:46<00:00, 729.17it/s] 


In [50]:
normalize_scores(merged_candidates_gru4rec, FIELD_NAME, 'normalized_'+FIELD_NAME)

In [51]:
merged_candidates_gru4rec.drop(columns=['dataset_id', 'exp_score', 'score_sum'], inplace=True)

In [59]:
cast_dtype(merged_candidates_gru4rec)
merged_candidates_gru4rec.to_parquet(merged_candidates_feature_test_path, engine='pyarrow')

In [60]:
merged_candidates

Unnamed: 0,sess_id,sess_locale,product,sasrec_scores_2,sasrec_normalized_scores_2
0,0,DE,4088833651,0.000000,2.975813e-09
1,0,DE,B000H6W2GW,0.000000,2.975813e-09
2,0,DE,B000JG2RAG,7.665308,6.347557e-06
3,0,DE,B000RYSOUW,-2.951060,1.555882e-10
4,0,DE,B000UGZVQM,3.977920,1.589257e-07
...,...,...,...,...,...
69428426,316970,UK,B0BJCTH4NH,11.327528,1.041200e-04
69428427,316970,UK,B0BJTQQWLG,5.604142,3.403292e-07
69428428,316970,UK,B0BJV3RL4H,9.146974,1.176336e-05
69428429,316970,UK,B0BK7SPC84,-10.383047,3.879279e-14


In [58]:
merged_candidates_gru4rec

Unnamed: 0,sess_id,sess_locale,product,sasrec_scores_2,sasrec_normalized_scores_2,gru4rec_scores,gru4rec_normalized_scores
0,0,DE,4088833651,0.000000,2.975813e-09,0.000000,0.0
1,0,DE,B000H6W2GW,0.000000,2.975813e-09,0.000000,0.0
2,0,DE,B000JG2RAG,7.665308,6.347557e-06,8.104032,0.000005
3,0,DE,B000RYSOUW,-2.951060,1.555882e-10,-2.857798,0.0
4,0,DE,B000UGZVQM,3.977920,1.589257e-07,4.688567,0.0
...,...,...,...,...,...,...,...
69428426,316970,UK,B0BJCTH4NH,11.327528,1.041200e-04,10.629994,0.000382
69428427,316970,UK,B0BJTQQWLG,5.604142,3.403292e-07,6.052083,0.000004
69428428,316970,UK,B0BJV3RL4H,9.146974,1.176336e-05,7.667603,0.00002
69428429,316970,UK,B0BK7SPC84,-10.383047,3.879279e-14,-6.356799,0.0


In [63]:
merged_candidates_gru4rec[(merged_candidates_gru4rec['sess_id'] == 100005)].sort_values(by='normalized_sasrec_scores_2', ascending=False)[['sess_locale', 'product', 'normalized_sasrec_scores_2', 'sasrec_scores_2', 'normalized_gru4rec_scores', 'gru4rec_scores']][:15]

Unnamed: 0,sess_locale,product,sasrec_normalized_scores_2,sasrec_scores_2,gru4rec_normalized_scores,gru4rec_scores
21586793,DE,B07TRQH45S,0.157271,14.288859,0.252197,13.23623
21586762,DE,B01N4ND1T2,0.12286,14.041931,0.073287,12.000405
21586937,DE,B0B6WNV91T,0.109754,13.929132,0.063878,11.863
21586840,DE,B092CMLDHW,0.084313,13.66542,0.036893,11.314045
21586866,DE,B09F2J37V4,0.066638,13.430163,0.02189,10.792058
21586887,DE,B09NQ7T1D2,0.044435,13.024924,0.016852,10.530479
21586870,DE,B09F66MWVX,0.03542,12.79815,0.015094,10.420308
21586934,DE,B0B62K5H9P,0.023618,12.39289,0.021667,10.781809
21586846,DE,B0953JXVQ2,0.02131,12.290046,0.012762,10.25253
21586824,DE,B08H1YNK3P,0.018914,12.170765,0.028193,11.045082
