## Build Input for Transformer Based Recommender with Adjacency Data

SASRec [1] takes sequence of items (interacted by a user) as input and predicts the same sequence (shifted). The idea is to enrich this item sequence with additional information coming from the associated users. Thus, each item will have a sequence 
of users (sorted by time) as additional attributes. 

How to take care of the time aspect? While creating this graph only the users interacted before this time should be taken into account. Thus, the new dataset will look like 
        
        user-id, item-id, all the users who interacted this item recently    
        
        11 56076 0
        11 14037 15,10,12,13
        11 4467 0
        11 33810 268,41162,60222,56206,49801,10441,13000,14299
        11 31260 12817,14614,30088,11039,25632,13,62373,47260,45849
        11 28006 32489
        11 16413 11359,11315,14025,34607,41079,11448,41139,40790,2541,10873,41072,41089,41083,13000,41099,29498,26935,9951,41060
        11 55039 59475,11315,14025,34607,11448,61040,41089,41099,41083,2541,13407,20417,9951
        11 20799 56213,56198
        11 58690 3616
        11 26147 0
        11 66039 51660,52634,58450,51306,59567,10873
        11 78708 58957
        11 18158 3597,951,61249,2901,40226,60070,32243,32556,3635

In [None]:
import os
import sys
import json
import re
import random
import copy
from tqdm import tqdm
import numpy as np
import pickle

from collections import defaultdict, Counter


In [None]:
interaction_filename = "/recsys_data/RecSys/SASRec-tf2/data/ae_original.txt"
output_filename = "/recsys_data/RecSys/SASRec-tf2/data/ae_graph.txt"
dict_filename = "/recsys_data/RecSys/SASRec-tf2/data/ae_graph_dict.pkl"

In [None]:
def data_process_with_time(fname, pname, sep="\t", file_write=False, max_seq_len=50, max_item_len=50):
    User = defaultdict(list)
    Items = set()
    user_dict, item_dict = {}, {}
    item_user = defaultdict(list)  # track user interaction time
    final_item_user = dict()

    with open(fname, 'r') as fr:
        for line in fr:
            u, i, t = line.rstrip().split(sep)
            t = float(t)
            User[u].append((i, t))
            Items.add(i)
            item_user[i].append((u, t))
            
    print(len(User), len(Items))
    
    item_count = 1  # always start with 1
    for item in Items:
        item_dict[item] = item_count
        item_count += 1

    count_del = 0
    user_count = 1  # start with 1

    # get the user-ids
    for user in User.keys():
        if len(User[user]) <= 2:
            count_del += 1
        else:
            User[user] = sorted(User[user], key=lambda x: x[1])
            user_dict[user] = user_count
            user_count += 1

    if file_write:
        print(f"Writing data in {pname}")
        with open(pname, 'w') as fw:
            for user in tqdm(User.keys()):
                if len(User[user]) > 2:
                    items = sorted(User[user], key=lambda x: x[1])
                    current_items = [x[0] for x in items]
                    user_id = user_dict[user]
                    missing_user = 0
                    for it in items:
                        item_name, item_time = it
                        ut = item_user[item_name]
                        item_id = item_dict[item_name]
                        prev_ut = [x for x in ut if x[1] < item_time]  # previous user-time
                        prev_ut = sorted(prev_ut, key=lambda x: item_time - x[1])
                        prev_u = [user] + [x[0] for x in prev_ut if x[0] in user_dict]
                        
                        # items interacted by these users (but before the current item_time)
                        # and not in the current user's item list
                        prev_it = [User[u] for u in prev_u]
                        prev_it = [item for sublist in prev_it for item in sublist]
                        prev_it = [x for x in prev_it if x[1] < item_time]
                        prev_it = sorted(prev_it, key=lambda x: item_time - x[1])
                        prev_i = [x[0] for x in prev_it if x[0] in item_dict]
                        prev_i = [item for item in prev_i if item not in current_items]
                        if len(prev_i) > 0:
                            prev_i = [item_dict[item] for item in prev_i][:max_item_len]
                            prev_i = [str(item) for item in prev_i]
                        else:
                            prev_i = ['0']
                        
                        prev_u = [str(user_id)] + [str(user_dict[x[0]]) for x in prev_ut if x[0] in user_dict][:max_seq_len]
                        hist_u = ','.join(prev_u)
                        hist_i = ','.join(prev_i)
                        fw.write(sep.join([str(user_id), str(item_id), hist_u, hist_i]) + '\n')
        
    print(user_count-1, count_del)
    return user_dict, item_dict, User, item_user

In [None]:
write_file = True
max_user_list = 49
max_item_list = 50
udict, idict, user_history, item_history = data_process_with_time(interaction_filename, 
                                                                  output_filename, 
                                                                  "\t", 
                                                                  write_file,
                                                                  max_user_list,
                                                                  max_item_list
                                                                 )

if write_file:
    with open(dict_filename, 'wb') as handle:
        pickle.dump((udict, idict, user_history), handle, protocol=pickle.HIGHEST_PROTOCOL)

print(f"Retained {len(udict)} users with {len(idict)} items from {len(user_history)} users")

In [None]:
# 'B005UEB5TQ', 'B000W3LJ6Y', 'B0089MVZDW', 'B00005TQ09', 'B0001Y7UAI', 'B00020BJA8', 'B000BQ7GW8', 'B000EPR7XO', 'B000M2GYF6', 'B001EH8FZA', 'B001EZRYYU'
item_history['B005UEB5TQ']

In [None]:
user_history['A2C8I2RQ0WG940']

In [None]:
udict['A2C8I2RQ0WG940']

In [None]:
idict['B005UEB5TQ'], idict['B007A4JTDI']

In [None]:
item_history['B007A4JTDI']

In [None]:
[udict[x[0]] for x in item_history['B007A4JTDI'] if x[1] < 1390176000.0]

## Processing MovieLens

In [None]:
def data_process_movielens(fname, pname, sep="\t", header=False, file_write=False, max_seq_len=50, max_item_len=50):
    User = defaultdict(list)
    Items = set()
    user_dict, item_dict = {}, {}
    item_user = defaultdict(list)  # track user interaction time
    final_item_user = dict()

    with open(fname, 'r') as fr:
        count_line = 0
        for line in fr:
            if header and count_line==0:
                count_line += 1
                continue
            u, i, r, t = line.rstrip().split(sep)
            t = float(t)
            User[u].append((i, t))
            Items.add(i)
            item_user[i].append((u, t))
            
    print(f"{len(User)}-users and {len(Items)}-items")
    
    item_count = 1  # always start with 1
    for item in Items:
        item_dict[item] = item_count
        item_count += 1

    count_del = 0
    user_count = 1  # start with 1

    # get the user-ids
    for user in User.keys():
        if len(User[user]) <= 2:
            count_del += 1
        else:
            User[user] = sorted(User[user], key=lambda x: x[1])
            user_dict[user] = user_count
            user_count += 1

    if file_write:
        print(f"Writing data in {pname}")
        with open(pname, 'w') as fw:
            for user in tqdm(User.keys()):
                if len(User[user]) > 2:
                    items = sorted(User[user], key=lambda x: x[1])
                    print(items)
                    current_items = [x[0] for x in items]
                    print(current_items)
                    user_id = user_dict[user]
                    missing_user = 0
                    for it in items:
                        print(it)
                        item_name, item_time = it
                        ut = item_user[item_name]
                        item_id = item_dict[item_name]
                        prev_ut = [x for x in ut if x[1] < item_time]  # previous user-time
                        prev_ut = sorted(prev_ut, key=lambda x: item_time - x[1])
                        prev_u = [user] + [x[0] for x in prev_ut if x[0] in user_dict]
                        print('here')
                        
                        # items interacted by these users (but before the current item_time)
                        # and not in the current user's item list
                        prev_it = [User[u] for u in prev_u]
                        prev_it = [item for sublist in prev_it for item in sublist]
                        prev_it = [x for x in prev_it if x[1] < item_time]
                        prev_it = sorted(prev_it, key=lambda x: item_time - x[1])
                        prev_i = [x[0] for x in prev_it if x[0] in item_dict]
                        prev_i = [item for item in prev_i if item not in current_items]
                        print('here-2')
                        if len(prev_i) > 0:
                            prev_i = [item_dict[item] for item in prev_i][:max_item_len]
                            prev_i = [str(item) for item in prev_i]
                        else:
                            prev_i = ['0']
                        
                        prev_u = [str(user_id)] + [str(user_dict[x[0]]) for x in prev_ut if x[0] in user_dict][:max_seq_len]
                        hist_u = ','.join(prev_u)
                        hist_i = ','.join(prev_i)
                        fw.write(sep.join([str(user_id), str(item_id), hist_u, hist_i]) + '\n')
        
    print(user_count-1, count_del)
    return user_dict, item_dict, User, item_user

In [None]:
write_file = False
max_user_list = 49
max_item_list = 50
# data_dir = "/recsys_data/RecSys/Movielens/KeBERT4Rec/Data/ml-20m"
data_dir = "/recsys_data/RecSys/Movielens/ml-1m"
interaction_filename = os.path.join(data_dir, "ratings.dat")
output_filename = os.path.join(data_dir, "ratings_processed.csv")

udict, idict, user_history, item_history = data_process_movielens(interaction_filename, 
                                                                  output_filename, 
                                                                  sep="::",
                                                                  header=False,
                                                                  file_write=True,
                                                                  max_seq_len=max_user_list,
                                                                  max_item_len=max_item_list
                                                                 )



In [None]:
import pandas as pd

interaction_filename = "/recsys_data/RecSys/SASRec-tf2/data/ae_original.txt"

# data_dir = "/recsys_data/RecSys/Movielens/ml-1m"
# interaction_filename = os.path.join(data_dir, "ratings.dat")
# df = pd.read_csv(interaction_filename, sep="::", names=['User', 'item', 'rating', 'time'])

df = pd.read_csv(interaction_filename, sep="\t", names=['User', 'item', 'time'])
df

In [None]:
def concat_users(df_g):
    users, time = df_g['User'], df_g['time']
    concat = [(u,t) for u,t in zip(users, time)]
    return concat
    
def concat_items(df_g):
    items, time = df_g['item'], df_g['time']
    concat = [(i,t) for i,t in zip(items, time)]
    return concat

# Users who interacted with this item
df_item = df.groupby('item')['User', 'time'].apply(concat_users)
df_item = df_item.reset_index().rename(columns={0: 'user_time'})

# items interacted with
df_user = df.groupby('User')['item', 'time'].apply(concat_items)
df_user = df_user.reset_index().rename(columns={0: 'item_time'})

In [None]:
df2 = df.merge(df_item, on='item', how='left')

In [None]:
def filter_users(row):
    filtered = [x[0] for x in row['user_time'] if x[1] < row['time']]
    df_u = pd.DataFrame({"User": filtered}).merge(df_user, how="inner")
    items = list(df_u['item_time'])
    flat = [item for sublist in items for item in sublist]
    flat = [x for x in flat if x[1] < row['time']]
    flat = sorted(flat, key=lambda x: row['time'] - x[1])
    flat = [x[0] for x in flat]
    return flat[:50]

# this one takes lot of time
df3 = df2.apply(filter_users, axis=1)

In [None]:
df_item[df_item['item']==1193]

In [None]:
df_user

## Amazon Dataset

In [32]:
def data_process_amazon(fname, pname, sep="\t", header=False, file_write=False, max_seq_len=50, max_item_len=50):
    User = defaultdict(list)
    Items = set()
    item_user = defaultdict(list)  # track user interaction time
    final_item_user = dict()

    with open(fname, 'r') as fr:
        count_line = 0
        for line in fr:
            if header and count_line==0:
                count_line += 1
                continue
                
            u, i, t = line.rstrip().split(sep)
#             u, i, r, t = line.rstrip().split(sep)
            t = float(t)
            User[u].append((i, t))
            Items.add(i)
            item_user[i].append((u, t))
            
    print(f"{len(User)}-users and {len(Items)}-items")
    
    count_del = 0
    
    # get the user-ids
    for user in User.keys():
        User[user] = sorted(User[user], key=lambda x: x[1])

    fr = open(fname, 'r')
    lines = fr.readlines()
    if header:
        lines = lines[1:]
    
    print(f"Writing data in {pname}")
    with open(pname, 'w') as fw:
        for line in tqdm(lines):
            user, item_name, item_time = line.strip().split(sep)
            item_time = float(item_time)
            ut = item_user[item_name]
            prev_ut = [x for x in ut if x[1] < item_time]  # previous user-time
            prev_ut = sorted(prev_ut, key=lambda x: item_time - x[1])
            prev_u = [user] + [x[0] for x in prev_ut]

            # items interacted by these users (but before the current item_time)
            # and not in the current user's item list
            prev_it = [User[u] for u in prev_u]
            prev_it = [item for sublist in prev_it for item in sublist]  # flatten
            prev_it = [x for x in prev_it if x[1] < item_time]
            prev_it = sorted(prev_it, key=lambda x: item_time - x[1])
            prev_i = [x[0] for x in prev_it]
            prev_i = [item for item in prev_i if item != item_name]
            if len(prev_i) > 0:
                prev_i = prev_i[:max_item_len]
            else:
                # if there is no history then keep the current item
                prev_i = [item_name]

            prev_u = [user] + [x[0] for x in prev_ut][:max_seq_len]
            hist_u = ','.join(prev_u)
            hist_i = ','.join(prev_i)
            fw.write(sep.join([user, item_name, str(item_time), hist_u, hist_i]) + '\n')
        
    return User, item_user

In [33]:
max_user_list = 49
max_item_list = 50

data_dir = "/recsys_data/RecSys/KeBERT4Rec/Data/amazon-electronics"
filename = "reviews_Electronics_10filter.json_output"

interaction_filename = os.path.join(data_dir, filename)
output_filename = os.path.join(data_dir, "temp.csv")

user_history, item_history = data_process_amazon(interaction_filename, 
                                                  output_filename, 
                                                  sep="\t",
                                                  header=False,
                                                  file_write=True,
                                                  max_seq_len=max_user_list,
                                                  max_item_len=max_item_list
                                                 )



63161-users and 85930-items


  0%|          | 99/949416 [00:00<16:02, 986.63it/s]

Writing data in /recsys_data/RecSys/KeBERT4Rec/Data/amazon-electronics/temp.csv


100%|██████████| 949416/949416 [12:06<00:00, 1306.78it/s] 


In [68]:
history_file = '/recsys_data/RecSys/KeBERT4Rec/Data/amazon-electronics/ratings_with_history.tsv'
df = pd.read_csv(history_file, sep="\t", names=['uid', 'sid', 'timestamp', 'user_hist', 'item_hist'])
df.head()

Unnamed: 0,uid,sid,time,user_hist,item_hist
0,A2C8I2RQ0WG940,B005UEB5TQ,1383523000.0,A2C8I2RQ0WG940,"B007G5NNOW,B00CU9GKTO,B0095ZRQN0,B000VDCTCI,B0..."
1,AM8OIQGVZEEKT,B005UEB5TQ,1405469000.0,"AM8OIQGVZEEKT,A2C8I2RQ0WG940","B00144KS6W,B0037MH5W4,B00FDPSH0W,B003OBUJIK,B0..."
2,A2YQL8DH5AKIGV,B000W3LJ6Y,1238890000.0,"A2YQL8DH5AKIGV,A33QUFNY4E5D0","B001H0GEW0,B000E922SA,B001AZ01EO,B0007U9KAY,B0..."
3,A33QUFNY4E5D0,B000W3LJ6Y,1199232000.0,A33QUFNY4E5D0,"B000IJV4BC,B00005T6GZ"
4,ATFBVUXDIRXT6,B0089MVZDW,1374278000.0,"ATFBVUXDIRXT6,A4PXVT3HVX0MH,A3L1V09GLOVFLT","B00CMNZHEC,B00B5Q79D4,B001T9NX9Q,B0062K951C,B0..."


In [69]:
umap = {u: i for i, u in enumerate(set(df["uid"]))}
smap = {s: i for i, s in enumerate(set(df["sid"]))}

In [70]:
def convert_users(lst):
    return [umap[x] for x in lst.split(',')]

def convert_items(lst):
    return [smap[x] for x in lst.split(',')]

In [71]:
df["uid"] = df["uid"].map(umap)
df["sid"] = df["sid"].map(smap)

df["user_hist"] = df["user_hist"].apply(convert_users)
df["item_hist"] = df["item_hist"].apply(convert_items)

In [72]:
df.head()

Unnamed: 0,uid,sid,time,user_hist,item_hist
0,579,60247,1383523000.0,[579],"[82614, 49598, 47840, 3623, 77365, 69845, 2364..."
1,36174,60247,1405469000.0,"[36174, 579]","[73774, 25402, 31990, 15044, 48220, 54052, 632..."
2,38663,56269,1238890000.0,"[38663, 38340]","[42029, 78672, 56772, 82497, 36126, 50490, 455..."
3,38340,56269,1199232000.0,[38340],"[36126, 62251]"
4,58116,68163,1374278000.0,"[58116, 29294, 23727]","[53873, 10332, 67455, 20317, 81203, 11539, 258..."


In [45]:
umap['A3QO5UPOCBMAR5']

38148

In [73]:
tqdm.pandas()

user_group = df.groupby("uid")
hist_items = user_group.progress_apply(
                lambda d: list(d.sort_values(by="time")["item_hist"])
            )

100%|██████████| 63161/63161 [00:43<00:00, 1461.25it/s]


In [74]:
hist_items[0]

[[14156],
 [23994,
  84535,
  33915,
  33915,
  43254,
  33915,
  29456,
  53636,
  80303,
  26956,
  58857,
  37960,
  55456,
  76879,
  60259,
  45665,
  41613,
  82172,
  69715,
  80128,
  64442,
  38796,
  19063,
  25565,
  28365,
  53846,
  14156],
 [15121,
  13602,
  5112,
  80985,
  49432,
  33901,
  34630,
  18749,
  43802,
  47582,
  6869,
  7552,
  56432,
  66568,
  77053,
  72143,
  17997,
  123,
  23608,
  39112,
  9182,
  13188,
  9182,
  33989,
  51389,
  7420,
  72735,
  83229,
  81030,
  19255,
  46627,
  70398,
  38867,
  6300,
  41587,
  1220,
  70081,
  60239,
  41715,
  1150,
  84577,
  80288,
  52199,
  75801,
  65430,
  12719,
  12719,
  52098,
  40854,
  29797],
 [69737,
  78980,
  36535,
  1150,
  82155,
  39883,
  40295,
  26916,
  12059,
  26916,
  81777,
  21943,
  81944,
  25170,
  38353,
  63742,
  52321,
  2146,
  73688,
  63742,
  35082,
  28796,
  67569,
  74730,
  6997,
  16831,
  83241,
  63742,
  24550,
  63364,
  61482,
  48097,
  41213,
  73688,
  5

In [62]:
[len(u) for u in user2items[0]]

[1, 27, 50, 50, 50, 50, 6, 50, 50, 50, 50, 50]

In [64]:
df[df['uid']==0].sort_values(by="time")['item_hist'].apply(lambda x: list(map(int, x.split(','))))

821441                                              [14156]
798452    [23994, 84535, 33915, 33915, 43254, 33915, 294...
311890    [15121, 13602, 5112, 80985, 49432, 33901, 3463...
423463    [69737, 78980, 36535, 1150, 82155, 39883, 4029...
549839    [6804, 49750, 31711, 15426, 35623, 3002, 66521...
542840    [20953, 42968, 54123, 6618, 67347, 17867, 6480...
823645              [73688, 83128, 47, 49385, 35463, 14156]
599472    [49498, 7097, 30496, 69060, 825, 49178, 36246,...
644485    [30615, 13574, 37670, 24283, 62413, 69190, 474...
677171    [13842, 75548, 47763, 52718, 45589, 68216, 807...
891813    [1402, 63319, 59579, 24511, 16181, 4736, 70522...
182031    [17637, 58406, 22763, 51569, 19008, 71738, 637...
Name: item_hist, dtype: object

In [75]:
len(umap), len(smap)

(63161, 85930)

In [67]:
df['item_hist'][0]

'82614,49598,47840,3623,77365,69845,23640,19339,9301,32097,13518,50060,41654'