In [1]:
import sys
sys.path = ['/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/RecStudio/'] + sys.path
import random
import numpy as np
import pandas as pd
import cudf, itertools
import scipy.sparse as ssp
from functools import lru_cache, partial
from tqdm import tqdm, trange
from collections import Counter, defaultdict
import torch
import pickle

In [2]:
def cast_dtype(df : pd.DataFrame, columns=None):
    if columns is None:
        columns = df.columns
    for k in columns:
        dt = type(df[k].iloc[0])
        if 'float' in str(dt):
            df[k] = df[k].astype('float32')
        elif 'int' in str(dt):
            df[k] = df[k].astype('int32')
        elif dt == list:
            dt_ = type(df.iloc[0][k][0])
            if 'float' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.float32))
            elif 'int' in str(dt_):
                df[k] = df[k].apply(lambda x : np.array(x, dtype=np.int32))

In [19]:
def get_scores(merged_candidates_df, product_id_name, query_embeddings, product_embeddings):
    batch_size = 10000
    num_iter = (len(merged_candidates_df) - 1) // batch_size + 1
    score_list = []
    with torch.no_grad():
        for i in tqdm(range(num_iter)):
            st, ed = i * batch_size, (i + 1) * batch_size 
            batch_sess = merged_candidates_df.iloc[st : ed]
            batch_sess_id = torch.tensor(batch_sess['sess_id'].tolist(), dtype=torch.long, device=query_embeddings.device)
            batch_product_id = torch.tensor(batch_sess[product_id_name].tolist(), dtype=torch.long, device=product_embeddings.device)
            query_emb = query_embeddings[batch_sess_id]
            product_emb = product_embeddings[batch_product_id]
            batch_score = (query_emb * product_emb).sum(dim=-1) 
            score_list.append(batch_score.cpu())
        score_list = torch.cat(score_list, dim=0).cpu().tolist()
        return score_list 

In [4]:
def normalize_scores(score_df, score_name, normalized_score_name):
    # score_df_g = cudf.from_pandas(score_df)
    score_df['exp_score'] = np.exp(score_df[score_name].to_numpy())
    scores_sum = score_df[['sess_id', 'exp_score']].groupby('sess_id').sum()
    scores_sum.reset_index(inplace=True)
    scores_sum = scores_sum.sort_values(by=['sess_id'], ascending=True)
    scores_sum.reset_index(drop=True, inplace=True)
    scores_sum.rename(columns={'exp_score' : 'score_sum'}, inplace=True)

    merged_score_df = score_df.merge(scores_sum, how='left', left_on=['sess_id'], right_on=['sess_id'])
    merged_score_df = merged_score_df.sort_values(by=['sess_id', 'product'])
    merged_score_df.reset_index(drop=True, inplace=True)
    
    # merged_score_df = merged_score_df_g.to_pandas(merged_score_df_g)
    score_df[normalized_score_name] = merged_score_df['exp_score'] / merged_score_df['score_sum']
    score_df['exp_score'] = merged_score_df['exp_score']
    score_df['score_sum'] = merged_score_df['score_sum']

    # del scores_sum_g
    # del merged_score_df_g 

# Merge valid score

In [5]:
FIELD_NAME = 'roberta_scores'

In [6]:
merged_candidates_feature_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/XGBoost/candidates_phase2/merged_candidates_150_feature.parquet'
valid_sessions_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/data_for_recstudio/task1_data/task13_4_task1_valid_sessions_phase2.csv'
product_data_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/raw_data/products_train.csv'

In [7]:
@lru_cache(maxsize=1)
def read_merged_candidates_feature():
    return pd.read_parquet(merged_candidates_feature_path, engine='pyarrow')

@lru_cache(maxsize=1)
def read_valid_sessions():
    return pd.read_csv(valid_sessions_path)

@lru_cache(maxsize=1)
def read_product_data():
    return pd.read_csv(product_data_path)

In [8]:
merged_candidates = read_merged_candidates_feature()
valid_sessions = read_valid_sessions()
product_data = read_product_data()

In [9]:
merged_candidates_product = merged_candidates[['sess_id', 'sess_locale', 'product']]
merged_candidates_product

Unnamed: 0,sess_id,sess_locale,product
0,0,DE,355165591X
1,0,DE,3833237058
2,0,DE,B00CIXSI6U
3,0,DE,B00NVDOWUW
4,0,DE,B00NVDP3ZU
...,...,...,...
78842194,261815,UK,B0BCX524Y6
78842195,261815,UK,B0BCX6QB4L
78842196,261815,UK,B0BFPJYXQL
78842197,261815,UK,B0BH3X67S3


In [10]:
product_index = product_data[['id', 'locale']]
product_index['product_index'] = product_index.index + 1

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  product_index['product_index'] = product_index.index + 1


In [11]:
# merged_candidates_product_g = cudf.from_pandas(merged_candidates_product)
# product_index_g = cudf.from_pandas(product_index)

In [12]:
# merged_candidates_product_index_g = merged_candidates_product_g.merge(product_index_g, how='left', left_on=['sess_locale', 'product'], right_on=['locale', 'id'])
# merged_candidates_product_index_g = merged_candidates_product_index_g.sort_values(by=['sess_id', 'product'])
# merged_candidates_product_index_g.reset_index(drop=True, inplace=True)
# assert len(merged_candidates_product_index_g) == len(merged_candidates_product_g)
# merged_candidates_product_index_g.drop(columns=['id', 'locale'], inplace=True)
# merged_candidates_product_index_g['product_index'] = merged_candidates_product_index_g['product_index'].fillna(0)
# merged_candidates_product_index = merged_candidates_product_index_g.to_pandas()

In [13]:
merged_candidates_product_index = merged_candidates_product.merge(product_index, how='left', left_on=['sess_locale', 'product'], right_on=['locale', 'id'])
merged_candidates_product_index = merged_candidates_product_index.sort_values(by=['sess_id', 'product'])
merged_candidates_product_index.reset_index(drop=True, inplace=True)
assert len(merged_candidates_product_index) == len(merged_candidates_product)
merged_candidates_product_index.drop(columns=['id', 'locale'], inplace=True)
merged_candidates_product_index['product_index'] = merged_candidates_product_index['product_index'].fillna(0)

In [14]:
# del merged_candidates_product_g
# del product_index_g
# del merged_candidates_product_index_g

In [15]:
roberta_product_embeddings_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/text_method/phase2_task1_xlm_roberta_results/results/item_reps/item.npy'
roberta_valid_embeddings_path = '/root/autodl-tmp/xiaolong/WorkSpace/Amazon-KDDCUP-23/text_method/phase2_task1_xlm_roberta_results/valid_results/valid_query_reps/query.npy'

In [16]:
roberta_product_embeddings = np.load(roberta_product_embeddings_path)
roberta_valid_embeddings = np.load(roberta_valid_embeddings_path)

In [17]:
roberta_product_embeddings = torch.from_numpy(roberta_product_embeddings)
roberta_valid_embeddings = torch.from_numpy(roberta_valid_embeddings)
roberta_product_embeddings = torch.cat([torch.tensor([[0.0] * roberta_product_embeddings.shape[1]]), roberta_product_embeddings], dim=0)

In [20]:
roberta_valid_embeddings = roberta_valid_embeddings.to('cuda:7')
roberta_product_embeddings = roberta_product_embeddings.to('cuda:7')

In [21]:
merged_candidates_product_index[FIELD_NAME] = get_scores(merged_candidates_product_index, 'product_index', roberta_valid_embeddings, roberta_product_embeddings)

  batch_product_id = torch.tensor(batch_sess[product_id_name].tolist(), dtype=torch.long, device=product_embeddings.device)
100%|██████████| 7885/7885 [02:36<00:00, 50.53it/s]


In [24]:
normalize_scores(merged_candidates_product_index, FIELD_NAME, 'normalized_'+FIELD_NAME)

In [25]:
assert len(merged_candidates) == len(merged_candidates_product_index)
merged_candidates[FIELD_NAME] = merged_candidates_product_index[FIELD_NAME]
merged_candidates['normalized_'+FIELD_NAME] = merged_candidates_product_index['normalized_'+FIELD_NAME]

In [30]:
cast_dtype(merged_candidates, [FIELD_NAME, 'normalized_'+FIELD_NAME])
merged_candidates.to_parquet(merged_candidates_feature_path, engine='pyarrow')

In [26]:
roberta_product_embeddings

tensor([[ 0.0000,  0.0000,  0.0000,  ...,  0.0000,  0.0000,  0.0000],
        [-0.0163,  0.3441, -0.1964,  ..., -0.5166,  0.1799,  0.1612],
        [-0.0200,  0.1681, -0.2068,  ..., -0.1228,  0.1850,  0.3288],
        ...,
        [ 0.2957, -0.0038, -0.0214,  ..., -0.4679,  0.0706, -0.1150],
        [ 0.0821,  0.3195,  0.0097,  ..., -0.4260,  0.0666, -0.0825],
        [-0.0979,  0.2652, -0.2106,  ..., -0.0616,  0.0766,  0.0603]])

In [27]:
merged_candidates

Unnamed: 0,sess_id,sess_locale,product,target,sess_avg_price,product_price,product_freq,sasrec_scores_3,normalized_sasrec_scores_3,sasrec_scores_2,...,normalized_co_graph_counts_0,co_graph_counts_1,normalized_co_graph_counts_1,co_graph_counts_2,normalized_co_graph_counts_2,cos_text_bert_scores,text_bert_scores,normalized_text_bert_scores,roberta_scores,normalized_roberta_scores
0,0,DE,355165591X,0.0,43.256542,8.990000,51.0,2.230508,7.658405e-09,0.512931,...,0.000000,0.000000,0.00000,0.0,0.000000,0.903757,378.286041,1.296655e-08,276.525787,7.975509e-07
1,0,DE,3833237058,0.0,43.256542,22.000000,84.0,9.605231,1.221631e-05,9.325538,...,0.002217,0.090909,0.00083,0.0,0.000000,0.921604,387.624756,1.474268e-04,284.460052,2.226209e-03
2,0,DE,B00CIXSI6U,0.0,43.256542,6.470000,7.0,0.714114,1.681035e-09,-0.115904,...,0.000000,0.000000,0.00000,0.0,0.000000,0.901061,374.802551,3.980740e-10,278.039612,3.624132e-06
3,0,DE,B00NVDOWUW,0.0,43.256542,11.990000,166.0,8.750996,5.199363e-06,8.507557,...,0.000000,0.000000,0.00000,0.0,0.000000,0.927298,385.701782,2.154962e-05,285.239197,4.852260e-03
4,0,DE,B00NVDP3ZU,0.0,43.256542,22.990000,502.0,8.056712,2.596729e-06,5.898870,...,0.000000,0.000000,0.00000,0.0,0.000000,0.930655,385.398499,1.591202e-05,284.763611,3.015780e-03
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
78842194,261815,UK,B0BCX524Y6,0.0,9.383333,16.990000,7.0,6.813615,1.076201e-03,7.203015,...,0.000000,0.000000,0.00000,0.0,0.000000,0.972680,438.956238,1.636532e-04,281.902344,1.476179e-04
78842195,261815,UK,B0BCX6QB4L,0.0,9.383333,10.990000,53.0,9.030836,9.881445e-03,10.123234,...,0.017115,4.250000,0.01369,2.0,0.010638,0.972680,438.956238,1.636532e-04,281.902344,1.476179e-04
78842196,261815,UK,B0BFPJYXQL,0.0,9.383333,10.560000,7.0,0.796892,2.623396e-06,1.711608,...,0.000000,0.000000,0.00000,0.0,0.000000,0.953467,430.164368,2.486932e-08,283.306732,6.012531e-04
78842197,261815,UK,B0BH3X67S3,0.0,9.383333,6.830000,38.0,4.250781,8.296004e-05,6.447586,...,0.000000,0.000000,0.00000,0.0,0.000000,0.961829,434.029083,1.186011e-06,273.954742,5.218429e-08


In [28]:
merged_candidates.query('sess_id==60').sort_values(by=['normalized_sasrec_scores_2'], ascending=False)[['product', 'normalized_sasrec_scores_2', 'roberta_scores', 'normalized_roberta_scores']][:15]

Unnamed: 0,product,normalized_sasrec_scores_2,roberta_scores,normalized_roberta_scores
18132,B08511MSM5,0.128722,288.872253,0.076249
18008,B00W0G743S,0.112961,286.239502,0.005481
18039,B01IN3MVMK,0.10014,285.603973,0.002903
17976,B00305B1KK,0.066778,278.270935,2e-06
18018,B00W0G748S,0.055488,286.615387,0.007982
18102,B07K7PLWRB,0.046246,284.095154,0.000642
18146,B08FJBBZMZ,0.046075,286.598206,0.007846
18011,B00W0G746A,0.030626,286.818512,0.009779
18131,B08511GG29,0.02731,288.748596,0.06738
18103,B07K7QGDS7,0.025443,283.537506,0.000368


In [29]:
merged_candidates.query('sess_id==60').sort_values(by=['normalized_roberta_scores'], ascending=False)[['product', 'normalized_sasrec_scores_2', 'roberta_scores', 'normalized_roberta_scores']][:15]

Unnamed: 0,product,normalized_sasrec_scores_2,roberta_scores,normalized_roberta_scores
18130,B08511GC19,0.0001841048,288.925812,0.080444
18132,B08511MSM5,0.1287216,288.872253,0.076249
18131,B08511GG29,0.02731024,288.748596,0.06738
18078,B07GPG2NGM,0.005105171,287.354492,0.016714
18088,B07GPRXZ5P,0.004547593,287.313568,0.016044
18072,B07DWBPWWG,0.0001820841,287.216492,0.014559
17980,B0033VX9GK,1.769011e-07,287.208069,0.014437
18071,B07DW9MB7F,0.0002861425,287.188354,0.014156
18074,B07DX1S9N6,1.492338e-05,287.12149,0.01324
18083,B07GPKK1LY,0.001104486,287.095276,0.012897
