In [1]:
import numpy as np
from collections import defaultdict
from tqdm import tqdm
import torch
import pandas as pd
import os

data_path = '../../data/jeehoshin/allrecipe_dataset/'

uid = pd.read_csv(data_path + 'u_id_mapping.csv', sep='\t')
iid = pd.read_csv(data_path + 'i_id_mapping.csv', sep='\t')
interaction = pd.read_csv(data_path + 'allrecipe.inter', sep='\t')

print(uid.head())
print(iid.head())
print(interaction.head())

def gen_user_matrix(all_edge, no_users):
    edge_dict = defaultdict(set)

    for edge in all_edge:
        user, item = edge
        edge_dict[user].add(item)

    min_user = 0             # 0
    num_user = no_users      # in our case, users/items ids start from 1
    user_graph_matrix = torch.zeros(num_user, num_user)
    key_list = list(edge_dict.keys())
    key_list.sort()
    bar = tqdm(total=len(key_list))
    for head in range(len(key_list)):
        bar.update(1)
        for rear in range(head+1, len(key_list)):
            head_key = key_list[head]
            rear_key = key_list[rear]
            # print(head_key, rear_key)
            item_head = edge_dict[head_key]
            item_rear = edge_dict[rear_key]
            # print(len(user_head.intersection(user_rear)))
            inter_len = len(item_head.intersection(item_rear))
            if inter_len > 0:
                user_graph_matrix[head_key-min_user][rear_key-min_user] = inter_len
                user_graph_matrix[rear_key-min_user][head_key-min_user] = inter_len
    bar.close()

    return user_graph_matrix

   user_id  userID
0  5215572       0
1  3622615       1
2  1313770       2
3  3181149       3
4   880574       4
   recipe_id  itemID
0      17991       0
1     170724       1
2      18045       2
3      60598       3
4      47519       4
   userID  itemID  rating                  timestamp  x_label
0       0       0       5   2010-08-25T14:38:53.84\n        0
1       0       1       4  2010-09-09T14:04:45.733\n        0
2       0       2       5  2010-08-16T14:51:25.833\n        0
3       1       3       4   2009-03-15T12:10:20.85\n        0
4       2       4       5  2005-10-04T15:43:36.653\n        0


In [None]:
uid_field = 'userID'
iid_field = 'itemID'

num_user = len(pd.unique(interaction[uid_field]))
train_df = interaction[interaction['x_label'] == 0].copy()
train_data = train_df[[uid_field, iid_field]].to_numpy()

user_graph_matrix = gen_user_matrix(train_data, num_user)
user_graph = user_graph_matrix
user_num = torch.zeros(num_user)

user_graph_dict = {}
item_graph_dict = {}
edge_list_i = []
edge_list_j = []

for i in range(num_user):
    user_num[i] = len(torch.nonzero(user_graph[i]))

for i in range(num_user):
    if user_num[i] <= 200:
        user_i = torch.topk(user_graph[i],int(user_num[i]))
        edge_list_i =user_i.indices.numpy().tolist()
        edge_list_j =user_i.values.numpy().tolist()
        edge_list = [edge_list_i, edge_list_j]
        user_graph_dict[i] = edge_list
    else:
        user_i = torch.topk(user_graph[i], 200)
        edge_list_i = user_i.indices.numpy().tolist()
        edge_list_j = user_i.values.numpy().tolist()
        edge_list = [edge_list_i, edge_list_j]
        user_graph_dict[i] = edge_list

np.save(os.path.join(data_path, 'user_graph_dict.npy'), user_graph_dict, allow_pickle=True)

100%|████████████████████████████████| 68768/68768 [32:20<00:00, 35.44it/s]


In [3]:
ugd = np.load(data_path + "user_graph_dict.npy", allow_pickle=True).item()
print(len(list(ugd.keys())))
print(ugd[0])
print(ugd[0][0])
print(ugd[0][1])

68768
[[31751, 23037, 31777, 47875, 11366, 8106, 58340, 38494, 43711, 46186, 29355, 10620, 36301, 27262, 22981, 20063, 18810, 18359, 16607, 15770, 15275, 15183, 13256, 11486, 11432, 9709, 20832, 20129, 33, 18964, 17191, 16764, 15953, 16414, 14240, 13404, 12324, 11948, 10679, 10307, 11415, 8892, 9450, 8820, 8239, 8195, 6783, 6479, 6456, 5734, 6253, 5727, 5660, 5124, 4975, 4786, 4122, 21002, 20961, 20281, 20768, 19521, 19371, 18907, 19842, 17657, 17620, 16801, 17142, 16108, 16084, 15418, 18012, 14563, 14300, 13886, 13997, 12611, 12598, 11628, 13223, 11209, 11081, 10468, 10636, 9079, 9001, 8466, 15132, 7254, 7027, 6521, 6675, 5794, 6108, 5688, 7759, 5177, 5453, 4855, 5530, 4221, 4250, 7787, 3992, 5581, 4342, 4002, 3967, 3493, 3503, 3481, 2877, 3527, 3125, 114, 1880, 2276, 1858, 1807, 21115, 21161, 20851, 21194, 20323, 20608, 20091, 21228, 19559, 19564, 19275, 19788, 18935, 18948, 18502, 21241, 17665, 17689, 17597, 17833, 16964, 17035, 16641, 18002, 16142, 16146, 16048, 16292, 15536, 15571