In [1]:
%%capture
import json
import networkx as nx
from tqdm.notebook import tqdm
import numpy as np
import scipy as sp

# source : https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.linalg.svds.html

In [16]:
with open('./data/aan_full.json') as f:
    full_set = json.load(f)
    
directed_citation_graph = nx.DiGraph()
for paper in full_set:
    for ref_id in paper['references']:
        directed_citation_graph.add_edge(paper['id'], ref_id)
        
nodelist = directed_citation_graph.nodes()
adj_matrix = nx.adjacency_matrix(directed_citation_graph).astype(np.float64)
node_index = dict([(node, i) for i, node in enumerate(nodelist)])

In [11]:
U, s, V_t = sp.sparse.linalg.svds(adj_matrix, k=100)
S = np.diag(s)
S_i = np.linalg.inv(S)

In [37]:
with open('./data/aan_test_tripletfromref.json') as f:
    test_set = json.load(f)
    
input_vectors = []
for input_papers in test_set:
    data = np.array([1. for _ in input_papers])
    indices = np.array([node_index[input_paper['id']] for input_paper in input_papers])
    indptr = np.array([0,2])
    
    input_vector = sp.sparse.csr_matrix((data, indices, indptr),
                                        shape=(1, len(nodelist)),
                                        dtype=np.float64)

    input_vectors.append(input_vector)
    
input_array = sp.sparse.vstack(input_vectors)

U_test = (input_array.dot(V_t.T)).dot(S_i)
score_array = (U_test.dot(S)).dot(V_t)

In [43]:
results = []

for input_papers, scores  in tqdm(zip(test_set, score_array.tolist())):
    result = {}
    result['input'] = [input_paper['id'] for input_paper in input_papers]
    input_paper_ids_set = set(result['input'])
    
    candidate_scores = list(zip(nodelist, scores))     
    candidate_scores.sort(key=lambda x: x[1], reverse=True)
    filtered_candidate_scores = [cs for cs in candidate_scores if cs[0] not in input_paper_ids_set]
    
    result['output'] = [cs[0] for cs in filtered_candidate_scores[:100]]
    results.append(result)

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))




In [44]:
with open('./results/tripletfromref_base_svd.json', 'w') as f:
    json.dump(results, f)