In [34]:
import os
import pickle
import random
from tqdm import tqdm
import numpy as np


### Pay attention to the reproducibility !!

In [35]:
data_dir=f"/Users/yehaoran/Desktop/KGAgentEcno/Graph-CoT-main/data/processed_data/amazon"

In [36]:
# read processed graph
import json 
graph = json.load(open(os.path.join(data_dir, 'magazine_graph.json')))
print(graph.keys())

dict_keys(['item_nodes', 'brand_nodes'])


In [37]:
all_generated_data = {} # key: triple (question (str), answer (str)), value: generated data (List)
k = 10

### Design questions (one type of question in one cell)

In [38]:
# 1-hop reasoning (easy)
# What is the brand of item xxx?
# What is the price of item xxx?
# What is the category of item xxx?

random.seed(2023)
item_ids = list(graph['item_nodes'].keys()) # 9430088

question = "What is the brand of item {item_title}?"
answer = "{brand_name}"
generated_data = []
random.shuffle(item_ids)
for item_id in item_ids:
    item_title = graph['item_nodes'][item_id]['features']['title']
    brand_ids = graph['item_nodes'][item_id]['neighbors']['brand']

    if len(brand_ids) != 1:
        continue

    brand_names = [graph['brand_nodes'][brand_id]['features']['name'] for brand_id in brand_ids]
    if len(brand_names)>0 and item_title!='':
        generated_data.append({"item_title":item_title, "brand_name": ', '.join(brand_names)})
    if len(generated_data)==k:
        break
all_generated_data[(question, answer)] = generated_data

In [39]:
random.seed(2024)
question = "What is the category of item {item_title}?"
answer = "{category}"

item_ids = list(graph['item_nodes'].keys())

generated_data = []
random.shuffle(item_ids)
for item_id in item_ids:
    item_title = graph['item_nodes'][item_id]['features']['title']
    category = graph['item_nodes'][item_id]['features']['category']

    if len(category) != 1:
        continue

    if item_title!='':
        generated_data.append({"item_title":item_title, "category":', '.join(category)})
    if len(generated_data)==k:
        break

all_generated_data[(question, answer)] = generated_data

In [40]:
random.seed(2025)

question = "What is the price of item {item_title}?"
answer = "{price}"
item_ids = list(graph['item_nodes'].keys())

generated_data = []
random.shuffle(item_ids)
for item_id in item_ids:
    item_title = graph['item_nodes'][item_id]['features']['title']
    price = graph['item_nodes'][item_id]['features']['price']
    
    if price!='' and item_title!='':
        generated_data.append({"item_title":item_title, "price":price})
    if len(generated_data)==k:
        break

all_generated_data[(question, answer)] = generated_data

### Degree-based reasoning (easy)

In [41]:
##### How many “co-viewed” items does item xxx have?
##### How many “co-purchased” items does item xxx have? # TODO ambiguous question?
##### How many items are in brand xxx?

random.seed(2026)

question = "How many co-viewed items does item {item_title} have?"
answer = "{num}"
item_ids = list(graph['item_nodes'].keys())

generated_data = []
random.shuffle(item_ids)
for item_id in item_ids:
    item_title = graph['item_nodes'][item_id]['features']['title']
    related_item_ids = graph['item_nodes'][item_id]['neighbors']['also_viewed_item']
    
    if item_title!='':
        generated_data.append({"item_title":item_title, "num": len(related_item_ids)})
    if len(generated_data)==k:
        break
all_generated_data[(question, answer)] = generated_data

In [42]:
random.seed(2027)

question = "How many bought-together items does item {item_title} have?"
answer = "{num}"
item_ids = list(graph['item_nodes'].keys())

generated_data = []
random.shuffle(item_ids)
for item_id in item_ids:
    item_title = graph['item_nodes'][item_id]['features']['title']
    related_item_ids = graph['item_nodes'][item_id]['neighbors']['bought_together_item']
    if item_title!='':
        generated_data.append({"item_title":item_title, "num": len(related_item_ids)})
    if len(generated_data)==k:
        break

all_generated_data[(question, answer)] = generated_data

In [43]:
random.seed(2028)

question = "How many buy-after-viewing items does item {item_title} have?"
answer = "{num}"
item_ids = list(graph['item_nodes'].keys())

generated_data = []
random.shuffle(item_ids)
for item_id in item_ids:
    item_title = graph['item_nodes'][item_id]['features']['title']
    related_item_ids = graph['item_nodes'][item_id]['neighbors']['buy_after_viewing_item']
    if item_title!='':
        generated_data.append({"item_title":item_title, "num": len(related_item_ids)})
    if len(generated_data)==k:
        break

all_generated_data[(question, answer)] = generated_data

In [44]:
random.seed(2029)

question = "How many also-bought items does item {item_title} have?"
answer = "{num}"
item_ids = list(graph['item_nodes'].keys())

generated_data = []
random.shuffle(item_ids)
for item_id in item_ids:
    item_title = graph['item_nodes'][item_id]['features']['title']
    related_item_ids = graph['item_nodes'][item_id]['neighbors']['also_bought_item']
    if item_title!='':
        generated_data.append({"item_title":item_title, "num": len(related_item_ids)})
    if len(generated_data) == k:
        break
    
all_generated_data[(question, answer)] = generated_data

In [45]:
random.seed(2030)

question = "How many items are in brand {brand_name}?"
answer = "{num}"
generated_data = []

brand_ids = list(graph['brand_nodes'].keys()) # 110796
random.shuffle(brand_ids)

for brand_id in brand_ids:
    brand_name = graph['brand_nodes'][brand_id]['features']['name']
    within_item_ids = graph['brand_nodes'][brand_id]['neighbors']['item']
    if brand_name!='':
        generated_data.append({"brand_name":brand_name, "num": len(within_item_ids)})
    if len(generated_data)==k:
        break

all_generated_data[(question, answer)] = generated_data

In [46]:
random.seed(2031)

question = "What is the main category of item {item_title}?"
answer = "{main_category}"
item_ids = list(graph['item_nodes'].keys())

generated_data = []
random.shuffle(item_ids)
for item_id in item_ids:
    item_title = graph['item_nodes'][item_id]['features']['title']
    main_cat = graph['item_nodes'][item_id]['features']['main_cat']

    if main_cat != '' and item_title != '':
        generated_data.append({"item_title": item_title, "main_category": main_cat})
    if len(generated_data) == k:
        break

all_generated_data[(question, answer)] = generated_data

random.seed(2032)

question = "What is the publisher of item {item_title}?"
answer = "{publisher}"
item_ids = list(graph['item_nodes'].keys())

generated_data = []
random.shuffle(item_ids)
for item_id in item_ids:
    item_title = graph['item_nodes'][item_id]['features']['title']
    details = graph['item_nodes'][item_id]['features']['details']
    publisher = details.get('Publisher:', '') if details else ''

    if publisher != '' and item_title != '':
        generated_data.append({"item_title": item_title, "publisher": publisher})
    if len(generated_data) == k:
        break

all_generated_data[(question, answer)] = generated_data

random.seed(2033)

question = "What is the format of item {item_title}?"
answer = "{format}"
item_ids = list(graph['item_nodes'].keys())

generated_data = []
random.shuffle(item_ids)
for item_id in item_ids:
    item_title = graph['item_nodes'][item_id]['features']['title']
    details = graph['item_nodes'][item_id]['features']['details']
    format_type = details.get('Format:', '') if details else ''

    if format_type != '' and item_title != '':
        generated_data.append({"item_title": item_title, "format": format_type})
    if len(generated_data) == k:
        break

all_generated_data[(question, answer)] = generated_data

### Multi-hop reasoning (medium)

In [47]:
random.seed(2034)

question = "Find the items which are in the same brand and same category as item {item_title}."
answer = "{item_title_neighbour}"

generated_data = []

item_ids = list(graph['item_nodes'].keys())
random.shuffle(item_ids)
brand_ids = list(graph['brand_nodes'].keys()) # 110796
random.shuffle(brand_ids)

for item_id in item_ids:
    item_features = graph['item_nodes'][item_id]['features']
    item_title = item_features['title']
    if item_title == '':
        continue
    brand_ids = graph['item_nodes'][item_id]['neighbors']['brand']
    if len(brand_ids) != 1:
        continue

    brand_id = brand_ids[0]  # search for same brand, just use the first brand
    within_item_ids = graph['brand_nodes'][brand_id]['neighbors']['item']
    result_list = []
    for within_item_id in within_item_ids: 
        if within_item_id==item_id:
            continue
        neighbor_features = graph['item_nodes'][within_item_id]['features']
        neighbor_categories = neighbor_features['category']
        if len(neighbor_categories)==0:
            continue
        neighbor_category = neighbor_categories[0]  # search for same category, just use the first category
        if neighbor_category in item_features['category']: 
            result_list.append(neighbor_features['title'])
            #generated_data.append({"item_title":item_title, "item_title_neighbour":neighbor_features['title']})
            #break

    if len(result_list) < 20 and len(result_list) > 0:
        generated_data.append({"item_title":item_title, "item_title_neighbour": ', '.join(result_list)})

    if len(generated_data)==k:
        break
all_generated_data[(question, answer)] = generated_data

In [48]:
random.seed(2035)

question = "Which item shares over {num} co-viewed items with item {item_title}?"
answer = "{item_title_neighbour}"

num = 4
generated_data = []

item_ids = list(graph['item_nodes'].keys())
random.shuffle(item_ids)

brand_ids = list(graph['brand_nodes'].keys()) # 110796
random.shuffle(brand_ids)

for item_id in item_ids:
    item_features= graph['item_nodes'][item_id]['features']
    item_title = item_features['title']

    if item_title=='':
        continue

    coview_item_ids = graph['item_nodes'][item_id]['neighbors']['also_viewed_item']
    if len(coview_item_ids) < num:
        continue

    res = []
    for search_item_id in tqdm(item_ids): 
        if (search_item_id==item_id) or (search_item_id not in graph['item_nodes']):
            continue
        neighbor_coview_item_ids= graph['item_nodes'][search_item_id]['neighbors']['also_viewed_item']
        if len(neighbor_coview_item_ids)<num:
            continue   
        coview_item_ids_set = set(coview_item_ids)
        neighbor_coview_item_ids_set = set(neighbor_coview_item_ids)
        if len(coview_item_ids_set.intersection(neighbor_coview_item_ids_set))>=num:
            neighbor_features = graph['item_nodes'][search_item_id]['features']
            res.append(neighbor_features['title'])
        if len(res) > 30:
            break

    if len(res) < 20 and len(res) > 0:
        generated_data.append({"num":num, "item_title":item_title, "item_title_neighbour": ', '.join(res)})

    if len(generated_data) == k:
        break
all_generated_data[(question, answer)] = generated_data

 14%|█▍        | 335/2320 [00:00<00:00, 480045.04it/s]
 43%|████▎     | 992/2320 [00:00<00:00, 889809.57it/s]
 65%|██████▌   | 1516/2320 [00:00<00:00, 1079003.03it/s]
100%|██████████| 2320/2320 [00:00<00:00, 763378.46it/s]
 37%|███▋      | 865/2320 [00:00<00:00, 733685.13it/s]
100%|██████████| 2320/2320 [00:00<00:00, 1107911.34it/s]
100%|██████████| 2320/2320 [00:00<00:00, 1305619.92it/s]
 35%|███▍      | 808/2320 [00:00<00:00, 871431.64it/s]
 13%|█▎        | 302/2320 [00:00<00:00, 839416.71it/s]
100%|██████████| 2320/2320 [00:00<00:00, 1218480.50it/s]
100%|██████████| 2320/2320 [00:00<00:00, 998746.31it/s]
 35%|███▍      | 806/2320 [00:00<00:00, 974519.75it/s]
100%|██████████| 2320/2320 [00:00<00:00, 965834.77it/s]
100%|██████████| 2320/2320 [00:00<00:00, 1073445.70it/s]
 16%|█▌        | 371/2320 [00:00<00:00, 872735.16it/s]
 26%|██▌       | 595/2320 [00:00<00:00, 1036383.26it/s]
 29%|██▉       | 670/2320 [00:00<00:00, 1058848.41it/s]
 26%|██▌       | 607/2320 [00:00<00:00, 922777.28i

In [49]:
random.seed(2036)

question = "Which item shares over {num} bought-together items with item {item_title}?"
answer = "{item_title_neighbour}"

num = 4
generated_data = []

item_ids = list(graph['item_nodes'].keys())
random.shuffle(item_ids)

brand_ids = list(graph['brand_nodes'].keys()) # 110796
random.shuffle(brand_ids)

for item_id in item_ids:
    item_features= graph['item_nodes'][item_id]['features']
    item_title = item_features['title']
    
    if item_title=='':
        continue
    
    coview_item_ids = graph['item_nodes'][item_id]['neighbors']['bought_together_item']
    if len(coview_item_ids) < num:
        continue
    
    res = []
    for search_item_id in tqdm(item_ids): 
        if (search_item_id==item_id) or (search_item_id not in graph['item_nodes']):
            continue
        neighbor_coview_item_ids= graph['item_nodes'][search_item_id]['neighbors']['bought_together_item']
        if len(neighbor_coview_item_ids)<num:
            continue   
        coview_item_ids_set = set(coview_item_ids)
        neighbor_coview_item_ids_set = set(neighbor_coview_item_ids)
        if len(coview_item_ids_set.intersection(neighbor_coview_item_ids_set))>=num:
            neighbor_features = graph['item_nodes'][search_item_id]['features']
            res.append(neighbor_features['title'])
        if len(res) > 30:
            break

    if len(res) < 20 and len(res) > 0:
        generated_data.append({"num":num, "item_title":item_title, "item_title_neighbour": ', '.join(res)})
  
    if len(generated_data) == k:
        break
all_generated_data[(question, answer)] = generated_data


In [50]:
random.seed(2037)

question = "How many items have the same bought-together items with item {item_title}?"
answer = "{num}"
generated_data = []

item_ids = list(graph['item_nodes'].keys())
random.shuffle(item_ids)

brand_ids = list(graph['brand_nodes'].keys()) # 110796
random.shuffle(brand_ids)

for item_id in item_ids:
    item_features= graph['item_nodes'][item_id]['features']
    item_title = item_features['title']

    if item_title=='':
        continue

    coview_item_ids = graph['item_nodes'][item_id]['neighbors']['bought_together_item']
    
    num_shared = 0
    for search_item_id in item_ids: 
        if (search_item_id==item_id) or (search_item_id not in graph['item_nodes']):
            continue
        neighbor_coview_item_ids= graph['item_nodes'][search_item_id]['neighbors']['bought_together_item']

        coview_item_ids_set = set(coview_item_ids)
        neighbor_coview_item_ids_set = set(neighbor_coview_item_ids)

        if coview_item_ids_set==neighbor_coview_item_ids_set:
            num_shared+=1
    
    if num_shared>0 and num_shared<100:
        generated_data.append({"num":num_shared, "item_title":item_title})
    
    if len(generated_data) == k:
        break
all_generated_data[(question, answer)] = generated_data

In [51]:
random.seed(2038)

question = "What is the average price of the bought-together items with {item_title}?"

answer = "{average_price}"
generated_data = []

item_ids = list(graph['item_nodes'].keys())
random.shuffle(item_ids)


for item_id in item_ids:
    item_features= graph['item_nodes'][item_id]['features']
    item_title = item_features['title']
    
    if item_title=='':
        continue
    
    cobuy_item_ids = graph['item_nodes'][item_id]['neighbors']['bought_together_item']
    if len(cobuy_item_ids)==0:
        continue
    
    all_price=[]
    for search_item_id in tqdm(cobuy_item_ids):
        if search_item_id not in graph['item_nodes']:
            continue
        price = graph['item_nodes'][search_item_id]['features']['price']
        if price!='':
            try:
                price_float = float(price)
                all_price.append(price_float)
            except ValueError:
                continue
    
    if len(all_price)>0:
        generated_data.append({"item_title":item_title, "average_price": round(sum(all_price)/len(all_price),2) })
  
    if len(generated_data) == k:
        break
all_generated_data[(question, answer)] = generated_data

In [52]:
random.seed(2039)

question = "What is the average price of the co-viewed items with {item_title}?"

answer = "{average_price}"
generated_data = []

item_ids = list(graph['item_nodes'].keys())
random.shuffle(item_ids)


for item_id in item_ids:
    item_features= graph['item_nodes'][item_id]['features']
    item_title = item_features['title']
    
    if item_title=='':
        continue
    
    coview_item_ids = graph['item_nodes'][item_id]['neighbors']['also_viewed_item']
    if len(coview_item_ids)==0:
        continue
    
    all_price=[]
    for search_item_id in tqdm(coview_item_ids): 
        if search_item_id not in graph['item_nodes']:
            continue
        price = graph['item_nodes'][search_item_id]['features']['price']
        if price!='':
            try:
                price_float = float(price)
                all_price.append(price_float)
            except ValueError:
                continue
    
    if len(all_price)>0:
        generated_data.append({"item_title":item_title, "average_price": round(sum(all_price)/len(all_price),2) })
  
    if len(generated_data) == k:
        break
all_generated_data[(question, answer)] = generated_data

100%|██████████| 6/6 [00:00<00:00, 78889.73it/s]
100%|██████████| 34/34 [00:00<00:00, 976755.73it/s]
100%|██████████| 29/29 [00:00<00:00, 367476.79it/s]
100%|██████████| 4/4 [00:00<00:00, 118149.41it/s]
100%|██████████| 21/21 [00:00<00:00, 128772.49it/s]
100%|██████████| 30/30 [00:00<00:00, 939023.28it/s]
100%|██████████| 30/30 [00:00<00:00, 1066348.47it/s]
100%|██████████| 5/5 [00:00<00:00, 313007.76it/s]
100%|██████████| 2/2 [00:00<00:00, 127100.12it/s]
100%|██████████| 27/27 [00:00<00:00, 871124.68it/s]
100%|██████████| 7/7 [00:00<00:00, 315700.30it/s]
100%|██████████| 10/10 [00:00<00:00, 301748.49it/s]
100%|██████████| 30/30 [00:00<00:00, 1113532.04it/s]
100%|██████████| 5/5 [00:00<00:00, 279620.27it/s]
100%|██████████| 21/21 [00:00<00:00, 1000913.45it/s]
100%|██████████| 3/3 [00:00<00:00, 174762.67it/s]
100%|██████████| 2/2 [00:00<00:00, 118149.41it/s]
100%|██████████| 11/11 [00:00<00:00, 640796.44it/s]
100%|██████████| 3/3 [00:00<00:00, 199728.76it/s]
100%|██████████| 31/31 [00:0

In [53]:
random.seed(2040)

question = "What is the most popular category name of the bought-together items with {item_title}?"

answer = "{category}"
generated_data = []

item_ids = list(graph['item_nodes'].keys())
random.shuffle(item_ids)


for item_id in item_ids:
    item_features= graph['item_nodes'][item_id]['features']
    item_title = item_features['title']
    
    if item_title=='':
        continue
    
    cobuy_item_ids = graph['item_nodes'][item_id]['neighbors']['bought_together_item']
    if len(cobuy_item_ids)==0:
        continue
    
    category_counter={}
    for search_item_id in tqdm(coview_item_ids):
        if search_item_id not in graph['item_nodes']:
            continue
        price = graph['item_nodes'][search_item_id]['features']['price']
        if price!='':
            try:
                price_float = float(price)
                all_price.append(price_float)
            except ValueError:
                continue
    
        for cate in category:
            if cate in category_counter:
                category_counter[cate]+=1
            else:
                category_counter[cate]=1
        
    if len(category_counter)>0:
        most_popular_category= max(category_counter, key= lambda x: category_counter[x]) 
        generated_data.append({"item_title":item_title, "category": most_popular_category })
  
    if len(generated_data) == k:
        break
all_generated_data[(question, answer)] = generated_data

In [54]:
random.seed(2041)

question = "What is the most popular category name of the co-viewed items with {item_title}?"

answer = "{category}"
generated_data = []

item_ids = list(graph['item_nodes'].keys())
random.shuffle(item_ids)


for item_id in item_ids:
    item_features= graph['item_nodes'][item_id]['features']
    item_title = item_features['title']
    
    if item_title=='':
        continue
    
    coview_item_ids = graph['item_nodes'][item_id]['neighbors']['also_viewed_item']
    if len(coview_item_ids)==0:
        continue
    
    category_counter={}
    for search_item_id in tqdm(coview_item_ids): 
        if search_item_id not in graph['item_nodes']:
            continue
        category = graph['item_nodes'][search_item_id]['features']['category']
        
        if len(category)!=1: # a list with a string
            continue
    
        for cate in category:
            if cate in category_counter:
                category_counter[cate]+=1
            else:
                category_counter[cate]=1
        

    if len(category_counter)>0:
        most_popular_category= max(category_counter, key= lambda x: category_counter[x]) 
        generated_data.append({"item_title":item_title, "category": most_popular_category })
  
    if len(generated_data) == k:
        break
all_generated_data[(question, answer)] = generated_data

100%|██████████| 4/4 [00:00<00:00, 142179.80it/s]
100%|██████████| 7/7 [00:00<00:00, 333637.82it/s]
100%|██████████| 29/29 [00:00<00:00, 805528.58it/s]
100%|██████████| 86/86 [00:00<00:00, 1283665.99it/s]
100%|██████████| 32/32 [00:00<00:00, 1147160.07it/s]
100%|██████████| 23/23 [00:00<00:00, 570822.44it/s]
100%|██████████| 30/30 [00:00<00:00, 1198372.57it/s]
100%|██████████| 1/1 [00:00<00:00, 59074.70it/s]
100%|██████████| 7/7 [00:00<00:00, 315700.30it/s]
100%|██████████| 31/31 [00:00<00:00, 812646.40it/s]
100%|██████████| 9/9 [00:00<00:00, 428962.91it/s]
100%|██████████| 22/22 [00:00<00:00, 961194.67it/s]
100%|██████████| 2/2 [00:00<00:00, 118149.41it/s]
100%|██████████| 37/37 [00:00<00:00, 1193763.45it/s]
100%|██████████| 18/18 [00:00<00:00, 857925.82it/s]
100%|██████████| 10/10 [00:00<00:00, 530924.56it/s]
100%|██████████| 1/1 [00:00<00:00, 72315.59it/s]
100%|██████████| 48/48 [00:00<00:00, 1170503.44it/s]
100%|██████████| 5/5 [00:00<00:00, 313007.76it/s]
100%|██████████| 2/2 [00:

In [55]:
random.seed(2042)

question = "What is the most popular publisher in category {category}?"
answer = "{publisher}"
generated_data = []

# 获取所有独特的类别
all_categories = set()
for item_id in graph['item_nodes']:
    categories = graph['item_nodes'][item_id]['features']['category']
    if categories:
        all_categories.update(categories)

categories_list = list(all_categories)
random.shuffle(categories_list)

for category in categories_list:
    publisher_counter = {}

    # 统计该类别中的出版商数量
    for item_id in graph['item_nodes']:
        item_categories = graph['item_nodes'][item_id]['features']['category']
        if category in item_categories:
            details = graph['item_nodes'][item_id]['features']['details']
            publisher = details.get('Publisher:', '') if details else ''
            if publisher != '':
                if publisher in publisher_counter:
                    publisher_counter[publisher] += 1
                else:
                    publisher_counter[publisher] = 1

    if len(publisher_counter) > 0:
        most_popular_publisher = max(publisher_counter, key=lambda x: publisher_counter[x])
        generated_data.append({"category": category, "publisher": most_popular_publisher})

    if len(generated_data) == k:
        break

all_generated_data[(question, answer)] = generated_data

random.seed(2043)

question = "Which items from publisher {publisher} have the most also-viewed items?"
answer = "{item_title}"
generated_data = []

# 收集所有出版商
all_publishers = set()
for item_id in graph['item_nodes']:
    details = graph['item_nodes'][item_id]['features']['details']
    publisher = details.get('Publisher:', '') if details else ''
    if publisher != '':
        all_publishers.add(publisher)

publishers_list = list(all_publishers)
random.shuffle(publishers_list)

for publisher in publishers_list:
    publisher_items = []

    # 找到该出版商的所有商品
    for item_id in graph['item_nodes']:
        details = graph['item_nodes'][item_id]['features']['details']
        item_publisher = details.get('Publisher:', '') if details else ''
        if item_publisher == publisher:
            item_title = graph['item_nodes'][item_id]['features']['title']
            also_viewed_count = len(graph['item_nodes'][item_id]['neighbors']['also_viewed_item'])
            if item_title != '':
                publisher_items.append((item_title, also_viewed_count))

    if len(publisher_items) > 0:
        # 找到也观看数量最多的商品
        best_item = max(publisher_items, key=lambda x: x[1])
        if best_item[1] > 0:  # 至少有一些also-viewed
            generated_data.append({"publisher": publisher, "item_title": best_item[0]})

    if len(generated_data) == k:
        break

all_generated_data[(question, answer)] = generated_data

random.seed(2044)

question = "What is the average number of also-bought items for magazines in category {category}?"
answer = "{average_count}"
generated_data = []

# 获取所有独特的类别
all_categories = set()
for item_id in graph['item_nodes']:
    categories = graph['item_nodes'][item_id]['features']['category']
    if categories:
        all_categories.update(categories)

categories_list = list(all_categories)
random.shuffle(categories_list)

for category in categories_list:
    also_bought_counts = []

    for item_id in graph['item_nodes']:
        item_categories = graph['item_nodes'][item_id]['features']['category']
        if category in item_categories:
            also_bought_count = len(graph['item_nodes'][item_id]['neighbors']['also_bought_item'])
            also_bought_counts.append(also_bought_count)

    if len(also_bought_counts) >= 3:
        avg_count = round(sum(also_bought_counts) / len(also_bought_counts), 2)
        generated_data.append({"category": category, "average_count": avg_count})

    if len(generated_data) == k:
        break

all_generated_data[(question, answer)] = generated_data

random.seed(2045)

question = "Which brand has items spanning the most different categories?"
answer = "{brand_name}"
generated_data = []

brand_category_count = {}
brand_names = {}

for item_id in graph['item_nodes']:
    brand_ids = graph['item_nodes'][item_id]['neighbors']['brand']
    categories = graph['item_nodes'][item_id]['features']['category']

    if len(brand_ids) == 1 and len(categories) > 0:
        brand_id = brand_ids[0]
        if brand_id in graph['brand_nodes']:
            brand_name = graph['brand_nodes'][brand_id]['features']['name']
            if brand_name != '':
                brand_names[brand_id] = brand_name
                if brand_id not in brand_category_count:
                    brand_category_count[brand_id] = set()
                brand_category_count[brand_id].update(categories)

# 找到跨越最多类别的品牌
if brand_category_count:
    most_diverse_brand = max(brand_category_count, key=lambda x: len(brand_category_count[x]))
    if len(brand_category_count[most_diverse_brand]) > 1:
        generated_data.append({"brand_name": brand_names[most_diverse_brand]})

all_generated_data[(question, answer)] = generated_data

## Inductive reasoning (hard)
### Recommendation - What item should be recommended to the user based on his history: {item_titles}?

In [56]:
import json
import gzip
from collections import defaultdict
from tqdm import tqdm

# Function to load reviews
def load_reviews(file_path):
    user_history = defaultdict(list)
    with open(file_path, 'r') as f:
        readin = f.readlines()
        for line in tqdm(readin):
            tmp = line.strip().split(',')
            user_history[tmp[0]].append((tmp[-1], tmp[1]))
    return user_history

# Load and preprocess reviews
# 创建模拟用户历史数据
user_history = defaultdict(list)
item_ids = list(graph['item_nodes'].keys())

# 生成一些模拟用户和购买历史
random.seed(2046)
for user_id in range(1000):  # 创建1000个模拟用户
    user_key = f"user_{user_id}"
    # 每个用户随机购买3-10个商品
    num_purchases = random.randint(3, 10)
    selected_items = random.sample(item_ids, min(num_purchases, len(item_ids)))

    for i, item_id in enumerate(selected_items):
        # 模拟时间戳，每个商品间隔一些时间
        timestamp = f"2023-{random.randint(1,12):02d}-{random.randint(1,28):02d}"
        user_history[user_key].append((timestamp, item_id))

In [57]:
random.seed(2046)

question = "What next item should be recommended to the user based on his history: {item_titles}?"
answer = "{targe_item_title}"
generated_data = []

user_ids = list(user_history.keys())
random.shuffle(user_ids)

for user_id in user_ids:
    tmp_history = user_history[user_id]
    tmp_history.sort(key=lambda x: x[0])
    
    if len(tmp_history) < 2 or tmp_history[-1][-1] not in graph['item_nodes'] or graph['item_nodes'][tmp_history[-1][-1]]['features']['title'] == '':
        continue

    item_titles = [graph['item_nodes'][idd[-1]]['features']['title'] for idd in tmp_history[-8:-1] if idd[-1] in graph['item_nodes'] and graph['item_nodes'][idd[-1]]['features']['title'] != '']
    targe_item_title = graph['item_nodes'][tmp_history[-1][-1]]['features']['title']

    if targe_item_title != '' and len(item_titles) >= 5:
        generated_data.append({"item_titles": item_titles, "targe_item_title": targe_item_title})

    if len(generated_data) == k:
        break

all_generated_data[(question, answer)] = generated_data

In [58]:
random.seed(2047)

question = "What is the network centrality score of item {item_title} based on its total connections?"
answer = "{centrality_score}"
generated_data = []

item_ids = list(graph['item_nodes'].keys())
random.shuffle(item_ids)

for item_id in item_ids:
    item_features = graph['item_nodes'][item_id]['features']
    item_title = item_features['title']

    if item_title == '':
        continue

    # 计算总连接数作为中心性分数
    neighbors = graph['item_nodes'][item_id]['neighbors']
    total_connections = (
            len(neighbors.get('also_viewed_item', [])) +
            len(neighbors.get('also_bought_item', [])) +
            len(neighbors.get('bought_together_item', [])) +
            len(neighbors.get('buy_after_viewing_item', []))
    )

    if total_connections > 0:
        generated_data.append({
            "item_title": item_title,
            "centrality_score": total_connections
        })

    if len(generated_data) == k:
        break

all_generated_data[(question, answer)] = generated_data

random.seed(2048)

question = "Which publisher has the highest average network connectivity for their magazines?"
answer = "{publisher}"
generated_data = []

publisher_connectivity = {}

for item_id in graph['item_nodes']:
    details = graph['item_nodes'][item_id]['features']['details']
    publisher = details.get('Publisher:', '') if details else ''

    if publisher != '':
        neighbors = graph['item_nodes'][item_id]['neighbors']
        total_connections = (
                len(neighbors.get('also_viewed_item', [])) +
                len(neighbors.get('also_bought_item', [])) +
                len(neighbors.get('bought_together_item', [])) +
                len(neighbors.get('buy_after_viewing_item', []))
        )

        if publisher not in publisher_connectivity:
            publisher_connectivity[publisher] = []
        publisher_connectivity[publisher].append(total_connections)

# 计算每个出版商的平均连接度
publisher_avg = {}
for publisher, connections in publisher_connectivity.items():
    if len(connections) >= 2:  # 至少2个杂志
        publisher_avg[publisher] = sum(connections) / len(connections)

if publisher_avg:
    best_publisher = max(publisher_avg, key=publisher_avg.get)
    generated_data.append({"publisher": best_publisher})

all_generated_data[(question, answer)] = generated_data

random.seed(2049)

question = "Identify magazines that serve as bridges between different categories based on their also-viewed patterns."
answer = "{bridge_magazines}"
generated_data = []

bridge_items = []

item_ids = list(graph['item_nodes'].keys())
random.shuffle(item_ids)

for item_id in item_ids[:100]:  # 限制数量以提高效率
    item_features = graph['item_nodes'][item_id]['features']
    item_title = item_features['title']
    item_categories = set(item_features['category'])

    if item_title == '' or len(item_categories) == 0:
        continue

    also_viewed_ids = graph['item_nodes'][item_id]['neighbors']['also_viewed_item']
    connected_categories = set()

    for viewed_id in also_viewed_ids:
        if viewed_id in graph['item_nodes']:
            viewed_categories = graph['item_nodes'][viewed_id]['features']['category']
            connected_categories.update(viewed_categories)

    # 如果连接的类别比自身类别多，说明是桥梁
    unique_connected = connected_categories - item_categories
    if len(unique_connected) >= 3:  # 连接到至少3个不同类别
        bridge_items.append(item_title)

    if len(bridge_items) >= k:
        break

if bridge_items:
    generated_data.append({"bridge_magazines": ', '.join(bridge_items[:k])})

all_generated_data[(question, answer)] = generated_data

random.seed(2050)

question = "What is the clustering coefficient of the magazine network for category {category}?"
answer = "{clustering_coefficient}"
generated_data = []

# 获取所有独特的类别
all_categories = set()
for item_id in graph['item_nodes']:
    categories = graph['item_nodes'][item_id]['features']['category']
    if categories:
        all_categories.update(categories)

categories_list = list(all_categories)
random.shuffle(categories_list)

for category in categories_list[:k]:
    # 找到该类别的所有商品
    category_items = []
    for item_id in graph['item_nodes']:
        if category in graph['item_nodes'][item_id]['features']['category']:
            category_items.append(item_id)

    if len(category_items) < 3:
        continue

    # 计算聚类系数（简化版本）
    total_triangles = 0
    total_possible = 0

    for item_id in category_items[:20]:  # 限制计算量
        neighbors = graph['item_nodes'][item_id]['neighbors']['also_viewed_item']
        category_neighbors = [n for n in neighbors if n in category_items]

        if len(category_neighbors) >= 2:
            # 检查邻居之间的连接
            neighbor_connections = 0
            for i, neighbor1 in enumerate(category_neighbors):
                for neighbor2 in category_neighbors[i + 1:]:
                    if neighbor1 in graph['item_nodes'] and neighbor2 in graph['item_nodes']:
                        if neighbor2 in graph['item_nodes'][neighbor1]['neighbors']['also_viewed_item']:
                            neighbor_connections += 1

            possible_connections = len(category_neighbors) * (len(category_neighbors) - 1) // 2
            if possible_connections > 0:
                total_triangles += neighbor_connections
                total_possible += possible_connections

    if total_possible > 0:
        clustering_coeff = round(total_triangles / total_possible, 3)
        generated_data.append({
            "category": category,
            "clustering_coefficient": clustering_coeff
        })

    if len(generated_data) == k:
        break

all_generated_data[(question, answer)] = generated_data

### Retrieval - What is the exact matched/substitute/complement item given this query: {item_titles}?

In [59]:
import pandas as pd
df_examples = pd.read_parquet('/Users/yehaoran/Desktop/KGAgentEcno/Graph-CoT-main/data/processed_data/amazon/shopping_queries_dataset_examples.parquet')
df_products = pd.read_parquet('/Users/yehaoran/Desktop/KGAgentEcno/Graph-CoT-main/data/processed_data/amazon/shopping_queries_dataset_products.parquet')
df_examples_products = pd.merge(
    df_examples,
    df_products,
    how='left',
    left_on=['product_locale','product_id'],
    right_on=['product_locale', 'product_id']
)

In [60]:
prod_ids = set(list(graph['item_nodes'].keys()))

In [61]:
df_task_2 = df_examples_products[df_examples_products["large_version"] == 1]
df_task_2_train = df_task_2[df_task_2["split"] == "train"]
df_task_2_test = df_task_2[df_task_2["split"] == "test"]

In [62]:
# exact match
np.random.seed(2040)

question = "What is the exact matched item given this query: {query_text}?"
answer = "{targe_item_title}"
generated_data = []

## Exact match
df_em = df_task_2_test[df_task_2_test["esci_label"] == "E"]

# shuffle the DataFrame rows
df_em = df_em.sample(frac = 1)

# process
for _, row in df_em.iterrows():
    cnt = len(df_em[(df_em.query_id == row['query_id'])])

    if row['product_locale'] == 'us' and cnt == 1 and row['product_id'] in prod_ids:
        #generated_data.append({"query_text": row['query'], "targe_item_title": row['product_title']})
        generated_data.append({"query_text": row['query'], "targe_item_title": graph['item_nodes'][row['product_id']]['features']['title']})

    if len(generated_data) == k:
        break

all_generated_data[(question, answer)] = generated_data

In [63]:
# substitutive
np.random.seed(2041)

question = "What is the substitutive item given this query: {query_text}?"
answer = "{targe_item_title}"
generated_data = []

## Exact match
df_em = df_task_2_test[df_task_2_test["esci_label"] == "S"]

# shuffle the DataFrame rows
df_em = df_em.sample(frac = 1)

# process
for _, row in df_em.iterrows():
    cnt = len(df_em[(df_em.query_id == row['query_id'])])

    if row['product_locale'] == 'us' and cnt == 1 and row['product_id'] in prod_ids:
        #generated_data.append({"query_text": row['query'], "targe_item_title": row['product_title']})
        generated_data.append({"query_text": row['query'], "targe_item_title": graph['item_nodes'][row['product_id']]['features']['title']})

    if len(generated_data) == k:
        break

all_generated_data[(question, answer)] = generated_data

In [64]:
# complementary
np.random.seed(2042)

question = "What is the complementary item given this query: {query_text}?"
answer = "{targe_item_title}"
generated_data = []

## Exact match
df_em = df_task_2_test[df_task_2_test["esci_label"] == "C"]

# shuffle the DataFrame rows
df_em = df_em.sample(frac = 1)

# process
for _, row in df_em.iterrows():
    cnt = len(df_em[(df_em.query_id == row['query_id'])])

    if row['product_locale'] == 'us' and cnt == 1 and row['product_id'] in prod_ids:
        #generated_data.append({"query_text": row['query'], "targe_item_title": row['product_title']})
        generated_data.append({"query_text": row['query'], "targe_item_title": graph['item_nodes'][row['product_id']]['features']['title']})

    if len(generated_data) == k:
        break

all_generated_data[(question, answer)] = generated_data

In [65]:
import json
import re

# 保存为pickle格式
pickle.dump(all_generated_data, open(os.path.join(data_dir, 'preprocess_samples.pkl'), 'wb'))

# 转换为新格式
formatted_data = []
qid_counter = 0

for (question_template, answer_template), data_list in all_generated_data.items():
    for data_item in data_list:
        # 替换问题模板中的占位符
        question_text = question_template
        for key, value in data_item.items():
            placeholder = "{" + key + "}"
            if placeholder in question_text:
                question_text = question_text.replace(placeholder, str(value))

        # 替换答案模板中的占位符
        answer_text = answer_template
        for key, value in data_item.items():
            placeholder = "{" + key + "}"
            if placeholder in answer_text:
                answer_text = answer_text.replace(placeholder, str(value))

        # 添加格式化后的数据
        formatted_data.append({
            "qid": str(qid_counter),
            "question": question_text,
            "answer": answer_text
        })
        qid_counter += 1

# 保存为JSONL格式（每行一个JSON对象）
with open(os.path.join(data_dir, 'new_data.json'), 'w', encoding='utf-8') as f:
    for item in formatted_data:
        f.write(json.dumps(item, ensure_ascii=False) + '\n')

# 或者保存为JSON数组格式
with open(os.path.join(data_dir, 'new_data.json'), 'w', encoding='utf-8') as f:
    json.dump(formatted_data, f, ensure_ascii=False, indent=2)

print(f"Generated {len(all_generated_data)} question types")
print(f"Files saved to {data_dir}")

Generated 31 question types
Files saved to /Users/yehaoran/Desktop/KGAgentEcno/Graph-CoT-main/data/processed_data/amazon


In [66]:
print(len(all_generated_data))

31
