In [1]:
import os
import random
import json
import pickle
from copy import deepcopy
from tqdm import tqdm
from collections import defaultdict

from transformers import BertTokenizerFast

In [2]:
from google.cloud import bigquery  # python client for Google BigQuery
import pandas as pd  # python library for [tabular] data analysis and manipulation
import numpy as np  # python's mathematical library for n-dimensional arrays

## download data with Google Cloud SQL

In [3]:
os.environ["GOOGLE_APPLICATION_CREDENTIALS"]="google_cloud.json"

In [4]:
# make connection
client = bigquery.Client()  # create BigQuery client to communicate with the database

# create a reference to stackoverflow dataset
dataset_ref = client.dataset("stackoverflow", project="bigquery-public-data")

# get the dataset resource
dataset = client.get_dataset(dataset_ref)

In [5]:
# all the data
tables = list(client.list_tables(dataset))  # list all the tables in the dataset

#### Only focus on comments table

In [43]:
# Construct a reference to the "posts_answers" table
comments_table_ref = dataset_ref.table("comments")

# API request - fetch the table
comments_table = client.get_table(comments_table_ref)

# Preview the first five lines of the "posts_answers" table
#data_df = client.list_rows(comments_table, max_results=2000000).to_dataframe()
data_df = client.list_rows(comments_table).to_dataframe()

In [44]:
# save to json
#json_result = data_df.to_json(orient="records")
#with open('stackoverflow/data.json','w') as f:
#    json.dump(json_result, f, indent = 6)
    
data_df.to_json('stackoverflow/data.json', orient="records")

In [7]:
# scrape data with SQL (no permission)
######################### (no permission) #########################
query = """
        SELECT text, post_id, user_id
        FROM `bigquery-public-data.stackoverflow.comments` 
        """

In [8]:
# Set up the query (cancel the query if it would use too much of 
# your quota, with the limit set to 1 GB)
######################### (no permission) #########################

safe_config = bigquery.QueryJobConfig(maximum_bytes_billed=10**10)
query_job = client.query(query, job_config=safe_config) # Your code goes here

# API request - run the query, and return a pandas DataFrame
comments_results = query_job.to_dataframe() # Your code goes here

# Preview results
print(comments_results.head())

Forbidden: 403 POST https://bigquery.googleapis.com/bigquery/v2/projects/proud-sweep-344117/jobs?prettyPrint=false: Access Denied: Project proud-sweep-344117: User does not have bigquery.jobs.create permission in project proud-sweep-344117.

Location: None
Job ID: 623dbeb3-7592-40e8-b33d-d3352028ad9f


# read review data

In [3]:
random.seed(0)

In [4]:
dataset = 'stackoverflow'
data_name = 'data'
output_dir='xxx'

In [5]:
# read raw data
with open(f'{dataset}/{data_name}.json') as f:
    data = json.load(f)
random.shuffle(data)

In [6]:
len(data)

83160601

In [7]:
data[7]

{'id': 69179198,
 'text': 'Please also take some effort to format your post and your code properly. I applied some basic indentation to your code.',
 'creation_date': 1480978297137,
 'post_id': 40984235,
 'user_id': 721644.0,
 'user_display_name': None,
 'score': 1}

In [8]:
# text processing function
def text_process(text):
    p_text = ' '.join(text.split('\r\n'))
    p_text = ' '.join(text.split('\n\r'))
    p_text = ' '.join(text.split('\n'))
    p_text = ' '.join(p_text.split('\t'))
    p_text = ' '.join(p_text.split('\rm'))
    p_text = ' '.join(p_text.split('\r'))
    p_text = ''.join(p_text.split('$'))
    p_text = ''.join(p_text.split('*'))

    return p_text

In [9]:
## rate distribution

rate_dict = defaultdict(int)
all_rates = 0

for d in tqdm(data):
    rate_dict[d['score']] += 1
    all_rates += 1
    
print(rate_dict)
print(all_rates)

100%|██████████| 83160601/83160601 [02:25<00:00, 570566.48it/s]

defaultdict(<class 'int'>, {0: 69307730, 2: 2471553, 1: 9326709, 3: 945067, 6: 123687, 5: 205822, 4: 427859, 11: 21220, 19: 4742, 55: 439, 10: 25103, 31: 1702, 17: 6338, 24: 2818, 7: 82440, 72: 205, 23: 3086, 8: 51639, 15: 7846, 13: 13290, 48: 569, 12: 16424, 9: 36934, 26: 2385, 88: 116, 14: 10899, 32: 1457, 25: 2570, 18: 5265, 16: 6823, 28: 1872, 51: 510, 39: 857, 53: 450, 80: 142, 22: 3352, 49: 500, 70: 200, 63: 291, 40: 796, 35: 1119, 38: 944, 41: 872, 71: 242, 69: 217, 37: 1066, 30: 1602, 215: 9, 162: 28, 21: 3918, 76: 180, 27: 2257, 50: 483, 33: 1323, 57: 409, 29: 1886, 60: 298, 114: 65, 36: 1100, 52: 438, 127: 45, 47: 692, 159: 23, 42: 794, 101: 81, 146: 27, 163: 24, 373: 3, 34: 1277, 106: 64, 46: 576, 86: 116, 140: 42, 79: 161, 20: 4043, 372: 3, 289: 1, 73: 196, 58: 378, 431: 2, 111: 48, 85: 125, 296: 8, 64: 298, 56: 391, 44: 679, 54: 420, 82: 152, 43: 716, 45: 668, 102: 63, 183: 22, 649: 1, 105: 62, 1331: 1, 161: 26, 94: 90, 125: 45, 148: 31, 268: 8, 89: 121, 153: 24, 110: 58, 




In [10]:
## user/item statistics
### we see 5 score as positive edge(review), 1-4 as negative ones.
### user_pos_reviews/user_neg_reviews: key<-userID, value<-list(reviews)
### item_pos_reviews/item_neg_reviews: key<-productID, value<-list(reviews)
### user_reviews_dict_save/item_reviews_dict_save: key<-userID/productID, value<-list(tuple(reviews,p/n))

user_pos_reviews = defaultdict(list)
user_neg_reviews = defaultdict(list)
item_pos_reviews = defaultdict(list)
item_neg_reviews = defaultdict(list)
user_set = set()
item_set = set()

user_reviews_dict_save = defaultdict(dict)
item_reviews_dict_save = defaultdict(dict)

blank_review_cnt = 0

for d in tqdm(data):
    if 'text' not in d or d['text'] == '' or d['text'] == None:
        blank_review_cnt += 1
        continue
    
    text = text_process(d['text'])
    user_set.add(d['user_id'])
    item_set.add(d['post_id'])
    if d['score'] >=0 :
        user_pos_reviews[d['user_id']].append(text)
        item_pos_reviews[d['post_id']].append(text)
        
        if d['post_id'] not in user_reviews_dict_save[d['user_id']]:
            user_reviews_dict_save[d['user_id']][d['post_id']] = [text,d['creation_date']]
        else:
            user_reviews_dict_save[d['user_id']][d['post_id']][0] += text
            user_reviews_dict_save[d['user_id']][d['post_id']][1] = max(user_reviews_dict_save[d['user_id']][d['post_id']][1], d['creation_date'])

        if d['post_id'] not in item_reviews_dict_save[d['post_id']]:
            item_reviews_dict_save[d['post_id']][d['user_id']] = [text,d['creation_date']]
        else:
            item_reviews_dict_save[d['post_id']][d['user_id']][0] += text
            item_reviews_dict_save[d['post_id']][d['user_id']][1] = max(item_reviews_dict_save[d['post_id']][d['user_id']][1], d['creation_date'])            
    else:
        raise ValueError('Error!')
        
print(f'Number of blank review:{blank_review_cnt}')
print(f'Number of user:{len(user_set)}, Number of item:{len(item_set)}')
print(f'user_pos_reviews.len:{len(user_pos_reviews)},user_neg_reviews.len:{len(user_neg_reviews)}')
print(f'item_pos_reviews.len:{len(item_pos_reviews)},item_neg_reviews.len:{len(item_neg_reviews)}')
print(f'user.avg.review:{all_rates/len(user_set)}, item.avg.review:{all_rates/len(item_set)}')

100%|██████████| 83160601/83160601 [43:57<00:00, 31535.34it/s]  

Number of blank review:58
Number of user:3468581, Number of item:27738624
user_pos_reviews.len:3468581,user_neg_reviews.len:0
item_pos_reviews.len:27738624,item_neg_reviews.len:0
user.avg.review:23.97539541385944, item.avg.review:2.998007435408476





In [12]:
print(f'Number of blank review:{blank_review_cnt}')
print(f'Number of user:{len(user_set)}, Number of item:{len(item_set)}')
print(f'user_pos_reviews.len:{len(user_pos_reviews)},user_neg_reviews.len:{len(user_neg_reviews)}')
print(f'item_pos_reviews.len:{len(item_pos_reviews)},item_neg_reviews.len:{len(item_neg_reviews)}')
print(f'user.avg.review:{all_rates/len(user_set)}, item.avg.review:{all_rates/len(item_set)}')

Number of blank review:58
Number of user:3468581, Number of item:27738624
user_pos_reviews.len:3468581,user_neg_reviews.len:0
item_pos_reviews.len:27738624,item_neg_reviews.len:0
user.avg.review:23.97539541385944, item.avg.review:2.998007435408476


In [13]:
# transfer to the original format

user_reviews_dict = {}
item_reviews_dict = {}

for i in tqdm(user_reviews_dict_save):
    user_reviews_dict[i] = [(j, user_reviews_dict_save[i][j][0], user_reviews_dict_save[i][j][1]) for j in user_reviews_dict_save[i]]

for i in tqdm(item_reviews_dict_save):
    item_reviews_dict[i] = [(j, item_reviews_dict_save[i][j][0], item_reviews_dict_save[i][j][1]) for j in item_reviews_dict_save[i]]

100%|██████████| 3468581/3468581 [06:05<00:00, 9490.74it/s]  
100%|██████████| 27738624/27738624 [03:09<00:00, 146286.50it/s]


In [16]:
# dump
pickle.dump(user_reviews_dict,open(f'{output_dir}/user_reviews_dict.pkl','wb'))
pickle.dump(item_reviews_dict,open(f'{output_dir}/item_reviews_dict.pkl','wb'))

In [19]:
## filter

line = 10
none_cnt = 0

item_reviews_cnt_dict = {}
item_reviews_filtered_dict = {}
user_reviews_filtered_dict = defaultdict(list)

for i in tqdm(item_reviews_dict):
    ## add it into cnt dict
    if i in item_reviews_cnt_dict:
        raise ValueError('stop')
    item_reviews_cnt_dict[i] = len(item_reviews_dict[i])

    ## reorder it
    item_reviews_dict[i].sort(key=lambda x:x[2])
    
    ## statistics for line
    if len(item_reviews_dict[i]) >= line:
        item_reviews_filtered_dict[i] = deepcopy(item_reviews_dict[i])
        for r in item_reviews_filtered_dict[i]:
            if r[0] == None:
                none_cnt += 1
            else:
                # user_reviews_filtered_dict[int(r[0])].append((r[0],i,r[2],1))
                user_reviews_filtered_dict[int(r[0])].append((i,r[1],r[2]))

print(f'Number of above line user:{len(user_reviews_filtered_dict)}')
print(f'Number of above line item:{len(item_reviews_filtered_dict)}')
print(f'None cnt:{none_cnt}')

100%|██████████| 27738624/27738624 [01:30<00:00, 307698.91it/s]

Number of above line user:138816
Number of above line item:26027
None cnt:3172





In [20]:
# print filtered statistics

edge_cnt = 0
for u in tqdm(user_reviews_filtered_dict):
    edge_cnt += len(user_reviews_filtered_dict[u])

print(f'Average degree of filtered user:{edge_cnt / len(user_reviews_filtered_dict)}')
print(f'Average degree of filtered item:{edge_cnt / len(item_reviews_filtered_dict)}')

100%|██████████| 138816/138816 [00:00<00:00, 1360352.95it/s]

Average degree of filtered user:2.304446173351775
Average degree of filtered item:12.290851807738118





In [22]:
## split train/val/test as 7:1:2 or 8:1:1
### user_pos_reviews/user_neg_reviews: key<-userID, value<-list(reviews)
### item_pos_reviews/item_neg_reviews: key<-productID, value<-list(reviews)
### train_user_neighbor: key<-userID, value<-list(tuple(reviews,p/n))
### train_item_neighbor: key<-userID, value<-list(tuple(reviews,p/n))

random.seed(0)

train_tuples = []
val_tuples = []
test_tuples = []
train_user_set = set()
user_id2idx = {}
item_id2idx = {}
train_user_neighbor = defaultdict(list)
train_item_neighbor = defaultdict(list)

for iid in tqdm(item_reviews_filtered_dict):
    if iid not in item_id2idx:
        item_id2idx[iid] = len(item_id2idx)
    
    for i in range(int(len(item_reviews_filtered_dict[iid])*0.7)):
        train_tuples.append((iid,item_reviews_filtered_dict[iid][i]))
        train_user_set.add(item_reviews_filtered_dict[iid][i][0])

        # add to user_id2idx
        if item_reviews_filtered_dict[iid][i][0] not in user_id2idx:
            user_id2idx[item_reviews_filtered_dict[iid][i][0]] = len(user_id2idx)

        # add to train_user_neighbor/train_item_neighbor        
        train_item_neighbor[iid].append(item_reviews_filtered_dict[iid][i])
        train_user_neighbor[item_reviews_filtered_dict[iid][i][0]].append((iid,item_reviews_filtered_dict[iid][i][1],
                                                                            item_reviews_filtered_dict[iid][i][2]))
        
    for i in range(int(len(item_reviews_filtered_dict[iid])*0.7),int(len(item_reviews_filtered_dict[iid])*0.8)):
        val_tuples.append((iid,item_reviews_filtered_dict[iid][i]))

    for i in range(int(len(item_reviews_filtered_dict[iid])*0.8),len(item_reviews_filtered_dict[iid])):
        test_tuples.append((iid,item_reviews_filtered_dict[iid][i]))
        
print(f'Number of user appearing in train_set:{len(train_user_set)} or {len(user_id2idx)}')
print(f'Train/Val/Test size:{len(train_tuples)},{len(val_tuples)},{len(test_tuples)}')

100%|██████████| 26027/26027 [00:01<00:00, 17180.49it/s]

Number of user appearing in train_set:103295 or 103295
Train/Val/Test size:217822,31652,73592





In [23]:
print(f'Number of item appearing in train_set:{len(item_id2idx)}')

Number of item appearing in train_set:26027


In [24]:
train_tuples[0]

(40544215,
 (100297.0,
  'In those 7 years, you have posted 5 questions on Meta (that still survive today), plus another [20 on Meta.SE](https://meta.stackexchange.com/users/150595/jigar-joshi?tab=activity&sort=posts). Do you plan to become more active?',
  1478855118773))

In [39]:
# generate and save train file
## user pos neighbor: 2, user neg neighbor: 0
## item pos neighbor: 5, item neg neighbor: 0

upos = 2
ipos = 5
uneg = 0
ineg = 0

random.seed(0)

with open(f'{output_dir}/{dataset}/train.tsv','w') as fout:
    for d in tqdm(train_tuples):

        # prepare sample pool for user and item
        item_pos_pool = set(deepcopy(train_item_neighbor[d[0]]))
        user_pos_pool = set(deepcopy(train_user_neighbor[d[1][0]]))
        
        # remove train sample from neighbor file
        item_pos_pool.remove(d[1])
        user_pos_pool.remove((d[0],d[1][1],d[1][2]))
        
        user_pos_pool = list(user_pos_pool)
        item_pos_pool = list(item_pos_pool)
        random.shuffle(user_pos_pool)
        random.shuffle(item_pos_pool)
        
        user_neg_pool = []
        item_neg_pool = []
        
        # sample for user
        if len(user_pos_pool) >= upos:
            user_pos_samples = user_pos_pool[:upos]
        else:
            user_pos_samples = user_pos_pool + [(-1,'')] * (upos-len(user_pos_pool))
        
        if len(user_neg_pool) >= uneg:
            user_neg_samples = user_neg_pool[:uneg]
        else:
            user_neg_samples = user_neg_pool + [(-1,'')] * (uneg-len(user_neg_pool))
        
        # sample for item
        if len(item_pos_pool) >= ipos:
            item_pos_samples = item_pos_pool[:ipos]
        else:
            item_pos_samples = item_pos_pool + [(-1,'')] * (ipos-len(item_pos_pool))
        
        if len(item_neg_pool) >= ineg:
            item_neg_samples = item_neg_pool[:ineg]
        else:
            item_neg_samples = item_neg_pool + [(-1,'')] * (ineg-len(item_neg_pool))
        
        # prepare for writing file
        user_pos_text = '\t'.join([up[1] for up in user_pos_samples])
        user_pos_neighbor = '\t'.join([str(item_id2idx[up[0]]) if up[0] != -1 else str(-1) for up in user_pos_samples])
        user_neg_text = '\t'.join([un[1] for un in user_neg_samples])
        user_neg_neighbor = '\t'.join([str(item_id2idx[un[0]]) if un[0] != -1 else str(-1) for un in user_neg_samples])
        
        item_pos_text = '\t'.join([ip[1] for ip in item_pos_samples])
        item_pos_neighbor = '\t'.join([str(user_id2idx[ip[0]]) if ip[0] != -1 else str(-1) for ip in item_pos_samples])
        item_neg_text = '\t'.join([inn[1] for inn in item_neg_samples])
        item_neg_neighbor = '\t'.join([str(user_id2idx[inn[0]]) if inn[0] != -1 else str(-1) for inn in item_neg_samples])
        
        user_line = str(user_id2idx[d[1][0]]) + '\*\*' + user_pos_text + '\*\*' + user_neg_text + '\*\*' + user_pos_neighbor + '\*\*' + user_neg_neighbor
        item_line = str(item_id2idx[d[0]]) + '\*\*' + item_pos_text + '\*\*' + item_neg_text + '\*\*' + item_pos_neighbor + '\*\*' + item_neg_neighbor
        
        # fout.write(user_line+'\$\$'+item_line+'\$\$'+str(d[1][3])+'\n')
        fout.write(item_line+'\$\$'+user_line+'\$\$'+str(1)+'\n')

100%|██████████| 217822/217822 [01:01<00:00, 3521.10it/s]


In [41]:
# generate and save val file (make sure to delete items that are not in train set)

random.seed(0)

valid_dev_edges = 0

with open(f'{output_dir}/{dataset}/val.tsv','w') as fout:
    for d in tqdm(val_tuples):
        # if item not in train item set, continue
        if d[1][0] not in train_user_set:
            continue

        # counting
        valid_dev_edges += 1

        # prepare sample pool for user and item
        user_neg_pool = []
        item_neg_pool = []
        
        item_pos_pool = deepcopy(train_item_neighbor[d[0]])
        user_pos_pool = deepcopy(train_user_neighbor[d[1][0]])
        
        random.shuffle(user_pos_pool)
        random.shuffle(item_pos_pool)
        
        # sample for user
        if len(user_pos_pool) >= upos:
            user_pos_samples = user_pos_pool[:upos]
        else:
            user_pos_samples = user_pos_pool + [(-1,'')] * (upos-len(user_pos_pool))
        
        if len(user_neg_pool) >= uneg:
            user_neg_samples = user_neg_pool[:uneg]
        else:
            user_neg_samples = user_neg_pool + [(-1,'')] * (uneg-len(user_neg_pool))
        
        # sample for item
        if len(item_pos_pool) >= ipos:
            item_pos_samples = item_pos_pool[:ipos]
        else:
            item_pos_samples = item_pos_pool + [(-1,'')] * (ipos-len(item_pos_pool))
        
        if len(item_neg_pool) >= ineg:
            item_neg_samples = item_neg_pool[:ineg]
        else:
            item_neg_samples = item_neg_pool + [(-1,'')] * (ineg-len(item_neg_pool))
        
        # prepare for writing file
        user_pos_text = '\t'.join([up[1] for up in user_pos_samples])
        user_pos_neighbor = '\t'.join([str(item_id2idx[up[0]]) if up[0] != -1 else str(-1) for up in user_pos_samples])
        user_neg_text = '\t'.join([un[1] for un in user_neg_samples])
        user_neg_neighbor = '\t'.join([str(item_id2idx[un[0]]) if un[0] != -1 else str(-1) for un in user_neg_samples])
        
        item_pos_text = '\t'.join([ip[1] for ip in item_pos_samples])
        item_pos_neighbor = '\t'.join([str(user_id2idx[ip[0]]) if ip[0] != -1 else str(-1) for ip in item_pos_samples])
        item_neg_text = '\t'.join([inn[1] for inn in item_neg_samples])
        item_neg_neighbor = '\t'.join([str(user_id2idx[inn[0]]) if inn[0] != -1 else str(-1) for inn in item_neg_samples])
        
        user_line = str(user_id2idx[d[1][0]]) + '\*\*' + user_pos_text + '\*\*' + user_neg_text + '\*\*' + user_pos_neighbor + '\*\*' + user_neg_neighbor
        item_line = str(item_id2idx[d[0]]) + '\*\*' + item_pos_text + '\*\*' + item_neg_text + '\*\*' + item_pos_neighbor + '\*\*' + item_neg_neighbor
        
        #fout.write(user_line+'\$\$'+item_line+'\$\$'+str(d[1][3])+'\n')
        fout.write(item_line+'\$\$'+user_line+'\$\$'+str(1)+'\n')
        
print(f'Number of Valid Dev Edges:{valid_dev_edges} | Total:{len(val_tuples)}')

100%|██████████| 31652/31652 [00:07<00:00, 4451.23it/s]


Number of Valid Dev Edges:20246 | Total:31652


In [42]:
# generate and save test file (make sure to delete items that are not in train set)

random.seed(0)

valid_test_edges = 0

with open(f'{output_dir}/{dataset}/test.tsv','w') as fout:
    for d in tqdm(test_tuples):
        # if item not in train item set, continue
        if d[1][0] not in train_user_set:
            continue

        # counting
        valid_test_edges += 1

        # prepare sample pool for user and item
        user_neg_pool = []
        item_neg_pool = []
        
        item_pos_pool = deepcopy(train_item_neighbor[d[0]])
        user_pos_pool = deepcopy(train_user_neighbor[d[1][0]])
        
        random.shuffle(user_pos_pool)
        random.shuffle(item_pos_pool)
        
        # sample for user
        if len(user_pos_pool) >= upos:
            user_pos_samples = user_pos_pool[:upos]
        else:
            user_pos_samples = user_pos_pool + [(-1,'')] * (upos-len(user_pos_pool))
        
        if len(user_neg_pool) >= uneg:
            user_neg_samples = user_neg_pool[:uneg]
        else:
            user_neg_samples = user_neg_pool + [(-1,'')] * (uneg-len(user_neg_pool))
        
        # sample for item
        if len(item_pos_pool) >= ipos:
            item_pos_samples = item_pos_pool[:ipos]
        else:
            item_pos_samples = item_pos_pool + [(-1,'')] * (ipos-len(item_pos_pool))
        
        if len(item_neg_pool) >= ineg:
            item_neg_samples = item_neg_pool[:ineg]
        else:
            item_neg_samples = item_neg_pool + [(-1,'')] * (ineg-len(item_neg_pool))
        
        # prepare for writing file
        user_pos_text = '\t'.join([up[1] for up in user_pos_samples])
        user_pos_neighbor = '\t'.join([str(item_id2idx[up[0]]) if up[0] != -1 else str(-1) for up in user_pos_samples])
        user_neg_text = '\t'.join([un[1] for un in user_neg_samples])
        user_neg_neighbor = '\t'.join([str(item_id2idx[un[0]]) if un[0] != -1 else str(-1) for un in user_neg_samples])
        
        item_pos_text = '\t'.join([ip[1] for ip in item_pos_samples])
        item_pos_neighbor = '\t'.join([str(user_id2idx[ip[0]]) if ip[0] != -1 else str(-1) for ip in item_pos_samples])
        item_neg_text = '\t'.join([inn[1] for inn in item_neg_samples])
        item_neg_neighbor = '\t'.join([str(user_id2idx[inn[0]]) if inn[0] != -1 else str(-1) for inn in item_neg_samples])

        user_line = str(user_id2idx[d[1][0]]) + '\*\*' + user_pos_text + '\*\*' + user_neg_text + '\*\*' + user_pos_neighbor + '\*\*' + user_neg_neighbor
        item_line = str(item_id2idx[d[0]]) + '\*\*' + item_pos_text + '\*\*' + item_neg_text + '\*\*' + item_pos_neighbor + '\*\*' + item_neg_neighbor
        
        #fout.write(user_line+'\$\$'+item_line+'\$\$'+str(d[1][3])+'\n')
        fout.write(item_line+'\$\$'+user_line+'\$\$'+str(1)+'\n')
        
print(f'Number of Valid Test Edges:{valid_test_edges} | Total:{len(test_tuples)}')

100%|██████████| 73592/73592 [00:14<00:00, 5042.30it/s]


Number of Valid Test Edges:43589 | Total:73592


In [50]:
# save side files

pickle.dump([ipos,ineg,upos,uneg],open(f'{output_dir}/{dataset}/neighbor_sampling.pkl','wb'))
pickle.dump(user_id2idx,open(f'{output_dir}/{dataset}/user_id2idx.pkl','wb'))
pickle.dump(item_id2idx,open(f'{output_dir}/{dataset}/item_id2idx.pkl','wb'))
pickle.dump([len(item_id2idx),len(user_id2idx),2],open(f'{output_dir}/{dataset}/node_num.pkl','wb'))

In [44]:
# save neighbor file

pickle.dump(train_user_neighbor,open(f'{output_dir}/{dataset}/neighbor/train_user_pos_neighbor.pkl','wb'))
pickle.dump([],open(f'{output_dir}/{dataset}/neighbor/train_user_neg_neighbor.pkl','wb'))
pickle.dump(train_item_neighbor,open(f'{output_dir}/{dataset}/neighbor/train_item_pos_neighbor.pkl','wb'))
pickle.dump([],open(f'{output_dir}/{dataset}/neighbor/train_item_neg_neighbor.pkl','wb'))