## Build Input for Transformer Based Recommender with Adjacency Data

SASRec [1] takes sequence of items (interacted by a user) as input and predicts the same sequence (shifted). The idea is to enrich this item sequence with additional information coming from the associated users. Thus, each item will have a sequence 
of users (sorted by time) as additional attributes. 

How to take care of the time aspect? While creating this graph only the users interacted before this time should be taken into account. Thus, the new dataset will look like 
        
        user-id, item-id, all the users who interacted this item recently    
        
        11 56076 0
        11 14037 15,10,12,13
        11 4467 0
        11 33810 268,41162,60222,56206,49801,10441,13000,14299
        11 31260 12817,14614,30088,11039,25632,13,62373,47260,45849
        11 28006 32489
        11 16413 11359,11315,14025,34607,41079,11448,41139,40790,2541,10873,41072,41089,41083,13000,41099,29498,26935,9951,41060
        11 55039 59475,11315,14025,34607,11448,61040,41089,41099,41083,2541,13407,20417,9951
        11 20799 56213,56198
        11 58690 3616
        11 26147 0
        11 66039 51660,52634,58450,51306,59567,10873
        11 78708 58957
        11 18158 3597,951,61249,2901,40226,60070,32243,32556,3635

In [1]:
import os
import sys
import json
import re
import random
import copy
from tqdm import tqdm
import numpy as np
import pickle

from collections import defaultdict, Counter


In [2]:
interaction_filename = "/recsys_data/RecSys/SASRec-tf2/data/ae_original.txt"
output_filename = "/recsys_data/RecSys/SASRec-tf2/data/ae_graph.txt"
dict_filename = "/recsys_data/RecSys/SASRec-tf2/data/ae_graph_dict.pkl"

In [3]:
def data_process_with_time(fname, pname, sep="\t", file_write=False, max_seq_len=50):
    User = defaultdict(list)
    Items = set()
    user_dict, item_dict = {}, {}
    item_user = defaultdict(list)  # track user interaction time
    final_item_user = dict()

    with open(fname, 'r') as fr:
        for line in fr:
            u, i, t = line.rstrip().split(sep)
            t = float(t)
            User[u].append((i, t))
            Items.add(i)
            item_user[i].append((u, t))
            
    print(len(User), len(Items))
    
    item_count = 1  # always start with 1
    for item in Items:
        item_dict[item] = item_count
        item_count += 1

    count_del = 0
    user_count = 1  # start with 1

    # get the user-ids
    for user in User.keys():
        if len(User[user]) <= 2:
            count_del += 1
        else:
            User[user] = sorted(User[user], key=lambda x: x[1])
            user_dict[user] = user_count
            user_count += 1

    if file_write:
        print(f"Writing data in {pname}")
        count_missing = 0
        with open(pname, 'w') as fw:
            for user in tqdm(User.keys()):
                if len(User[user]) > 2:
                    items = sorted(User[user], key=lambda x: x[1])
                    user_id = user_dict[user]
                    missing_user = 0
                    for it in items:
                        item_name, item_time = it
                        ut = item_user[item_name]
                        item_id = item_dict[item_name]
                        prev_ut = [x for x in ut if x[1] < item_time]
                        prev_ut = sorted(prev_ut, key=lambda x: item_time - x[1])
                        prev_u = [str(user_dict[x[0]]) for x in prev_ut if x[0] in user_dict][:max_seq_len]
                        if len(prev_u) == 0:
                            hist = '0'
                            missing_user += 1
                        else:
                            hist = ','.join(prev_u)
                        fw.write(sep.join([str(user_id), str(item_id), hist]) + '\n')
                    if missing_user == len(items):
                        count_missing += 1
        
    print(user_count-1, count_del, count_missing)
    return user_dict, item_dict, User, item_user

In [4]:
write_file = True
max_user_list = 50
udict, idict, user_history, item_history = data_process_with_time(interaction_filename, 
                                                                  output_filename, 
                                                                  "\t", 
                                                                  write_file,
                                                                  max_user_list
                                                                 )

if write_file:
    with open(dict_filename, 'wb') as handle:
        pickle.dump((udict, idict, user_history), handle, protocol=pickle.HIGHEST_PROTOCOL)

print(f"Retained {len(udict)} users with {len(idict)} items from {len(user_history)} users")

63161 85930


  0%|          | 45/63161 [00:00<02:20, 448.56it/s]

Writing data in /recsys_data/RecSys/SASRec-tf2/data/ae_graph.txt


100%|██████████| 63161/63161 [00:52<00:00, 1194.03it/s]


63114 47 12
Retained 63114 users with 85930 items from 63161 users
