In [None]:
import pandas as pd
import os
import csv
import sys
import re
import copy
import json
import pickle
import random
import numpy as np
import torch

csv.field_size_limit(sys.maxsize)

from tqdm import tqdm

from IPython.core.interactiveshell import InteractiveShell
InteractiveShell.ast_node_interactivity = 'all'

from pandarallel import pandarallel
pandarallel.initialize(progress_bar=True)

import warnings
warnings.filterwarnings("ignore")

In [None]:
np.__version__
torch.__version__

In [None]:
random.seed(2025)

In [None]:
read_path = '../../Qilin'
save_path = '..'


In [None]:
item_features = pd.read_pickle(f'{read_path}/raw_data/item_feat_encoded.pkl')
item_features.head(1)
item_features.shape


In [None]:
item_features['image_path'].iloc[3]

In [None]:
def text_null(row):
    row_text = row['text'].strip()
    return len(row_text)==0

text_null_data = item_features.apply(text_null,axis=1)
text_null_data.sum()

item_features = item_features[~text_null_data].reset_index(drop=True)

In [None]:
item_features.columns

In [None]:
rec_inter = pd.read_pickle(f"{read_path}/raw_data/rec_inter.pkl")
rec_inter = rec_inter.sort_values(by=['user_id','timestamp']).reset_index(drop=True)
rec_inter.head()
rec_inter.shape

In [None]:
all_user_item = rec_inter[['user_id','item_id']]
all_user_item.head(1)
all_user_item.shape

all_user_item['user_id'].value_counts()
all_user_item['item_id'].value_counts()

In [None]:
remain_user_item = all_user_item.copy()
filter_num=3
while (remain_user_item['user_id'].value_counts() < filter_num).sum() > 0:
    print("filter user")
    remain_user_cnt = remain_user_item['user_id'].value_counts()
    remained_user = remain_user_cnt[remain_user_cnt>=filter_num].index.to_list()
    remain_user_item = remain_user_item[remain_user_item['user_id'].isin(remained_user)]

    while (remain_user_item['item_id'].value_counts() < filter_num).sum() > 0:
        print("filter item")
        remain_item_cnt = remain_user_item['item_id'].value_counts()
        remained_item = remain_item_cnt[remain_item_cnt>=filter_num].index.to_list()
        remain_user_item = remain_user_item[remain_user_item['item_id'].isin(remained_item)]

In [None]:
remain_user_item['user_id'].value_counts()
remain_user_item['item_id'].value_counts()

In [None]:
rec_inter[(rec_inter['user_id'].isin(remain_user_item['user_id'].unique())) &
          (rec_inter['item_id'].isin(remain_user_item['item_id'].unique()))].shape

## Process Src data

In [None]:
src_inter = pd.read_pickle(f"{read_path}/raw_data/src_inter_encoded.pkl")
src_inter = src_inter.sort_values(by=['user_id','timestamp']).reset_index(drop=True)
src_inter.head()
src_inter.shape

In [None]:
user_features = pd.read_pickle(f'{read_path}/raw_data/user_feat.pkl')
user_features.head(1)
user_features.shape

## Map user/item to id

In [None]:
rec_inter = rec_inter[rec_inter['item_id'].isin(item_features['item_id'].unique())].reset_index(drop=True)
rec_inter = rec_inter[rec_inter['user_id'].isin(user_features['user_id'].unique())].reset_index(drop=True)
rec_inter.shape

src_inter = src_inter[src_inter['item_id'].isin(item_features['item_id'].unique())].reset_index(drop=True)
src_inter = src_inter[src_inter['user_id'].isin(user_features['user_id'].unique())].reset_index(drop=True)
src_inter.shape

## Filter User

In [None]:
rec_user_set = set(rec_inter['user_id'].unique())
src_user_set = set(src_inter['user_id'].unique())

len(rec_user_set)
len(src_user_set)
len(rec_user_set | src_user_set)
len(rec_user_set & src_user_set)

all_user_set = list(rec_user_set | src_user_set)

In [None]:
len(rec_user_set - src_user_set)
len(src_user_set - rec_user_set)

In [None]:
rec_inter = rec_inter[rec_inter['user_id'].isin(all_user_set)].reset_index(drop=True)
rec_inter.shape

src_inter = src_inter[src_inter['user_id'].isin(all_user_set)].reset_index(drop=True)
src_inter.shape

In [None]:
rec_inter['user_id'].value_counts()

In [None]:
(rec_inter['user_id'].value_counts() >= 5).sum() / rec_inter['user_id'].nunique()

(rec_inter['user_id'].value_counts() >= 10).sum() / rec_inter['user_id'].nunique()

(rec_inter['user_id'].value_counts() >= 20).sum() / rec_inter['user_id'].nunique()

(rec_inter['user_id'].value_counts() >= 30).sum() / rec_inter['user_id'].nunique()

In [None]:
src_inter['user_id'].value_counts()

In [None]:
(src_inter['user_id'].value_counts() >= 5).sum() / src_inter['user_id'].nunique()

(src_inter['user_id'].value_counts() >= 10).sum() / src_inter['user_id'].nunique()

(src_inter['user_id'].value_counts() >= 20).sum() / src_inter['user_id'].nunique()

(src_inter['user_id'].value_counts() >= 30).sum() / src_inter['user_id'].nunique()


In [None]:
num_rec_inter_dict = rec_inter['user_id'].value_counts().to_dict()
len(num_rec_inter_dict)

num_src_inter_dict = src_inter['user_id'].value_counts().to_dict()
len(num_src_inter_dict)

In [None]:
user_features = user_features.rename(columns={'user_id':'user'})

user_features['num_rec_inter'] = user_features['user'].map(lambda x: num_rec_inter_dict.get(x,0))
user_features['num_src_inter'] = user_features['user'].map(lambda x: num_src_inter_dict.get(x,0))

user_features = user_features.astype({"user":'category'})
id2user = user_features['user'].cat.categories.to_list()

user2id = {id2user[k]:k for k in range(len(id2user))} 

user_features['user_id'] = user_features['user'].map(user2id)

user_features.head(1)

## Filter Item

In [None]:
rec_item_set = set(rec_inter['item_id'].unique())
src_item_set = set(src_inter['item_id'].unique())

len(rec_item_set)
len(src_item_set)
len(rec_item_set | src_item_set)
len(rec_item_set & src_item_set)

In [None]:
item_features = item_features[item_features['item_id'].isin(rec_item_set | src_item_set)].reset_index(drop=True)
item_features.shape

In [None]:
item_features = item_features.rename(columns={'item_id':'item'})
item_features = item_features.astype({"item":'category'})
id2item = item_features['item'].cat.categories.to_list()
# id2item[0]
item2id = {id2item[k]:k+1 for k in range(len(id2item))} # +1 for padding
# item2id[0]
item_features['item_id'] = item_features['item'].map(item2id)

item_features.head(1)

In [None]:
pad_item_df = pd.DataFrame({"item_id": [0],})
pad_item_df

all_item_df = pd.concat([pad_item_df, item_features],axis=0)
all_item_df['item_id'] = all_item_df['item_id'].astype('int')
all_item_df = all_item_df.sort_values(by=['item_id']).reset_index(drop=True)

all_item_df['item_id'].nunique()
all_item_df.head()
all_item_df.shape

In [None]:
item_vocab = all_item_df.set_index('item_id',drop=False).to_dict('index')

In [None]:
rec_inter['user_id'] = rec_inter['user_id'].apply(lambda x:user2id[int(x)])
rec_inter['item_id'] = rec_inter['item_id'].apply(lambda x:item2id[int(x)])
rec_inter.head()
rec_inter.shape

In [None]:
src_inter['user_id'] = src_inter['user_id'].apply(lambda x:user2id[int(x)])
src_inter['item_id'] = src_inter['item_id'].apply(lambda x:item2id[int(x)])
src_inter.head()
src_inter.shape

In [None]:
src_inter['keyword'] = src_inter['keyword'].astype('str')
session_src_inter = src_inter.groupby(by=['user_id','search_session_id','query','query_id','keyword']).agg(
    click_list=('click',list),
    pos_items=("item_id",list),
    time_list=('timestamp',list)
).reset_index()
session_src_inter = session_src_inter.sort_values(by=['user_id']).reset_index(drop=True)
session_src_inter.head()
session_src_inter.shape

In [None]:
session_src_inter = session_src_inter.drop_duplicates(subset=['user_id','search_session_id'], keep='first').reset_index(drop=True)
session_src_inter.shape
session_src_inter['search_session_id'].nunique()

In [None]:
session_src_inter['search_session_id'] = session_src_inter['search_session_id'].astype('category')

id2session = session_src_inter['search_session_id'].cat.categories.to_list()
session2id = {id2session[k]:k+1 for k in range(len(id2session))}

session_src_inter['search_session_id'] = session_src_inter['search_session_id'].apply(lambda x:session2id[x])
# +1 for padding
session_src_inter.head()

In [None]:
src_inter['search_session_id'] = src_inter['search_session_id'].map(session2id)
src_inter.head()

In [None]:
session_src_inter['keyword'] = session_src_inter['keyword'].apply(eval)

In [None]:
session_vocab = session_src_inter[['search_session_id', 'query', 'query_id', 'keyword',
                                   'pos_items','click_list','time_list']].set_index('search_session_id',drop=False).to_dict('index')
session_vocab[1]

In [None]:
def get_session_time(row):
    return row['time_list'][0]

session_src_inter['timestamp'] = session_src_inter.apply(get_session_time,axis=1)
session_src_inter.head()

## Joint SAR data

In [None]:
sub_rec_inter = rec_inter[['user_id','item_id','timestamp']].copy()
sub_rec_inter['search_session_id'] = 'nan'
sub_rec_inter['behavior'] = 1

sub_session_src_inter = session_src_inter[['user_id','timestamp','search_session_id']].copy()
sub_session_src_inter['item_id'] = 'nan'
sub_session_src_inter['behavior'] = 2

sar_inter = pd.concat([sub_rec_inter,sub_session_src_inter],axis=0)
sar_inter = sar_inter.sort_values(by=['user_id','timestamp']).reset_index(drop=True)
sar_inter.head()
sar_inter.shape

In [None]:
user_vocab = user_features.set_index('user_id',drop=False).to_dict('index')
for key in user_vocab.keys():
    user_vocab[key]['rec_his'] = []
    user_vocab[key]['rec_his_ts'] = []
    user_vocab[key]['src_session_his'] = []
    user_vocab[key]['src_session_his_ts'] = []
    user_vocab[key]['src_his'] = []
    user_vocab[key]['src_his_ts'] = []
    user_vocab[key]['src_his_query'] = []
    user_vocab[key]['all_his'] = []
    user_vocab[key]['all_his_ts'] = []
    user_vocab[key]['all_his_query'] = []

new_sar_inter_list = []
for _, line in tqdm(sar_inter.iterrows()):
    user_id, item_id, timestamp,\
        search_session_id, behavior = line['user_id'], line['item_id'], \
            line['timestamp'], line['search_session_id'], line['behavior']
    
    cur_rec_his_len = len(user_vocab[user_id]['rec_his'])
    cur_src_session_his_len = len(user_vocab[user_id]['src_session_his'])
    cur_src_his_len = len(user_vocab[user_id]['src_his'])
    cur_all_his_len = len(user_vocab[user_id]['all_his'])
    
    new_sar_inter_list.append((user_id,item_id,timestamp, search_session_id,behavior,\
                               cur_rec_his_len,cur_src_session_his_len,cur_src_his_len,cur_all_his_len))

    if behavior == 1:
        user_vocab[user_id]['rec_his'].append(item_id)
        user_vocab[user_id]['rec_his_ts'].append(timestamp)
        user_vocab[user_id]['all_his'].append(item_id)
        user_vocab[user_id]['all_his_ts'].append(timestamp)
        user_vocab[user_id]['all_his_query'].append(0)
    elif behavior == 2:
        user_vocab[user_id]['src_session_his'].append(search_session_id)
        user_vocab[user_id]['src_session_his_ts'].append(timestamp)

        session_info = session_vocab[search_session_id]
        cur_query = session_info['keyword']
        cur_session_pos = session_info['pos_items']
        cur_session_ts = session_info['time_list']
        assert len(cur_session_pos) == len(cur_session_ts)
        
        user_vocab[user_id]['src_his'].extend(cur_session_pos)
        user_vocab[user_id]['src_his_ts'].extend(cur_session_ts)
        user_vocab[user_id]['src_his_query'].extend([cur_query]*len(cur_session_pos))

        user_vocab[user_id]['all_his'].extend(cur_session_pos)
        user_vocab[user_id]['all_his_ts'].extend([timestamp]*len(cur_session_pos))
        user_vocab[user_id]['all_his_query'].extend([cur_query]*len(cur_session_pos))


In [None]:
new_sar_inter_df = pd.DataFrame(data=new_sar_inter_list,
                                columns=sar_inter.columns.to_list()+['rec_his','src_session_his','src_his','all_his'])
new_sar_inter_df.head()
new_sar_inter_df.shape

In [None]:
pickle.dump(item_vocab,open(f'{save_path}/vocab/item_vocab.pkl','wb'))

pickle.dump(user_vocab,open(f'{save_path}/vocab/user_vocab.pkl','wb'))

pickle.dump(session_vocab,open(f'{save_path}/vocab/src_session_vocab.pkl','wb'))

## train/val/test

In [None]:
new_sar_inter_df[(new_sar_inter_df['rec_his'] == 0) & (new_sar_inter_df['src_his'] == 0)].shape

new_sar_inter_df[(new_sar_inter_df['rec_his'] > 0) & (new_sar_inter_df['src_his'] > 0)].shape

new_sar_inter_df[(new_sar_inter_df['rec_his'] > 0)].shape

new_sar_inter_df[(new_sar_inter_df['src_his'] > 0)].shape

In [None]:
new_sar_inter_df = new_sar_inter_df.sort_values(by=['timestamp']).reset_index(drop=True)
new_sar_inter_df['rec_train'] = 0
new_sar_inter_df['src_train'] = 0

data_num = len(new_sar_inter_df)


new_sar_inter_df['rec_train'] = 1

new_sar_inter_df['src_train'] = 1

new_sar_inter_df[new_sar_inter_df['behavior'] == 1]['rec_train'].value_counts()
new_sar_inter_df[new_sar_inter_df['behavior'] == 2]['src_train'].value_counts()

In [None]:
def splitTrainTest(user_df):
    user_df['train'].iloc[-1] = 3
    user_df['train'].iloc[-2] = 2
    return user_df

## Rec Data

In [None]:
rec_w_his_inter = rec_inter.copy()
rec_w_his_inter = rec_w_his_inter.sort_values(by=['user_id','timestamp']).reset_index(drop=True)
rec_w_his_inter.shape
rec_w_his_inter.head(1)

rec_new_sar_inter_df = new_sar_inter_df[new_sar_inter_df.behavior==1].rename(columns={'rec_train':'train'})
rec_new_sar_inter_df = rec_new_sar_inter_df.sort_values(by=['user_id','timestamp']).reset_index(drop=True)
rec_new_sar_inter_df.shape
rec_new_sar_inter_df.head(1)

In [None]:
rec_w_his_inter[['train','rec_his','src_session_his','src_his','all_his']] = rec_new_sar_inter_df[['train','rec_his','src_session_his','src_his','all_his']]
rec_w_his_inter = rec_w_his_inter.reset_index(drop=True)
rec_w_his_inter.head(1)

In [None]:
rec_w_his_inter = rec_w_his_inter[(rec_w_his_inter['rec_his'] > 0) | (rec_w_his_inter['src_session_his'] > 0)].reset_index(drop=True)
rec_w_his_inter.shape

In [None]:
rec_inter_num = rec_w_his_inter.groupby(by=['user_id']).count().reset_index()
filtered_users_rec = rec_inter_num[rec_inter_num['item_id'] >= 3]
filtered_users_rec.head(3), filtered_users_rec['item_id'].describe()

In [None]:
filtered_users_rec['user_id'].nunique()

In [None]:
rec_w_his_inter = rec_w_his_inter[rec_w_his_inter['user_id'].isin(set(filtered_users_rec['user_id'].unique()))]
rec_w_his_inter = rec_w_his_inter.reset_index(drop=True)
rec_w_his_inter.head()
rec_w_his_inter.shape

In [None]:
rec_w_his_inter_train = rec_w_his_inter.groupby('user_id').apply(splitTrainTest)

In [None]:
rec_train = rec_w_his_inter_train[rec_w_his_inter_train.train==1].reset_index(drop=True)
rec_train = rec_train.sort_values(by=['user_id', 'timestamp']).reset_index(drop=True)
rec_train.drop(['train'],axis=1,inplace=True)
rec_train.shape


rec_val = rec_w_his_inter_train[rec_w_his_inter_train.train==2].reset_index(drop=True)
rec_val = rec_val.sort_values(by=['user_id', 'timestamp']).reset_index(drop=True)
rec_val.drop(['train'],axis=1,inplace=True)
rec_val.shape

rec_test = rec_w_his_inter_train[rec_w_his_inter_train.train==3].reset_index(drop=True)
rec_test = rec_test.sort_values(by=['user_id', 'timestamp']).reset_index(drop=True)
rec_test.drop(['train'],axis=1,inplace=True)
rec_test.shape


In [None]:
rec_train['user_id'].nunique()
rec_val['user_id'].nunique()
rec_test['user_id'].nunique()

### Sample negative for val and test 

In [None]:
num_train_neg_samples = 4
num_test_neg_samples = 99

rec_item_set = rec_w_his_inter['item_id'].to_list()

def SampleNegatives(row, cur_num_samples):
    count = 0 
    user_id = int(row['user_id'])
    cur_pos = int(row['item_id'])
    cur_all_his = user_vocab[user_id]['all_his'][:int(row['all_his'])]

    neg_samples = []
    while count < cur_num_samples:
        cur_neg = random.choice(rec_item_set)
        if (cur_neg in cur_all_his) or (cur_neg in neg_samples) or (cur_neg == cur_pos):
            continue
        count += 1
        neg_samples.append(cur_neg)
    return neg_samples

In [None]:
rec_train['neg_items'] = rec_train.parallel_apply(SampleNegatives,cur_num_samples=4,axis=1)
rec_train.head()

In [None]:
rec_val['neg_items'] = rec_val.parallel_apply(SampleNegatives,cur_num_samples=99,axis=1)

In [None]:
rec_test['neg_items'] = rec_test.parallel_apply(SampleNegatives,cur_num_samples=99,axis=1)

In [None]:
rec_train.to_pickle(f'{save_path}/dataset/rec_train.pkl')

rec_val.to_pickle(f'{save_path}/dataset/rec_val.pkl')

rec_test.to_pickle(f'{save_path}/dataset/rec_test.pkl')