In [1]:
import os
import pickle
import random
from tqdm import tqdm
import numpy as np


### Pay attention to the reproducibility !!

In [2]:
data_dir=f"/shared/data3/bowenj4/llm-graph-plugin/data/processed_data/amazon"

In [3]:
# read processed graph
import json 
graph = json.load(open(os.path.join(data_dir, 'graph.json')))
print(graph.keys())

dict_keys(['item_nodes', 'brand_nodes'])


In [6]:
all_generated_data = {} # key: triple (question (str), answer (str)), value: generated data (List)
k = 10

### Design questions (one type of question in one cell)

In [7]:
# 1-hop reasoning (easy)
# What is the brand of item xxx?
# What is the price of item xxx?
# What is the category of item xxx?

random.seed(2023)
item_ids = list(graph['item_nodes'].keys()) # 9430088

question = "What is the brand of item {item_title}?"
answer = "{brand_name}"
generated_data = []
random.shuffle(item_ids)
for item_id in item_ids:
    item_title = graph['item_nodes'][item_id]['features']['title']
    brand_ids = graph['item_nodes'][item_id]['neighbors']['brand']

    if len(brand_ids) != 1:
        continue

    brand_names = [graph['brand_nodes'][brand_id]['features']['name'] for brand_id in brand_ids]
    if len(brand_names)>0 and item_title!='':
        generated_data.append({"item_title":item_title, "brand_name": ', '.join(brand_names)})
    if len(generated_data)==k:
        break
all_generated_data[(question, answer)] = generated_data

In [8]:
random.seed(2024)
question = "What is the category of item {item_title}?"
answer = "{category}"

item_ids = list(graph['item_nodes'].keys())

generated_data = []
random.shuffle(item_ids)
for item_id in item_ids:
    item_title = graph['item_nodes'][item_id]['features']['title']
    category = graph['item_nodes'][item_id]['features']['category']

    if len(category) != 1:
        continue

    if item_title!='':
        generated_data.append({"item_title":item_title, "category":', '.join(category)})
    if len(generated_data)==k:
        break

all_generated_data[(question, answer)] = generated_data

In [9]:
random.seed(2025)

question = "What is the price of item {item_title}?"
answer = "{price}"
item_ids = list(graph['item_nodes'].keys())

generated_data = []
random.shuffle(item_ids)
for item_id in item_ids:
    item_title = graph['item_nodes'][item_id]['features']['title']
    price = graph['item_nodes'][item_id]['features']['price']
    
    if price!='' and item_title!='':
        generated_data.append({"item_title":item_title, "price":price})
    if len(generated_data)==k:
        break

all_generated_data[(question, answer)] = generated_data

### Degree-based reasoning (easy)

In [10]:
##### How many “co-viewed” items does item xxx have?
##### How many “co-purchased” items does item xxx have? # TODO ambiguous question?
##### How many items are in brand xxx?

random.seed(2026)

question = "How many co-viewed items does item {item_title} have?"
answer = "{num}"
item_ids = list(graph['item_nodes'].keys())

generated_data = []
random.shuffle(item_ids)
for item_id in item_ids:
    item_title = graph['item_nodes'][item_id]['features']['title']
    related_item_ids = graph['item_nodes'][item_id]['neighbors']['also_viewed_item']
    
    if item_title!='':
        generated_data.append({"item_title":item_title, "num": len(related_item_ids)})
    if len(generated_data)==k:
        break
all_generated_data[(question, answer)] = generated_data

In [11]:
random.seed(2027)

question = "How many bought-together items does item {item_title} have?"
answer = "{num}"
item_ids = list(graph['item_nodes'].keys())

generated_data = []
random.shuffle(item_ids)
for item_id in item_ids:
    item_title = graph['item_nodes'][item_id]['features']['title']
    related_item_ids = graph['item_nodes'][item_id]['neighbors']['bought_together_item']
    if item_title!='':
        generated_data.append({"item_title":item_title, "num": len(related_item_ids)})
    if len(generated_data)==k:
        break

all_generated_data[(question, answer)] = generated_data

In [12]:
random.seed(2028)

question = "How many buy-after-viewing items does item {item_title} have?"
answer = "{num}"
item_ids = list(graph['item_nodes'].keys())

generated_data = []
random.shuffle(item_ids)
for item_id in item_ids:
    item_title = graph['item_nodes'][item_id]['features']['title']
    related_item_ids = graph['item_nodes'][item_id]['neighbors']['buy_after_viewing_item']
    if item_title!='':
        generated_data.append({"item_title":item_title, "num": len(related_item_ids)})
    if len(generated_data)==k:
        break

all_generated_data[(question, answer)] = generated_data

In [13]:
random.seed(2029)

question = "How many also-bought items does item {item_title} have?"
answer = "{num}"
item_ids = list(graph['item_nodes'].keys())

generated_data = []
random.shuffle(item_ids)
for item_id in item_ids:
    item_title = graph['item_nodes'][item_id]['features']['title']
    related_item_ids = graph['item_nodes'][item_id]['neighbors']['also_bought_item']
    if item_title!='':
        generated_data.append({"item_title":item_title, "num": len(related_item_ids)})
    if len(generated_data) == k:
        break
    
all_generated_data[(question, answer)] = generated_data

In [14]:
random.seed(2030)

question = "How many items are in brand {brand_name}?"
answer = "{num}"
generated_data = []

brand_ids = list(graph['brand_nodes'].keys()) # 110796
random.shuffle(brand_ids)

for brand_id in brand_ids:
    brand_name = graph['brand_nodes'][brand_id]['features']['name']
    within_item_ids = graph['brand_nodes'][brand_id]['neighbors']['item']
    if brand_name!='':
        generated_data.append({"brand_name":brand_name, "num": len(within_item_ids)})
    if len(generated_data)==k:
        break

all_generated_data[(question, answer)] = generated_data

### Multi-hop reasoning (medium)

In [15]:
random.seed(2031)

question = "Find the items which are in the same brand and same category as item {item_title}."
answer = "{item_title_neighbour}"

generated_data = []

item_ids = list(graph['item_nodes'].keys())
random.shuffle(item_ids)
brand_ids = list(graph['brand_nodes'].keys()) # 110796
random.shuffle(brand_ids)

for item_id in item_ids:
    item_features = graph['item_nodes'][item_id]['features']
    item_title = item_features['title']
    if item_title == '':
        continue
    brand_ids = graph['item_nodes'][item_id]['neighbors']['brand']
    if len(brand_ids) != 1:
        continue

    brand_id = brand_ids[0]  # search for same brand, just use the first brand
    within_item_ids = graph['brand_nodes'][brand_id]['neighbors']['item']
    result_list = []
    for within_item_id in within_item_ids: 
        if within_item_id==item_id:
            continue
        neighbor_features = graph['item_nodes'][within_item_id]['features']
        neighbor_categories = neighbor_features['category']
        if len(neighbor_categories)==0:
            continue
        neighbor_category = neighbor_categories[0]  # search for same category, just use the first category
        if neighbor_category in item_features['category']: 
            result_list.append(neighbor_features['title'])
            #generated_data.append({"item_title":item_title, "item_title_neighbour":neighbor_features['title']})
            #break

    if len(result_list) < 20 and len(result_list) > 0:
        generated_data.append({"item_title":item_title, "item_title_neighbour": ', '.join(result_list)})

    if len(generated_data)==k:
        break
all_generated_data[(question, answer)] = generated_data

In [16]:
random.seed(2032)

question = "Which item shares over {num} co-viewed items with item {item_title}?"
answer = "{item_title_neighbour}"

num = 4
generated_data = []

item_ids = list(graph['item_nodes'].keys())
random.shuffle(item_ids)

brand_ids = list(graph['brand_nodes'].keys()) # 110796
random.shuffle(brand_ids)

for item_id in item_ids:
    item_features= graph['item_nodes'][item_id]['features']
    item_title = item_features['title']

    if item_title=='':
        continue

    coview_item_ids = graph['item_nodes'][item_id]['neighbors']['also_viewed_item']
    if len(coview_item_ids) < num:
        continue

    res = []
    for search_item_id in tqdm(item_ids): 
        if (search_item_id==item_id) or (search_item_id not in graph['item_nodes']):
            continue
        neighbor_coview_item_ids= graph['item_nodes'][search_item_id]['neighbors']['also_viewed_item']
        if len(neighbor_coview_item_ids)<num:
            continue   
        coview_item_ids_set = set(coview_item_ids)
        neighbor_coview_item_ids_set = set(neighbor_coview_item_ids)
        if len(coview_item_ids_set.intersection(neighbor_coview_item_ids_set))>=num:
            neighbor_features = graph['item_nodes'][search_item_id]['features']
            res.append(neighbor_features['title'])
        if len(res) > 30:
            break

    if len(res) < 20 and len(res) > 0:
        generated_data.append({"num":num, "item_title":item_title, "item_title_neighbour": ', '.join(res)})

    if len(generated_data) == k:
        break
all_generated_data[(question, answer)] = generated_data

  3%|▎         | 283248/9430088 [00:01<00:51, 177475.41it/s]
 72%|███████▏  | 6822990/9430088 [00:24<00:09, 278025.97it/s]
 14%|█▍        | 1328009/9430088 [00:04<00:25, 322740.77it/s]
 91%|█████████ | 8581565/9430088 [00:25<00:02, 330355.67it/s]
  4%|▍         | 413145/9430088 [00:01<00:29, 301051.18it/s]
  4%|▍         | 420765/9430088 [00:01<00:28, 319062.84it/s]
 41%|████▏     | 3908380/9430088 [00:12<00:17, 321275.80it/s]
 36%|███▌      | 3409239/9430088 [00:10<00:18, 318776.58it/s]
  9%|▊         | 804030/9430088 [00:02<00:27, 319014.98it/s]
100%|██████████| 9430088/9430088 [00:26<00:00, 360472.84it/s]
 52%|█████▏    | 4887762/9430088 [00:13<00:12, 349841.18it/s]
 51%|█████▏    | 4850786/9430088 [00:12<00:12, 378639.81it/s]
  5%|▍         | 430062/9430088 [00:01<00:28, 315121.03it/s]
  5%|▌         | 498277/9430088 [00:01<00:25, 344238.43it/s]
 28%|██▊       | 2645891/9430088 [00:08<00:20, 327886.52it/s]
  8%|▊         | 759418/9430088 [00:02<00:26, 322402.37it/s]
 52%|█████▏    

In [17]:
random.seed(2033)

question = "Which item shares over {num} bought-together items with item {item_title}?"
answer = "{item_title_neighbour}"

num = 4
generated_data = []

item_ids = list(graph['item_nodes'].keys())
random.shuffle(item_ids)

brand_ids = list(graph['brand_nodes'].keys()) # 110796
random.shuffle(brand_ids)

for item_id in item_ids:
    item_features= graph['item_nodes'][item_id]['features']
    item_title = item_features['title']
    
    if item_title=='':
        continue
    
    coview_item_ids = graph['item_nodes'][item_id]['neighbors']['bought_together_item']
    if len(coview_item_ids) < num:
        continue
    
    res = []
    for search_item_id in tqdm(item_ids): 
        if (search_item_id==item_id) or (search_item_id not in graph['item_nodes']):
            continue
        neighbor_coview_item_ids= graph['item_nodes'][search_item_id]['neighbors']['bought_together_item']
        if len(neighbor_coview_item_ids)<num:
            continue   
        coview_item_ids_set = set(coview_item_ids)
        neighbor_coview_item_ids_set = set(neighbor_coview_item_ids)
        if len(coview_item_ids_set.intersection(neighbor_coview_item_ids_set))>=num:
            neighbor_features = graph['item_nodes'][search_item_id]['features']
            res.append(neighbor_features['title'])
        if len(res) > 30:
            break

    if len(res) < 20 and len(res) > 0:
        generated_data.append({"num":num, "item_title":item_title, "item_title_neighbour": ', '.join(res)})
  
    if len(generated_data) == k:
        break
all_generated_data[(question, answer)] = generated_data


100%|██████████| 9430088/9430088 [00:18<00:00, 522886.22it/s]
100%|██████████| 9430088/9430088 [00:16<00:00, 588777.92it/s]
100%|██████████| 9430088/9430088 [00:15<00:00, 597087.29it/s]
100%|██████████| 9430088/9430088 [00:15<00:00, 589662.57it/s]
100%|██████████| 9430088/9430088 [00:16<00:00, 586618.86it/s]
100%|██████████| 9430088/9430088 [00:16<00:00, 586532.46it/s]
100%|██████████| 9430088/9430088 [00:16<00:00, 577400.70it/s]
100%|██████████| 9430088/9430088 [00:15<00:00, 598090.23it/s]
100%|██████████| 9430088/9430088 [00:16<00:00, 584304.69it/s]
100%|██████████| 9430088/9430088 [00:16<00:00, 581786.26it/s]
100%|██████████| 9430088/9430088 [00:16<00:00, 585713.70it/s]
100%|██████████| 9430088/9430088 [00:15<00:00, 593577.23it/s]
100%|██████████| 9430088/9430088 [00:15<00:00, 601773.54it/s]
100%|██████████| 9430088/9430088 [00:16<00:00, 585243.25it/s]
100%|██████████| 9430088/9430088 [00:16<00:00, 583035.11it/s]
100%|██████████| 9430088/9430088 [00:16<00:00, 584807.47it/s]
100%|███

In [18]:
random.seed(2034)

question = "How many items have the same bought-together items with item {item_title}?"
answer = "{num}"
generated_data = []

item_ids = list(graph['item_nodes'].keys())
random.shuffle(item_ids)

brand_ids = list(graph['brand_nodes'].keys()) # 110796
random.shuffle(brand_ids)

for item_id in item_ids:
    item_features= graph['item_nodes'][item_id]['features']
    item_title = item_features['title']

    if item_title=='':
        continue

    coview_item_ids = graph['item_nodes'][item_id]['neighbors']['bought_together_item']
    
    num_shared = 0
    for search_item_id in item_ids: 
        if (search_item_id==item_id) or (search_item_id not in graph['item_nodes']):
            continue
        neighbor_coview_item_ids= graph['item_nodes'][search_item_id]['neighbors']['bought_together_item']

        coview_item_ids_set = set(coview_item_ids)
        neighbor_coview_item_ids_set = set(neighbor_coview_item_ids)

        if coview_item_ids_set==neighbor_coview_item_ids_set:
            num_shared+=1
    
    if num_shared>0 and num_shared<100:
        generated_data.append({"num":num_shared, "item_title":item_title})
    
    if len(generated_data) == k:
        break
all_generated_data[(question, answer)] = generated_data

In [20]:
random.seed(2035)

question = "What is the average price of the bought-together items with {item_title}?"

answer = "{average_price}"
generated_data = []

item_ids = list(graph['item_nodes'].keys())
random.shuffle(item_ids)


for item_id in item_ids:
    item_features= graph['item_nodes'][item_id]['features']
    item_title = item_features['title']
    
    if item_title=='':
        continue
    
    cobuy_item_ids = graph['item_nodes'][item_id]['neighbors']['bought_together_item']
    if len(cobuy_item_ids)==0:
        continue
    
    all_price=[]
    for search_item_id in tqdm(cobuy_item_ids): 
        if search_item_id not in graph['item_nodes']:
            continue
        price = graph['item_nodes'][search_item_id]['features']['price']
        if price!='':
            all_price.append(price)
    
    if len(all_price)>0:
        generated_data.append({"item_title":item_title, "average_price": round(sum(all_price)/len(all_price),2) })
  
    if len(generated_data) == k:
        break
all_generated_data[(question, answer)] = generated_data

100%|██████████| 1/1 [00:00<00:00, 14614.30it/s]
100%|██████████| 1/1 [00:00<00:00, 19784.45it/s]
100%|██████████| 2/2 [00:00<00:00, 30174.85it/s]
100%|██████████| 1/1 [00:00<00:00, 14266.34it/s]
100%|██████████| 1/1 [00:00<00:00, 15768.06it/s]
100%|██████████| 2/2 [00:00<00:00, 31775.03it/s]
100%|██████████| 2/2 [00:00<00:00, 27776.85it/s]
100%|██████████| 1/1 [00:00<00:00, 17403.75it/s]
100%|██████████| 1/1 [00:00<00:00, 17403.75it/s]
100%|██████████| 2/2 [00:00<00:00, 33420.75it/s]
100%|██████████| 2/2 [00:00<00:00, 28630.06it/s]
100%|██████████| 1/1 [00:00<00:00, 18157.16it/s]
100%|██████████| 2/2 [00:00<00:00, 34807.50it/s]
100%|██████████| 2/2 [00:00<00:00, 29641.72it/s]
100%|██████████| 1/1 [00:00<00:00, 10922.67it/s]
100%|██████████| 1/1 [00:00<00:00, 11748.75it/s]


In [21]:
random.seed(2036)

question = "What is the average price of the co-viewed items with {item_title}?"

answer = "{average_price}"
generated_data = []

item_ids = list(graph['item_nodes'].keys())
random.shuffle(item_ids)


for item_id in item_ids:
    item_features= graph['item_nodes'][item_id]['features']
    item_title = item_features['title']
    
    if item_title=='':
        continue
    
    coview_item_ids = graph['item_nodes'][item_id]['neighbors']['also_viewed_item']
    if len(coview_item_ids)==0:
        continue
    
    all_price=[]
    for search_item_id in tqdm(coview_item_ids): 
        if search_item_id not in graph['item_nodes']:
            continue
        price = graph['item_nodes'][search_item_id]['features']['price']
        if price!='':
            all_price.append(price)
    
    if len(all_price)>0:
        generated_data.append({"item_title":item_title, "average_price": round(sum(all_price)/len(all_price),2) })
  
    if len(generated_data) == k:
        break
all_generated_data[(question, answer)] = generated_data

100%|██████████| 3/3 [00:00<00:00, 39945.75it/s]
100%|██████████| 2/2 [00:00<00:00, 24600.02it/s]
100%|██████████| 16/16 [00:00<00:00, 110740.70it/s]
100%|██████████| 2/2 [00:00<00:00, 31300.78it/s]
100%|██████████| 182/182 [00:00<00:00, 171349.79it/s]
100%|██████████| 6/6 [00:00<00:00, 60787.01it/s]
100%|██████████| 1/1 [00:00<00:00, 11814.94it/s]
100%|██████████| 1/1 [00:00<00:00, 17050.02it/s]
100%|██████████| 124/124 [00:00<00:00, 235443.05it/s]
100%|██████████| 1/1 [00:00<00:00, 16912.52it/s]
100%|██████████| 31/31 [00:00<00:00, 167988.92it/s]


In [22]:
random.seed(2037)

question = "What is the most popular category name of the bought-together items with {item_title}?"

answer = "{category}"
generated_data = []

item_ids = list(graph['item_nodes'].keys())
random.shuffle(item_ids)


for item_id in item_ids:
    item_features= graph['item_nodes'][item_id]['features']
    item_title = item_features['title']
    
    if item_title=='':
        continue
    
    cobuy_item_ids = graph['item_nodes'][item_id]['neighbors']['bought_together_item']
    if len(cobuy_item_ids)==0:
        continue
    
    category_counter={}
    for search_item_id in tqdm(cobuy_item_ids): 
        if search_item_id not in graph['item_nodes']:
            continue
        category = graph['item_nodes'][search_item_id]['features']['category']
        
        if len(category)!=1: # a list with a string
            continue
    
        for cate in category:
            if cate in category_counter:
                category_counter[cate]+=1
            else:
                category_counter[cate]=1
        
    if len(category_counter)>0:
        most_popular_category= max(category_counter, key= lambda x: category_counter[x]) 
        generated_data.append({"item_title":item_title, "category": most_popular_category })
  
    if len(generated_data) == k:
        break
all_generated_data[(question, answer)] = generated_data

100%|██████████| 1/1 [00:00<00:00, 16131.94it/s]
100%|██████████| 2/2 [00:00<00:00, 33156.55it/s]
100%|██████████| 1/1 [00:00<00:00, 17772.47it/s]
100%|██████████| 2/2 [00:00<00:00, 21345.06it/s]
100%|██████████| 2/2 [00:00<00:00, 21399.51it/s]
100%|██████████| 2/2 [00:00<00:00, 22733.36it/s]
100%|██████████| 1/1 [00:00<00:00, 12087.33it/s]
100%|██████████| 2/2 [00:00<00:00, 33825.03it/s]
100%|██████████| 1/1 [00:00<00:00, 16513.01it/s]
100%|██████████| 1/1 [00:00<00:00, 16448.25it/s]
100%|██████████| 2/2 [00:00<00:00, 25653.24it/s]
100%|██████████| 2/2 [00:00<00:00, 27594.11it/s]
100%|██████████| 1/1 [00:00<00:00, 15534.46it/s]


In [23]:
random.seed(2038)

question = "What is the most popular category name of the co-viewed items with {item_title}?"

answer = "{category}"
generated_data = []

item_ids = list(graph['item_nodes'].keys())
random.shuffle(item_ids)


for item_id in item_ids:
    item_features= graph['item_nodes'][item_id]['features']
    item_title = item_features['title']
    
    if item_title=='':
        continue
    
    coview_item_ids = graph['item_nodes'][item_id]['neighbors']['also_viewed_item']
    if len(coview_item_ids)==0:
        continue
    
    category_counter={}
    for search_item_id in tqdm(coview_item_ids): 
        if search_item_id not in graph['item_nodes']:
            continue
        category = graph['item_nodes'][search_item_id]['features']['category']
        
        if len(category)!=1: # a list with a string
            continue
    
        for cate in category:
            if cate in category_counter:
                category_counter[cate]+=1
            else:
                category_counter[cate]=1
        

    if len(category_counter)>0:
        most_popular_category= max(category_counter, key= lambda x: category_counter[x]) 
        generated_data.append({"item_title":item_title, "category": most_popular_category })
  
    if len(generated_data) == k:
        break
all_generated_data[(question, answer)] = generated_data

100%|██████████| 64/64 [00:00<00:00, 150976.07it/s]
100%|██████████| 22/22 [00:00<00:00, 138758.93it/s]
100%|██████████| 2/2 [00:00<00:00, 34239.22it/s]
100%|██████████| 10/10 [00:00<00:00, 68422.58it/s]
100%|██████████| 3/3 [00:00<00:00, 30690.03it/s]
100%|██████████| 16/16 [00:00<00:00, 83991.07it/s]
100%|██████████| 15/15 [00:00<00:00, 87139.28it/s]
100%|██████████| 1/1 [00:00<00:00, 15420.24it/s]
100%|██████████| 37/37 [00:00<00:00, 113942.18it/s]
100%|██████████| 14/14 [00:00<00:00, 119593.19it/s]


## Inductive reasoning (hard)
### Recommendation - What item should be recommended to the user based on his history: {item_titles}?

In [24]:
import json
import gzip
from collections import defaultdict
from tqdm import tqdm

# Function to load reviews
def load_reviews(file_path):
    user_history = defaultdict(list)
    with open(file_path, 'r') as f:
        readin = f.readlines()
        for line in tqdm(readin):
            tmp = line.strip().split(',')
            user_history[tmp[0]].append((tmp[-1], tmp[1]))
    return user_history

# Load and preprocess reviews
user_history = load_reviews('/shared/data3/bowenj4/llm-graph-plugin/data/raw_data/amazon/item_dedup.csv')

100%|██████████| 82677131/82677131 [04:15<00:00, 323037.71it/s]


In [25]:
random.seed(2039)

question = "What next item should be recommended to the user based on his history: {item_titles}?"
answer = "{targe_item_title}"
generated_data = []

user_ids = list(user_history.keys())
random.shuffle(user_ids)

for user_id in user_ids:
    tmp_history = user_history[user_id]
    tmp_history.sort(key=lambda x: x[0])
    
    if len(tmp_history) < 2 or tmp_history[-1][-1] not in graph['item_nodes'] or graph['item_nodes'][tmp_history[-1][-1]]['features']['title'] == '':
        continue

    item_titles = [graph['item_nodes'][idd[-1]]['features']['title'] for idd in tmp_history[-8:-1] if idd[-1] in graph['item_nodes'] and graph['item_nodes'][idd[-1]]['features']['title'] != '']
    targe_item_title = graph['item_nodes'][tmp_history[-1][-1]]['features']['title']

    if targe_item_title != '' and len(item_titles) >= 5:
        generated_data.append({"item_titles": item_titles, "targe_item_title": targe_item_title})

    if len(generated_data) == k:
        break

all_generated_data[(question, answer)] = generated_data

### Retrieval - What is the exact matched/substitute/complement item given this query: {item_titles}?

In [26]:
import pandas as pd
df_examples = pd.read_parquet('/shared/data3/bowenj4/llm-graph-plugin/data/raw_data/amazon/shopping_queries_dataset/shopping_queries_dataset_examples.parquet')
df_products = pd.read_parquet('/shared/data3/bowenj4/llm-graph-plugin/data/raw_data/amazon/shopping_queries_dataset/shopping_queries_dataset_products.parquet')
df_examples_products = pd.merge(
    df_examples,
    df_products,
    how='left',
    left_on=['product_locale','product_id'],
    right_on=['product_locale', 'product_id']
)

In [27]:
prod_ids = set(list(graph['item_nodes'].keys()))

In [28]:
df_task_2 = df_examples_products[df_examples_products["large_version"] == 1]
df_task_2_train = df_task_2[df_task_2["split"] == "train"]
df_task_2_test = df_task_2[df_task_2["split"] == "test"]

In [39]:
# exact match
np.random.seed(2040)

question = "What is the exact matched item given this query: {query_text}?"
answer = "{targe_item_title}"
generated_data = []

## Exact match
df_em = df_task_2_test[df_task_2_test["esci_label"] == "E"]

# shuffle the DataFrame rows
df_em = df_em.sample(frac = 1)

# process
for _, row in df_em.iterrows():
    cnt = len(df_em[(df_em.query_id == row['query_id'])])

    if row['product_locale'] == 'us' and cnt == 1 and row['product_id'] in prod_ids:
        #generated_data.append({"query_text": row['query'], "targe_item_title": row['product_title']})
        generated_data.append({"query_text": row['query'], "targe_item_title": graph['item_nodes'][row['product_id']]['features']['title']})

    if len(generated_data) == k:
        break

all_generated_data[(question, answer)] = generated_data

In [40]:
# substitutive
np.random.seed(2041)

question = "What is the substitutive item given this query: {query_text}?"
answer = "{targe_item_title}"
generated_data = []

## Exact match
df_em = df_task_2_test[df_task_2_test["esci_label"] == "S"]

# shuffle the DataFrame rows
df_em = df_em.sample(frac = 1)

# process
for _, row in df_em.iterrows():
    cnt = len(df_em[(df_em.query_id == row['query_id'])])

    if row['product_locale'] == 'us' and cnt == 1 and row['product_id'] in prod_ids:
        #generated_data.append({"query_text": row['query'], "targe_item_title": row['product_title']})
        generated_data.append({"query_text": row['query'], "targe_item_title": graph['item_nodes'][row['product_id']]['features']['title']})

    if len(generated_data) == k:
        break

all_generated_data[(question, answer)] = generated_data

In [41]:
# complementary
np.random.seed(2042)

question = "What is the complementary item given this query: {query_text}?"
answer = "{targe_item_title}"
generated_data = []

## Exact match
df_em = df_task_2_test[df_task_2_test["esci_label"] == "C"]

# shuffle the DataFrame rows
df_em = df_em.sample(frac = 1)

# process
for _, row in df_em.iterrows():
    cnt = len(df_em[(df_em.query_id == row['query_id'])])

    if row['product_locale'] == 'us' and cnt == 1 and row['product_id'] in prod_ids:
        #generated_data.append({"query_text": row['query'], "targe_item_title": row['product_title']})
        generated_data.append({"query_text": row['query'], "targe_item_title": graph['item_nodes'][row['product_id']]['features']['title']})

    if len(generated_data) == k:
        break

all_generated_data[(question, answer)] = generated_data

In [43]:
import json
pickle.dump(all_generated_data, open(os.path.join(f'preprocess_samples.pkl'), 'wb'))

In [42]:
print(len(all_generated_data))

20
