In [None]:
import numpy as np
from collections import defaultdict
from tqdm import tqdm
import torch
import pandas as pd
import os

data_path = '../../data/jeehoshin/foodcom_dataset/'
dataset_name = 'foodcom'

uid = pd.read_csv(data_path + 'u_id_mapping.csv', sep='\t')
iid = pd.read_csv(data_path + 'i_id_mapping.csv', sep='\t')
interaction = pd.read_csv(data_path + 'foodcom.inter', sep='\t')

print(uid.head())
print(iid.head())
print(interaction.head())

def gen_user_matrix(all_edge, no_users):
    edge_dict = defaultdict(set)

    for edge in all_edge:
        user, item = edge
        edge_dict[user].add(item)

    min_user = 0             # 0
    num_user = no_users      # in our case, users/items ids start from 1
    user_graph_matrix = torch.zeros(num_user, num_user)
    key_list = list(edge_dict.keys())
    key_list.sort()
    bar = tqdm(total=len(key_list))
    for head in range(len(key_list)):
        bar.update(1)
        for rear in range(head+1, len(key_list)):
            head_key = key_list[head]
            rear_key = key_list[rear]
            # print(head_key, rear_key)
            item_head = edge_dict[head_key]
            item_rear = edge_dict[rear_key]
            # print(len(user_head.intersection(user_rear)))
            inter_len = len(item_head.intersection(item_rear))
            if inter_len > 0:
                user_graph_matrix[head_key-min_user][rear_key-min_user] = inter_len
                user_graph_matrix[rear_key-min_user][head_key-min_user] = inter_len
    bar.close()

    return user_graph_matrix

   user_id  userID
0     1535       0
1     1634       1
2     1676       2
3     1891       3
4     2586       4
   recipe_id  itemID
0         40       0
1         49       1
2         58       2
3         66       3
4        142       4
   userID  itemID  rating   timestamp  x_label
0       6     175       5  2000-10-23        0
1      25     445       4  2001-03-24        0
2      24      28       3  2001-04-02        0
3      21     353       4  2001-05-07        0
4      32     477       4  2001-05-16        0


In [None]:
uid_field = 'userID'
iid_field = 'itemID'

num_user = len(pd.unique(interaction[uid_field]))
train_df = interaction[interaction['x_label'] == 0].copy()
train_data = train_df[[uid_field, iid_field]].to_numpy()

user_graph_matrix = gen_user_matrix(train_data, num_user)
user_graph = user_graph_matrix
user_num = torch.zeros(num_user)

user_graph_dict = {}
item_graph_dict = {}
edge_list_i = []
edge_list_j = []

for i in range(num_user):
    user_num[i] = len(torch.nonzero(user_graph[i]))

for i in range(num_user):
    if user_num[i] <= 200:
        user_i = torch.topk(user_graph[i],int(user_num[i]))
        edge_list_i =user_i.indices.numpy().tolist()
        edge_list_j =user_i.values.numpy().tolist()
        edge_list = [edge_list_i, edge_list_j]
        user_graph_dict[i] = edge_list
    else:
        user_i = torch.topk(user_graph[i], 200)
        edge_list_i = user_i.indices.numpy().tolist()
        edge_list_j = user_i.values.numpy().tolist()
        edge_list = [edge_list_i, edge_list_j]
        user_graph_dict[i] = edge_list

np.save(os.path.join(data_path, 'user_graph_dict.npy'), user_graph_dict, allow_pickle=True)

100%|█████████████████████████████████| 7585/7585 [01:15<00:00, 100.17it/s]


In [3]:
ugd = np.load(data_path + "user_graph_dict.npy", allow_pickle=True).item()
print(len(list(ugd.keys())))
print(ugd[0])
print(ugd[0][0])
print(ugd[0][1])

7585
[[2045, 229, 41, 1589, 1834, 316, 1161, 1143, 2012, 2231, 2585, 508, 125, 1063, 68, 1492, 306, 709, 1939, 534, 574, 1800, 1410, 2704, 1912, 712, 4485, 1716, 2630, 266, 992, 973, 1545, 315, 463, 3079, 47, 1142, 4418, 2740, 602, 2990, 1422, 1427, 2009, 1384, 322, 1158, 1325, 12, 1010, 721, 2162, 404, 2892, 551, 286, 1433, 4121, 479, 20, 445, 637, 2175, 296, 2172, 2684, 3046, 579, 3272, 2640, 1409, 3738, 3308, 1677, 438, 1286, 2257, 457, 194, 1451, 3936, 1028, 6010, 4227, 2093, 5316, 327, 408, 3651, 1241, 2637, 73, 204, 1663, 2893, 646, 904, 568, 2782, 3029, 671, 660, 197, 196, 6473, 730, 593, 2785, 462, 784, 17, 144, 5325, 2105, 1670, 264, 511, 1193, 1440, 1439, 2734, 94, 641, 1450, 3985, 3322, 3802, 3557, 796, 861, 773, 625, 1475, 732, 3877, 1535, 92, 735, 2128, 6431, 6641, 11, 3511, 2975, 1368, 1965, 4249, 2176, 607, 304, 2145, 2267, 1509, 4047, 1354, 1455, 1393, 608, 1490, 25, 3031, 5100, 592, 4515, 1024, 1598, 672, 1486, 2921, 249, 2703, 358, 715, 1738, 3071, 704, 3866, 483, 110