In [None]:
import json, ast

# all products with metadata
filemap = {'train': './AmazonCat-3M_mappings/amazon-3M_train_map.txt',
            'test': './AmazonCat-3M_mappings/amazon-3M_test_map.txt',
            'meta': './metadata.json'}

prod_all = dict()
prod_rcd = dict()
with open(filemap['meta'], 'r') as f:
    for line in f:
        prod = ast.literal_eval(line.strip().replace('\n', '\\n'))
        asin = prod['asin']
        prod_all[asin] = prod
        if 'related' in prod and 'categories' in prod and 'description' in prod:
            prod_rcd[asin] = prod

print('#products in metadata.json:', len(prod_all))
print('#products with rel/cat/des:', len(prod_rcd))

In [None]:
testNodes = set()
prod_gcn = dict()
asin2id = dict()
cnt_id = 0
asinlist = []

for kword in ['train', 'test']:
    with open(filemap[kword], 'r') as f:
        for line in f:
            asin = line.split()[0]
            if asin in prod_rcd:
                if kword == 'test':
                    testNodes.add(asin)
                prod_gcn[asin] = prod_rcd[asin]
                asin2id[asin] = cnt_id
                cnt_id += 1
                asinlist.append(asin)

print('#products with rel/cat/des/feat (GCN assumptions)', len(prod_gcn))
print('#trainNodes:', len(prod_gcn)-len(testNodes), 'testNodes:', len(testNodes))

print(len(asin2id))

In [None]:
topK = 1
cat2id = dict()
cnt_id = 0

id_map = dict()
class_map = dict()
nodes = []
for idx, asin in enumerate(asinlist):
    prod = prod_gcn[asin]
    isTest = True if asin in testNodes else False

    id_map[idx] = idx
    label_set = set()
    for cat in prod['categories'][0][:topK]:
        if cat not in cat2id:
            cat2id[cat] = (cnt_id, 0, 0)
            cnt_id += 1
        label_id = cat2id[cat][0]
        label_set.add(label_id)

        if isTest:
            cat2id[cat] = (cat2id[cat][0], cat2id[cat][1], cat2id[cat][2]+1)
        else:
            cat2id[cat] = (cat2id[cat][0], cat2id[cat][1]+1, cat2id[cat][2])
    # only pick one of the labels
    class_map[idx] = list(label_set)[0]

    attr = dict()
    attr['id'] = idx
    attr['val'] = isTest
    attr['test'] = False
    nodes.append(attr)

print('#classes:', len(cat2id))

links = []
links_set = set()
for idx, asin in enumerate(asinlist):
    for rel, neighbors in prod_gcn[asin]['related'].items():
        for asin_nei in neighbors:
            if asin_nei not in asin2id: continue
            idx_nei = asin2id[asin_nei]
            lk = (idx, idx_nei) if idx_nei > idx else (idx_nei, idx)
            if lk not in links_set:
                links_set.add(lk)
                links.append({'source': lk[0], 'target': lk[1]})
print('#links between products:', len(links))

In [None]:
G_json = dict()
G_json['directed'] = False
G_json['graph'] = {}
G_json['nodes'] = nodes
G_json['links'] = links
G_json['multigraph'] = False

with open("outputs/Amazon2M-G.json", "w") as G_json_f,\
        open("outputs/Amazon2M-id_map.json", "w") as id_map_f,\
        open("outputs/Amazon2M-class_map.json", "w") as class_map_f:
    json.dump(G_json, G_json_f)
    json.dump(id_map, id_map_f)
    json.dump(class_map, class_map_f)