In [1]:
import pandas as pd
import numpy as np
import math
import random

In [2]:
seed = 2023
random.seed(seed)
np.random.seed(seed)

In [3]:
inter_df = pd.read_csv("raw-yelp-2022/yelp2022.inter", sep='\t', usecols=["user_id:token", "item_id:token", "rating:float", "timestamp:float"])

In [4]:
# user_id:token	item_id:token	rating:float	timestamp:float	useful:float	funny:float	cool:float	review_id:token
inter_df = inter_df.rename(columns={"user_id:token":"user_id", "item_id:token":"item_id", "rating:float":"rating", "timestamp:float":"timestamp"})

In [5]:
inter_df.user_id.nunique(), inter_df.item_id.nunique(), len(inter_df)

(1987929, 150346, 6990280)

In [6]:
1 - len(inter_df) / (inter_df.user_id.nunique() * inter_df.item_id.nunique())

0.999976611529346

In [7]:
inter_df = inter_df.drop_duplicates(["user_id", "item_id"], keep='first')

In [8]:
clean_inter_df = inter_df[inter_df["rating"] > 3]

In [9]:
clean_inter_df.user_id.nunique(), clean_inter_df.item_id.nunique(), len(clean_inter_df)

(1463053, 147434, 4532468)

In [10]:
threshold_inter_num = 15

while True:
    clean_inter_df = clean_inter_df.groupby('user_id').filter(lambda x:len(x)>=threshold_inter_num)
    clean_inter_df = clean_inter_df.groupby('item_id').filter(lambda x:len(x)>=threshold_inter_num)
    if clean_inter_df.groupby('user_id').size().min() >= threshold_inter_num and clean_inter_df.groupby('item_id').size().min() >= threshold_inter_num:
        break

In [11]:
sizes = clean_inter_df.groupby("user_id").apply(len)
print(sizes.min(), sizes.max(), len(clean_inter_df), clean_inter_df.user_id.nunique(), clean_inter_df.item_id.nunique())

15 762 861748 24309 18412


In [12]:
user_id_codes, user_id_uniques = pd.factorize(clean_inter_df['user_id'])
clean_inter_df['user_id'] = user_id_codes

In [13]:
item_id_codes, item_id_uniques = pd.factorize(clean_inter_df['item_id'])
clean_inter_df['item_id'] = item_id_codes

In [14]:
# clean_inter_df = clean_inter_df.sort_values(by=["user_id", "timestamp"], axis=0)
clean_inter_df = clean_inter_df.sample(frac=1).reset_index(drop=True)

In [15]:
def split_group(group, split_ratio=[0.8,0.1,0.1]):
    num = len(group)
    test_num = math.ceil(num * split_ratio[2])
    valid_num = math.ceil(num * split_ratio[1])
    train_num = num - test_num - valid_num
    nums = [train_num, valid_num, test_num]
    offsets = [0] + list(np.cumsum(nums))
    splits = [group.iloc[offsets[i]:offsets[i+1]] for i in range(len(nums))]
    return splits

splits = clean_inter_df.groupby(by="user_id").apply(split_group)

In [16]:
train_pos_inter, valid_pos_inter, test_pos_inter = [pd.concat([s[i] for s in splits]) for i in range(3)]

In [17]:
len(train_pos_inter), len(valid_pos_inter), len(test_pos_inter)

(668830, 96459, 96459)

In [18]:
# filter valid and test data, where item_id is not appeared in train dataset
train_used_item_set = train_pos_inter.item_id.unique()
valid_pos_inter = valid_pos_inter[np.isin(valid_pos_inter.item_id, train_used_item_set, assume_unique=True)]
test_pos_inter = test_pos_inter[np.isin(test_pos_inter.item_id, train_used_item_set, assume_unique=True)]

In [19]:
len(train_pos_inter), len(valid_pos_inter), len(test_pos_inter)

(668830, 96459, 96459)

In [20]:

print("-"*10 + "user" + "-"*10)
sizes = train_pos_inter.groupby("user_id").apply(len)
print(sizes.min(), sizes.max())
sizes = valid_pos_inter.groupby("user_id").apply(len)
print(sizes.min(), sizes.max())
sizes = test_pos_inter.groupby("user_id").apply(len)
print(sizes.min(), sizes.max())

print("-"*10 + "item" + "-"*10)
sizes = train_pos_inter.groupby("item_id").apply(len)
print(sizes.min(), sizes.max())
sizes = valid_pos_inter.groupby("item_id").apply(len)
print(sizes.min(), sizes.max())
sizes = test_pos_inter.groupby("item_id").apply(len)
print(sizes.min(), sizes.max())

----------user----------
11 608
2 77
2 77
----------item----------
6 1038
1 157
1 156


In [22]:
train_pos_inter.to_csv("pro-yelp/yelp-train.clean", sep='\t', index=False)
valid_pos_inter.to_csv("pro-yelp/yelp-valid.clean", sep='\t', index=False)
test_pos_inter.to_csv("pro-yelp/yelp-test.clean", sep='\t', index=False)