In [None]:
import os
import json
import random
from tqdm import tqdm
import pickle
from copy import deepcopy

In [None]:
# Finally, we select shelves, author, publisher, language_code, format

## Read Data

In [None]:
# read book data
train_data = []

with open('raw/train.tsv') as f:
    readin = f.readlines()
    for line in tqdm(readin):
        tmp = json.loads(line)
        train_data.append(tmp)
        
val_data = []

with open('raw/val.tsv') as f:
    readin = f.readlines()
    for line in tqdm(readin):
        tmp = json.loads(line)
        val_data.append(tmp)
        
test_data = []

with open('raw/test.tsv') as f:
    readin = f.readlines()
    for line in tqdm(readin):
        tmp = json.loads(line)
        test_data.append(tmp)

In [None]:
# read filtered shelves dict (1000~100000)
shelves_degree_dict = pickle.load(open('raw/shelves_degree_dict_1000_100000.pkl','rb'))

## process training set

In [None]:
# collect shelves, author, publisher, language_code, format base on train set

bookid_set = set()
publisher_set = set()
language_code_set = set()
format_set = set()
shelves_set = set()
author_set = set()

for b in tqdm(train_data):
    bookid_set.add(b['book_id'])
    
    if b['publisher'] != '':
        publisher_set.add(b['publisher'])
        
    if b['language_code'] != '':
        language_code_set.add(b['language_code'])
        
    if b['format'] != '':
        format_set.add(b['format'])
        
    for ss in b['popular_shelves']:
        if ss['name'] in shelves_degree_dict:
            shelves_set.add(ss['name'])
    
    for aa in b['authors']:
        author_set.add(aa['author_id'])

print(f'Book:{len(bookid_set)}, Publisher:{len(publisher_set)}, Language_code:{len(language_code_set)}, Format:{len(format_set)}, shelves:{len(shelves_set)}, author:{len(author_set)}')

In [None]:
# filter similar_paper_list for each paper inside train/val/test

for i in tqdm(range(len(train_data))):
    train_data[i]['similar_books'] = list(set(train_data[i]['similar_books']) & bookid_set)
    
for i in tqdm(range(len(val_data))):
    val_data[i]['similar_books'] = list(set(val_data[i]['similar_books']) & bookid_set)

for i in tqdm(range(len(test_data))):
    test_data[i]['similar_books'] = list(set(test_data[i]['similar_books']) & bookid_set)    

In [None]:
# paper-paper edge statistics
cnt = 0

for d in tqdm(train_data):
    if 'similar_books' not in d:
        continue
    cnt += len(d['similar_books'])

for d in tqdm(val_data):
    if 'similar_books' not in d:
        continue
    cnt += len(d['similar_books'])
    
for d in tqdm(test_data):
    if 'similar_books' not in d:
        continue
    cnt += len(d['similar_books'])
    
print(f'paper-paper edge num:{cnt}')

In [None]:
# filter shelves_list for each paper inside train/val/test

for i in tqdm(range(len(train_data))):
    new_list = []
    for ss in train_data[i]['popular_shelves']:
        if ss['name'] in shelves_degree_dict:
            new_list.append(ss)
    train_data[i]['popular_shelves'] = new_list
    
for i in tqdm(range(len(val_data))):
    new_list = []
    for ss in val_data[i]['popular_shelves']:
        if ss['name'] in shelves_degree_dict:
            new_list.append(ss)
    val_data[i]['popular_shelves'] = new_list
    
for i in tqdm(range(len(test_data))):
    new_list = []
    for ss in test_data[i]['popular_shelves']:
        if ss['name'] in shelves_degree_dict:
            new_list.append(ss)
    test_data[i]['popular_shelves'] = new_list

In [None]:
# construct id2idx_dict for shelves, author, publisher, language_code, format

shelves_id2idx_dict = {}
author_id2idx_dict = {}
publisher_id2idx_dict = {}
language_code_id2idx_dict = {}
format_id2idx_dict = {}

for ss in tqdm(shelves_set):
    shelves_id2idx_dict[ss] = len(shelves_id2idx_dict)

for aa in tqdm(author_set):
    author_id2idx_dict[aa] = len(author_id2idx_dict)

for pp in tqdm(publisher_set):
    publisher_id2idx_dict[pp] = len(publisher_id2idx_dict)
    
for ll in tqdm(language_code_set):
    language_code_id2idx_dict[ll] = len(language_code_id2idx_dict)
    
for ff in tqdm(format_set):
    format_id2idx_dict[ff] = len(format_id2idx_dict)

In [None]:
# avg statistics for similar book & shelves

similar_book_sum = 0
shelves_sum = 0

for b in tqdm(train_data):
    similar_book_sum += len(b['similar_books'])
    shelves_sum += len(b['popular_shelves'])
    
print(f'Average similar book:{similar_book_sum / len(train_data)}, Average shelves:{shelves_sum / len(train_data)}')

## File Generation

In [None]:
# sequence: book, shelves, author, publisher, language_code, format
book_neighbour = 5
shelves_neighbour = 5
author_neighbour = 2

In [None]:
# generate book_info_dict for books in trainset

train_book_dict = {} # key: book_id, value: book_dict

for b in tqdm(train_data):
    assert b['book_id'] not in train_book_dict
    train_book_dict[b['book_id']] = b

In [None]:
# generate train pairs and delete them in each other's similar_paper_list

train_pairs = []
simple_direction_cnt = 0
empty_cnt = 0

for b in tqdm(train_data):
    if len(b['similar_books']) == 0:
        empty_cnt += 1
    else:
        # sample key for tmp query
        sample_key = random.choice(b['similar_books'])

        # delete book_id in each other's similar_books list
        train_book_dict[b['book_id']]['similar_books'].pop(train_book_dict[b['book_id']]['similar_books'].index(sample_key))
        if b['book_id'] in set(train_book_dict[sample_key]['similar_books']):
            train_book_dict[sample_key]['similar_books'].pop(train_book_dict[sample_key]['similar_books'].index(b['book_id']))
        else:
            simple_direction_cnt += 1

        # add sampled pairs into train_pairs
        train_pairs.append((b['book_id'], sample_key))

print(f'Num of train pairs:{len(train_pairs)}, simple_direction_cnt:{simple_direction_cnt}, empty_cnt:{empty_cnt}')

In [None]:
# sampling function for each book
# book_text, book_neighbour * 5, shelves * 5, author * 2, publisher, language_code, format

# book_neighbour = 5
# shelves_neighbour = 5
# author_neighbour = 2

def remove_next_line(text):
    t = ' '.join(text.strip().split('\n'))
    
    return ' '.join(t.split('\t'))

def mysampling(book_info_dict):
    
    result_list = []
    
    # center book text
    book_text = remove_next_line(book_info_dict['title']+book_info_dict['description'])
    result_list.append(book_text)

    # sample book neighbour
    if len(book_info_dict['similar_books']) >= book_neighbour:
        random.shuffle(book_info_dict['similar_books'])
        sampled_book_neighbours = [remove_next_line(train_book_dict[bid]['title']+train_book_dict[bid]['description']) for bid in book_info_dict['similar_books'][:book_neighbour]]
    else:
        sampled_book_neighbours = [remove_next_line(train_book_dict[bid]['title']+train_book_dict[bid]['description']) for bid in book_info_dict['similar_books']] + [''] * (book_neighbour - len(book_info_dict['similar_books']))
    result_list += sampled_book_neighbours
    
    # sample shelves neighbour
    sampled_shelves_neighbours = []
    tmp_shelves_neighbours = sorted(book_info_dict['popular_shelves'], key=lambda x:-int(x['count']))
    if len(tmp_shelves_neighbours) >= shelves_neighbour:
        sampled_shelves_neighbours = [str(shelves_id2idx_dict[ss['name']]) for ss in tmp_shelves_neighbours[:book_neighbour]]
    else:
        sampled_shelves_neighbours = [str(shelves_id2idx_dict[ss['name']]) for ss in tmp_shelves_neighbours] + ['-1'] * (shelves_neighbour - len(tmp_shelves_neighbours))
    result_list += sampled_shelves_neighbours
    
    # sample author neighbour
    sampled_author_neighbours = []
    book_authors = [aa['author_id'] for aa in book_info_dict['authors']]
    book_authors = list(set(book_authors) & author_set)
    random.shuffle(book_authors)
    if len(book_authors) >= author_neighbour:
        sampled_author_neighbours = [str(author_id2idx_dict[aa]) for aa in book_authors[:author_neighbour]]
    else:
        sampled_author_neighbours = [str(author_id2idx_dict[aa]) for aa in book_authors] + ['-1'] * (author_neighbour - len(book_authors))
    result_list += sampled_author_neighbours
    
    # publisher
    if book_info_dict['publisher'] != '' and book_info_dict['publisher'] in publisher_set:
        publisher = str(publisher_id2idx_dict[book_info_dict['publisher']])
    else:
        publisher = '-1'
    result_list.append(publisher)
    
    # language_code
    if book_info_dict['language_code'] != '' and book_info_dict['language_code'] in language_code_set:
        language_code = str(language_code_id2idx_dict[book_info_dict['language_code']])
    else:
        language_code = '-1'
    result_list.append(language_code)
    
    # format
    if book_info_dict['format'] != '' and book_info_dict['format'] in format_set:
        formats = str(format_id2idx_dict[book_info_dict['format']])
    else:
        formats = '-1'
    result_list.append(formats)
    
    return '\t'.join(result_list)

# exp
#a = deepcopy(train_data[3])
#r = mysampling(a)
#print(train_data[3])
#print('**********************')
#print(r)

In [None]:
# train file generation

query_cnt = 0
key_cnt = 0

with open('data/book/train.tsv', 'w') as fout:
    for train_pair in tqdm(train_pairs):
        query_info_dict = deepcopy(train_book_dict[train_pair[0]])
        key_info_dict = deepcopy(train_book_dict[train_pair[1]])
        query_info = mysampling(query_info_dict)
        key_info = mysampling(key_info_dict)
        
        write_text = query_info+'\$\$'+key_info+'\n'
        
        a = write_text.strip().split('\$\$')
        if len(a) == 2:
            query_all, key_all = a
        else:
            print(a)
            raise ValueError('stop')
        query_and_neighbors = query_all.split('\t')
        key_and_neighbors = key_all.split('\t')

        if len(query_and_neighbors) != 1 + book_neighbour + shelves_neighbour + author_neighbour + 3:
            query_cnt += 1
            continue
        if len(key_and_neighbors) != 1 + book_neighbour + shelves_neighbour + author_neighbour + 3:
            key_cnt += 1
            continue
        
        fout.write(write_text)
        
print(f'query_cnt:{query_cnt}, key_cnt:{key_cnt}')

In [None]:
# validation file generation

blank_cnt = 0

with open('data/book/val.tsv', 'w') as fout:
    for b in tqdm(val_data):
        query_info_dict = deepcopy(b)
        if len(query_info_dict['similar_books']) == 0:
            blank_cnt += 1
            continue
        
        # sample key
        random.shuffle(query_info_dict['similar_books'])
        sample_key = query_info_dict['similar_books'].pop(0)
        
        key_info_dict = deepcopy(train_book_dict[sample_key])
        
        # sampling
        query_info = mysampling(query_info_dict)
        key_info = mysampling(key_info_dict)
        
        if query_info.split('\t') != 1 + book_neighbour + shelves_neighbour + author_neighbour + 3:
            continue
        if key_info.split('\t') != 1 + book_neighbour + shelves_neighbour + author_neighbour + 3:
            continue
        
        fout.write(query_info+'\$\$'+key_info+'\n')

print(f'Blank:{blank_cnt}')

In [None]:
# test file generation

blank_cnt = 0

with open('data/book/test.tsv', 'w') as fout:
    for b in tqdm(test_data):
        query_info_dict = deepcopy(b)
        if len(query_info_dict['similar_books']) == 0:
            blank_cnt += 1
            continue
        
        # sample key
        random.shuffle(query_info_dict['similar_books'])
        sample_key = query_info_dict['similar_books'].pop(0)
        
        key_info_dict = deepcopy(train_book_dict[sample_key])
        
        # sampling
        query_info = mysampling(query_info_dict)
        key_info = mysampling(key_info_dict)
        
        if query_info.split('\t') != 1 + book_neighbour + shelves_neighbour + author_neighbour + 3:
            continue
        if key_info.split('\t') != 1 + book_neighbour + shelves_neighbour + author_neighbour + 3:
            continue
        
        fout.write(query_info+'\$\$'+key_info+'\n')
        
print(f'Blank:{blank_cnt}')

## generate pretrain raw file

In [None]:
# generate neighbour dict

shelves_neighbour = {}
author_neighbour = {}
publisher_neighbour = {}
language_code_neighbour = {}
format_neighbour = {}

for b in tqdm(train_data):
    
    # shelves
    for ss in b['popular_shelves']:
        if ss['name'] in shelves_id2idx_dict:
            if ss['name'] not in shelves_neighbour:
                shelves_neighbour[ss['name']] = []
            shelves_neighbour[ss['name']].append(remove_next_line(b['title']+b['description']))
    
    # author
    for aa in b['authors']:
        if aa['author_id'] in author_id2idx_dict:
            if aa['author_id'] not in author_neighbour:
                author_neighbour[aa['author_id']] = []
            author_neighbour[aa['author_id']].append(remove_next_line(b['title']+b['description']))
    
    # publisher
    if b['publisher'] != '' and b['publisher'] in publisher_id2idx_dict:
        if b['publisher'] not in publisher_neighbour:
            publisher_neighbour[b['publisher']] = []
        publisher_neighbour[b['publisher']].append(remove_next_line(b['title']+b['description']))
    
    # language_code
    if b['language_code'] != '' and b['language_code'] in language_code_id2idx_dict:
        if b['language_code'] not in language_code_neighbour:
            language_code_neighbour[b['language_code']] = []
        language_code_neighbour[b['language_code']].append(remove_next_line(b['title']+b['description']))
    
    # format
    if b['format'] != '' and b['format'] in format_id2idx_dict:
        if b['format'] not in format_neighbour:
            format_neighbour[b['format']] = []
        format_neighbour[b['format']].append(remove_next_line(b['title']+b['description']))

In [None]:
len(shelves_neighbour)

In [None]:
len(author_neighbour)

In [None]:
len(publisher_neighbour)

In [None]:
len(language_code_neighbour)

In [None]:
len(format_neighbour)

In [None]:
# save id2idx
pickle.dump(shelves_id2idx_dict, open('data/book/neighbour/shelves_id2idx_dict.pkl', 'wb'))
pickle.dump(author_id2idx_dict, open('data/book/neighbour/author_id2idx_dict.pkl', 'wb'))
pickle.dump(publisher_id2idx_dict, open('data/book/neighbour/publisher_id2idx_dict.pkl', 'wb'))
pickle.dump(language_code_id2idx_dict, open('data/book/neighbour/language_code_id2idx_dict.pkl', 'wb'))
pickle.dump(format_id2idx_dict, open('data/book/neighbour/format_id2idx_dict.pkl', 'wb'))

In [None]:
# save neighbour dict
pickle.dump(shelves_neighbour, open('data/book/neighbour/shelves_neighbour.pkl', 'wb'))
pickle.dump(author_neighbour, open('data/book/neighbour/author_neighbour.pkl', 'wb'))
pickle.dump(publisher_neighbour, open('data/book/neighbour/publisher_neighbour.pkl', 'wb'))
pickle.dump(language_code_neighbour, open('data/book/neighbour/language_code_neighbour.pkl', 'wb'))
pickle.dump(format_neighbour, open('data/book/neighbour/format_neighbour.pkl', 'wb'))

In [None]:
# save number statistics
pickle.dump([6632, 205891, 62934, 139, 768], open('data/book/shelves_neighbour.pkl', 'wb'))