# 论文消歧（赛道一）开源
### 解决思路
1. 将处理后的Title+Abstract+Keywords作为训练样本，利用word2vec训练词向量
2. 计算每个词的IDF，将IDF与词向量加权求和，作为每篇论文的嵌入向量
3. 两两计算每篇论文之间的Cosine相似度，如果相似度大于0.7，则判断为同个作者
4. 通过networkx构建图，将论文作为nodes，如果两篇论文属于同个作者，则添加一条edge，最后输出图的各个连通子图，就是比赛的输出
### 线上分数
~0.248
### 参考
[预处理参考这里](https://biendata.com/models/category/3000/L_notebook/ )
### 备注
代码在谷歌的colab环境下运行，如果在本地运行，请去掉相关的挂载命令并修改文件路径 

In [0]:
from google.colab import drive
#drive.mount('/name-disambiguation/input')
drive.mount('/content/gdrive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/gdrive


In [0]:
import os, json, re
data_path = "gdrive/My Drive/name-disambiguation/input"
checkpoint_path = "gdrive/My Drive/name-disambiguation/checkpoint"
output_path = "gdrive/My Drive/name-disambiguation/output"
train_pub_path = os.path.join(data_path,'train_pub.json')
train_author_path = os.path.join(data_path,'train_author.json')
valid_author_raw_path = os.path.join(data_path,'sna_valid_author_raw.json')
valid_pub_path = os.path.join(data_path,'sna_valid_pub.json')
word2vect_model_path = os.path.join(checkpoint_path,'w2v_train_100.model')
pub_embedding_path = os.path.join(output_path,'pub_embedding.json')


from nltk.stem.porter import PorterStemmer  #todo: 还有其他词干抽取器
import multiprocessing as mlp
import pickle as pkl
from tqdm import tqdm, tqdm_notebook
import numpy as np

import time
from tqdm import tqdm_notebook, tqdm
from functools import partial
from sklearn.svm import SVC, NuSVC, LinearSVC
from sklearn.preprocessing import MinMaxScaler
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, confusion_matrix
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import pandas as pd
from itertools import combinations
from tqdm import tqdm_notebook,tqdm
from sklearn.metrics.pairwise import cosine_similarity
import random

from gensim.models import Word2Vec, KeyedVectors
EMBEDDING_DIM = 100

# 预处理名字
def preprocess_name(name):   
    name = name.lower().replace(' ', '_')
    name = name.replace('.', '_')
    name = name.replace('-', '')
    name = re.sub(r"_{2,}", "_", name) 
    return name

def preprocess_org(org):
    if org != "":
        org = org.replace('Sch.', 'School')
        org = org.replace('Dept.', 'Department')
        org = org.replace('Coll.', 'College')
        org = org.replace('Inst.', 'Institute')
        org = org.replace('Univ.', 'University')
        org = org.replace('Lab ', 'Laboratory ')
        org = org.replace('Lab.', 'Laboratory')
        org = org.replace('Natl.', 'National')
        org = org.replace('Comp.', 'Computer')
        org = org.replace('Sci.', 'Science')
        org = org.replace('Tech.', 'Technology')
        org = org.replace('Technol.', 'Technology')
        org = org.replace('Elec.', 'Electronic')
        org = org.replace('Engr.', 'Engineering')
        org = org.replace('Aca.', 'Academy')
        org = org.replace('Syst.', 'Systems')
        org = org.replace('Eng.', 'Engineering')
        org = org.replace('Res.', 'Research')
        org = org.replace('Appl.', 'Applied')
        org = org.replace('Chem.', 'Chemistry')
        org = org.replace('Prep.', 'Petrochemical')
        org = org.replace('Phys.', 'Physics')
        org = org.replace('Phys.', 'Physics')
        org = org.replace('Mech.', 'Mechanics')
        org = org.replace('Mat.', 'Material')
        org = org.replace('Cent.', 'Center')
        org = org.replace('Ctr.', 'Center')
        org = org.replace('Behav.', 'Behavior')
        org = org.replace('Atom.', 'Atomic')
        org = org.lower()
        org = org.split(';')[0]  # 多个机构只取第一个
    return org

def clean_sent(s, prefix = None):
    '''
    为区别各字段，不同字段前的词加不同的前缀
    '''
    words = re.sub('[^ \-_a-z]', ' ', s.lower()).split()
    stemer = PorterStemmer()
    return [ '__%s__%s'%(prefix, stemer.stem(w)) for w in words]

class PubInfo():
    def __init__(self, authors_name, authors_org, authors_pair, keywords, year, venue):
        self.authors_name = list(set(authors_name))
        self.authors_org = list(set(authors_org))
        self.authors_pair = authors_pair
        self.keywords = keywords
        self.year = int(year)
        self.venue = venue

def generate_pubinfo(pub_info):
    pub_authors_pair, pub_authors_name,pub_authors_org = [],[],[]
    pub_authors = pub_info['authors']
    pub_keywords = pub_info['keywords']
    pub_year = pub_info['year']
    pub_venue = pub_info['venue']
    for i in pub_authors:
        n = preprocess_name(i['name'])
        o = preprocess_org(i['org'])
        pub_authors_name.append(n)
        pub_authors_org.append(o)
        pub_authors_pair.append((n,o))
    # print(pub_authors_pair)
    return PubInfo(pub_authors_name, pub_authors_org, pub_authors_pair, pub_keywords, pub_year, pub_venue)
    
def cal_distance(left_id, right_id,pub_embedding):
    # print(left_doc,pub_embedding[left_doc])
    left_embedding = pub_embedding.get(left_id,np.zeros((EMBEDDING_DIM,)))
    right_embedding = pub_embedding.get(right_id,np.zeros((EMBEDDING_DIM,)))
    
    return cosine_similarity(left_embedding.reshape((1,-1)),right_embedding.reshape((1,-1)))[0][0]
    # return left_embedding.reshape((1,-1)) - right_embedding.reshape((1,-1))

def pair_to_cluster(v):
    G = nx.Graph()
    for pub_pair, author_name, f, l in v:
            left_pub, right_pub  = pub_pair
            if left_pub not in G:
                G.add_node(left_pub)
            if right_pub not in G:
                G.add_node(right_pub)
            if f > 0.7:
                G.add_edge(left_pub, right_pub)

    cluster = list(nx.connected_components(G))
    cluster = [list(i) for i in cluster]
    cnt_cluster = len(cluster)
    sum_cluster = sum([len(i) for i in cluster])
    max_cluster = max([len(i) for i in cluster])
    min_cluster = min([len(i) for i in cluster])
    avg_cluster = sum_cluster / cnt_cluster
    
    print(cnt_cluster, sum_cluster, max_cluster, min_cluster, avg_cluster)
    return (author_name, cluster)

In [0]:
with open(train_pub_path,'r') as f:
    train_pub = json.load(f)
    for k, v in train_pub.items():
        print(k,'->',v)
        break

P9a1gcvg -> {'authors': [{'name': 'Fenghe Qiu', 'org': 'Institute of Pharmacology and Toxicology'}, {'name': 'Li Liu', 'org': 'Institute of Pharmacology and Toxicology'}, {'name': 'Li Guo', 'org': 'Institute of Pharmacology and Toxicology'}], 'title': 'Rapid determination of central nervous drugs in plasma by solid-phase extraction and GC-FID and GC-MS', 'abstract': 'Objective: To establish a simultaneous determination method of central nervous drugs including barbitals, benzodiazepines, phenothiozines and tricyclic antidepressants in human plasma. Methods: Drugs in plasma were extracted and purified by using X-5 resin solid phase extraction columns, followed by identification and quantitation using capillary GC-FID and GC-MS. Results: More than 20 drugs were simultaneously extracted from human plasma, and effectively separated in GC and TIC spectra. The correlation coefficient of standard curves was larger than 0.99, and relative standard differences (RSD) were less than 10% for most 

In [0]:
with open(valid_pub_path,'r') as f:
    valid_pub = json.load(f)
    for k, v in valid_pub.items():
        print(k,'->',v)
        break

JIfXICrQ -> {'id': 'JIfXICrQ', 'title': 'Analysis of volatile compounds and determination of key odors in Sichuan Paocai (Part Ⅱ)', 'abstract': "Combined solid phase micro-extraction and gas chromatography-mass spectrometry(SPME-GC/MS) were taken to analyze the volatile components of Sichuan Paocai made with natural fermentation or artificial fermentation which is inoculated with old saline or lactic acid bacteria(LAB) starter cultures.Alcohols,aldehyde and alkene are the main volatiles of Sichuan Paocai,which amount 90% of the pickles' volatiles.Meanwhile key odors of Sichuan Paocai,sulfide,alkene and aldeyde are determined.Esters play only a limited role in the odor.There are big differences between natural fermented Sichuan Paocai and the artificial fermented Sichuan Paocai in the key odors.However,key odors of the artificial Sichuan Paocai inoculated with old saline or LAB starter cultures are similar.", 'keywords': ['technology', 'pickle', 'Sichuan Paocai', 'volatile compound', 'a

In [0]:
with open(train_author_path,'r') as f:
    train_author = json.load(f)
    for k, v in train_author.items():
        print(k,'->',v)
        break
train_author_raw = {}
for k, v in train_author.items():
    ids = [vv for kk, vv in v.items()]
    train_author_raw[k] = sum(ids, []) # Flattening
for k, v in train_author_raw.items():
    print(k,'->',v)
    break

li_guo -> {'EShnTfSe': ['P9a1gcvg'], 'sCKCrny5': ['Rg5fAeTd', 'lJPsGNBE', 'er5gTz90', 'UG32p2zs', 'FTweBQNS', 'MqQOkfGH', 'HcrsMIFk', 'O14akltW', 'gkO4nqtJ', 'pW2XD7jw', '9VmupIoc', 'pPjljRgW', 'rakLh2IJ', '7zYwBnXA', '0raOqO0t', 'ct2XgIla', 'Vh2Yy9vh', 'Yqpk8Jyt', 'xiuPYHlY', 'xWDIT1E6', 'bUEtpqgY', 'L98DkDap', 'UnhHr74J', 'CvJPSUxQ', 'AriXov6L', 'elqOgbWS', 'dikitupP', 'nlwdBqLy', 'pT5EU4aH', 'ub0fgak2', 'eThzWtj0', '3p1wV47k', '9iQCpuHx', 'QFAYRwHl', 'lioaGlmw', 'vW26N7Jw', 'Y8dEQrwa', 'AmLblp6l', 'gsL5mZJL', 'TXiFqOQn', 'dKTBlS35', 'fKt8D2xH', 'qaykvKo3', 'W0P2muEr', 'sLjQJNj8', '0ytlj6Jw', '3P3o3ZQ9', 'sqGfmgG1'], 'SY2FEqNj': ['2EGB9ZLn'], 'XXFI1LfP': ['wd3ugFQM', 'pYKTrEQj', 'ZUfWc0Zj', 'jMErL46k', 'g2a7ae8r', 'C0LwTl95', '84iQz3xI', 'saq73SMR', '7609LUea', 'mO4rXXob', 'Fhz4Q5J5', 'lUuqOzwx', 'Q8KEEDfP', 'vKhZWMKD', 'V2XqV0Gp', 'ceGwETv1', 'aSqmClGz', 'DRx4shZr', '1yu8n8Sn', 'kMO2ex8C', 'tLWwSkZD', '8a1YPO1b', '2j7eBdaC', 'kmyawALz', 'KM70zskA', 'APIk8oID', 'GSwb2A2E', 'b6L5kmwf'

In [0]:
with open(valid_author_raw_path,'r') as f:
    valid_author_raw = json.load(f)
    for k, v in valid_author_raw.items():
        print(k,'->',v)
        break

li_guo -> {'EShnTfSe': ['P9a1gcvg'], 'sCKCrny5': ['Rg5fAeTd', 'lJPsGNBE', 'er5gTz90', 'UG32p2zs', 'FTweBQNS', 'MqQOkfGH', 'HcrsMIFk', 'O14akltW', 'gkO4nqtJ', 'pW2XD7jw', '9VmupIoc', 'pPjljRgW', 'rakLh2IJ', '7zYwBnXA', '0raOqO0t', 'ct2XgIla', 'Vh2Yy9vh', 'Yqpk8Jyt', 'xiuPYHlY', 'xWDIT1E6', 'bUEtpqgY', 'L98DkDap', 'UnhHr74J', 'CvJPSUxQ', 'AriXov6L', 'elqOgbWS', 'dikitupP', 'nlwdBqLy', 'pT5EU4aH', 'ub0fgak2', 'eThzWtj0', '3p1wV47k', '9iQCpuHx', 'QFAYRwHl', 'lioaGlmw', 'vW26N7Jw', 'Y8dEQrwa', 'AmLblp6l', 'gsL5mZJL', 'TXiFqOQn', 'dKTBlS35', 'fKt8D2xH', 'qaykvKo3', 'W0P2muEr', 'sLjQJNj8', '0ytlj6Jw', '3P3o3ZQ9', 'sqGfmgG1'], 'SY2FEqNj': ['2EGB9ZLn'], 'XXFI1LfP': ['wd3ugFQM', 'pYKTrEQj', 'ZUfWc0Zj', 'jMErL46k', 'g2a7ae8r', 'C0LwTl95', '84iQz3xI', 'saq73SMR', '7609LUea', 'mO4rXXob', 'Fhz4Q5J5', 'lUuqOzwx', 'Q8KEEDfP', 'vKhZWMKD', 'V2XqV0Gp', 'ceGwETv1', 'aSqmClGz', 'DRx4shZr', '1yu8n8Sn', 'kMO2ex8C', 'tLWwSkZD', '8a1YPO1b', '2j7eBdaC', 'kmyawALz', 'KM70zskA', 'APIk8oID', 'GSwb2A2E', 'b6L5kmwf'

In [0]:
# 对论文的信息做预处理，并保存
def generate_text(x):
    pub_id, doc = x
    title = clean_sent(doc['title'], 'T') if doc.get('title',None) else []
    venue = clean_sent(doc['venue'], 'V') if doc.get('venue',None) else []
    abstract = clean_sent(doc['abstract'], 'A') if doc.get('abstract',None) else []
    keywords = clean_sent(' '.join(doc['keywords']), 'K') if doc.get('keywords',None) else []
    # print(title+venue+abstract+keywords)
    return pub_id, title+venue+abstract+keywords

def prepare_text_to_word2vec(train_pub, valid_pub=None):

    material_train_path = os.path.join(output_path,'material_train.json')
    material_valid_path = os.path.join(output_path,'material_valid.json')
    # if os.path.exists(material_train_path):
    #     material = pkl.load(open(material_train_path,'rb'))
    #     print(material_train_path)
    #     return material
    # with mlp.Pool(20) as pool:
    pool = mlp.Pool(5)
    for pub, save_path in zip([train_pub, valid_pub], [material_train_path, material_valid_path]):
        if not pub:
            continue

        # material = pool.map( generate_text, [(pub_id, pub_info) for pub_id, pub_info in pub.items()])
        material = tqdm_notebook(pool.imap( generate_text, zip(pub.keys(), pub.values())), total=len(pub))
        material = dict(material)
        pkl.dump(material, open(save_path,'wb'))

        for k, v in material.items():
            print(k,'->',v)
            break
    pool.close()
    pool.join()
    return material

prepare_text_to_word2vec(train_pub, valid_pub)

HBox(children=(IntProgress(value=0, max=203184), HTML(value='')))




HBox(children=(IntProgress(value=0, max=45416), HTML(value='')))


JIfXICrQ -> ['__T__analysi', '__T__of', '__T__volatil', '__T__compound', '__T__and', '__T__determin', '__T__of', '__T__key', '__T__odor', '__T__in', '__T__sichuan', '__T__paocai', '__T__part', '__V__china', '__V__brew', '__A__combin', '__A__solid', '__A__phase', '__A__micro-extract', '__A__and', '__A__ga', '__A__chromatography-mass', '__A__spectrometri', '__A__spme-gc', '__A__ms', '__A__were', '__A__taken', '__A__to', '__A__analyz', '__A__the', '__A__volatil', '__A__compon', '__A__of', '__A__sichuan', '__A__paocai', '__A__made', '__A__with', '__A__natur', '__A__ferment', '__A__or', '__A__artifici', '__A__ferment', '__A__which', '__A__is', '__A__inocul', '__A__with', '__A__old', '__A__salin', '__A__or', '__A__lactic', '__A__acid', '__A__bacteria', '__A__lab', '__A__starter', '__A__cultur', '__A__alcohol', '__A__aldehyd', '__A__and', '__A__alken', '__A__are', '__A__the', '__A__main', '__A__volatil', '__A__of', '__A__sichuan', '__A__paocai', '__A__which', '__A__amount', '__A__of', '__A__

In [0]:
# 读取训练集和验证集的处理后的样本
import gc
material_train_path = os.path.join(output_path,'material_train.json')
material_valid_path = os.path.join(output_path,'material_valid.json')
material_train = pkl.load(open(material_train_path,'rb')) if os.path.exists(material_train_path) else None
material_valid = pkl.load(open(material_valid_path,'rb')) if os.path.exists(material_valid_path) else None
material = {**material_train, **material_valid}
del material_train
del material_valid
gc.collect()
print(len(material))

246750


In [0]:
# 训练词向量模型
if os.path.exists(word2vect_model_path):
    # docs = pkl.load(open(material_path,'rb'))
    model_w2v =Word2Vec.load(word2vect_model_path)
    pass
else:
    model_w2v = Word2Vec([v for k,v in material.items()], size=EMBEDDING_DIM, window=5, min_count=5, workers=20)
    model_w2v.save(word2vect_model_path)

for k,v in dict([(k, idx+1)for idx,(k,v) in enumerate(model_w2v.wv.vocab.items())]).items():
    print(k,v)
    break

  'See the migration notes for details: %s' % _MIGRATION_NOTES_URL


__T__rapid 1


In [0]:
%%time
# 训练IDF
from sklearn.feature_extraction.text import CountVectorizer, TfidfTransformer, TfidfVectorizer

w2v_vocab = dict([(k, idx)for idx,(k,v) in enumerate(model_w2v.wv.vocab.items())])
vectorizer = TfidfVectorizer(max_df=0.5, vocabulary=w2v_vocab, lowercase=False) # token_pattern='(?u)\b\w\w+\b')
X = vectorizer.fit_transform([' '.join(v) for k,v in material.items()])
print(vectorizer.idf_.shape, len(vectorizer.vocabulary_), len(w2v_vocab))
# for idx, (k, v) in enumerate( vectorizer.vocabulary_.items()):
for idx, (k, v) in enumerate(w2v_vocab.items()):
    print(k,v, vectorizer.idf_[v])
    if idx == 9:
        break

(119160,) 119160 119160
__T__rapid 0 6.457686616674764
__T__determin 1 5.396193322295055
__T__of 2 1.4251137458591194
__T__central 3 6.838273652251373
__T__nervou 4 9.046687157505398
__T__drug 5 6.189925999871748
__T__in 6 2.0715690267800912
__T__plasma 7 5.899157785368098
__T__by 8 3.317697640492864
__T__solid-phas 9 13.41613500997242


In [0]:
%%time
# 构造一个包含idf和embedding的矩阵，加快查询的速度
def create_embedding_idx(w2v_vocab, vectorizer, model_w2v):
    embedding_idx = np.zeros((len(w2v_vocab), 1+EMBEDDING_DIM))
    for word, idx in tqdm_notebook(w2v_vocab.items(), total=len(w2v_vocab)):

        idf = vectorizer.idf_[idx]
        embedding = model_w2v.wv[word]
        embedding_idx[idx,0] = idf
        embedding_idx[idx,1:] = embedding
    return embedding_idx
embedding_idx = create_embedding_idx(w2v_vocab, vectorizer, model_w2v)

HBox(children=(IntProgress(value=0, max=119160), HTML(value='')))


CPU times: user 1min 53s, sys: 2.83 s, total: 1min 56s
Wall time: 1min 56s


In [0]:
%%time
# 计算每篇论文的嵌入向量
pub_embedding = {}


def get_doc_embedding(x,w2v_vocab,embedding_idx):
    pub_id, doc  = x
    word_vecs = []
    sum_weight = 0.0
    for word in doc:
        idx = w2v_vocab.get(word,None)
        if idx is not None:
            idf = embedding_idx[idx,0]
            embedding = embedding_idx[idx,1:]
            word_vecs.append( embedding * idf  )
            sum_weight += idf
            # print(word, idx, )
        else :
            pass
            # print(word, idx, )
    if len(word_vecs) == 0:
        weight = np.zeros((EMBEDDING_DIM,))
    else:
        weight = np.sum(word_vecs, axis = 0) / sum_weight
    return (pub_id, weight)

partial_work = partial(get_doc_embedding, w2v_vocab=w2v_vocab, embedding_idx=embedding_idx) 


pub_embedding = [i for i in tqdm_notebook(map(partial_work, zip(material.keys(),material.values())), total=len(material))]
pub_embedding = dict(pub_embedding)
for k, v in pub_embedding.items():
    print(k,v)
    break

HBox(children=(IntProgress(value=0, max=246750), HTML(value='')))


P9a1gcvg [-8.64553512e-02 -1.63969892e-01 -2.06351956e-01 -7.14733807e-01
  5.41466049e-01 -9.11201232e-01 -5.15561797e-01  4.26689449e-01
  3.39575201e-01 -1.56413616e-02  9.44672687e-04 -8.11502777e-03
 -7.90831292e-01 -1.78913322e-01  8.06763190e-02  6.73173190e-01
 -2.94236301e-01  5.02934129e-01 -4.61919279e-01 -2.21791543e-01
 -3.85466347e-01 -2.93281641e-01  7.36233718e-01  1.27524352e-01
 -3.75883989e-02 -4.92628301e-01 -1.15766725e+00  5.12122219e-02
  7.48636098e-01 -6.72230819e-02  7.43336650e-01 -2.77199857e-01
 -8.41933462e-01 -3.48576929e-01 -1.32203252e+00  4.29998770e-01
 -2.16396810e-02  5.60531949e-02 -1.82455632e-01 -8.56544296e-01
 -6.99641964e-01 -8.34725843e-01  6.04369094e-03 -2.00801046e-01
  2.49417198e-01  3.97216460e-01  3.95595166e-01  4.27633690e-01
  1.76337263e-01 -1.16787606e+00 -7.17508186e-01 -1.66466880e+00
 -1.39840926e-01 -1.42020113e+00  1.05808044e+00  2.44048674e-02
  8.40416867e-01 -5.96625649e-01  5.31029597e-02  3.60150173e-01
  4.56295704e-0

In [0]:
pkl.dump(pub_embedding, open(pub_embedding_path,'wb'))
# pub_embedding = pkl.load(open(pub_embedding_path,'rb')) if os.path.exists(pub_embedding_path) else {}}

In [0]:
pub_embedding = pkl.load(open(pub_embedding_path,'rb')) 

In [0]:
# 读取验证集的数据，输出结果
def generate_cluster(x, pub_embedding):
    
    author_name, author_pub_ids = x
    if len(author_pub_ids) == 0:
        return (author_name,[])


    feat = []
    same_pair = []
    # random.shuffle(author_pub_ids)
    all_pair = [tuple(set(i)) for i in combinations(author_pub_ids,2) if i[0] != i[1]]
    print(author_name, len(same_pair), len(all_pair), len(author_pub_ids))
    diff_pair = all_pair[:] # list(set(all_pair) - set(same_pair))
    # random.shuffle(diff_pair)
    
    # cnt_same = min(len(same_pair),num_sample)
    # cnt_diff = min(len(diff_pair),4*num_sample)
    # same_pair = same_pair[:num_sample]
    # diff_pair = diff_pair[:4*num_sample]
    sample_pair = same_pair + diff_pair
    author_names = [author_name] * len(sample_pair)
    label = [1]*len(same_pair) + [0]*len(diff_pair)
    for pub_pair in sample_pair: # tqdm_notebook(sample_pair, total=len(sample_pair)):
        left_pub_id, right_pub_id = pub_pair
        feat.append(cal_distance(left_pub_id, right_pub_id,pub_embedding ))
    return pair_to_cluster(zip(sample_pair, author_names, np.array(feat).reshape((-1,1)), np.array(label)))


def generate_output(num_sample, pub_embedding):
    drive.mount('/content/gdrive')
    with open(valid_author_raw_path,'r') as f:
        valid_author_raw = json.load(f)
    print(valid_author_raw_path, len(valid_author_raw))
    for k, v in valid_author_raw.items():
        print(k,'->',v)
        break
    print(valid_author_raw['j_yu'])
    result = []
    # for author_name, author_pub_ids in valid_author_raw.items():
    partial_work = partial(generate_cluster, pub_embedding=pub_embedding) 
    pool = mlp.Pool(2)
    result = [i for i in tqdm_notebook(pool.imap(partial_work,valid_author_raw.items()), total=len(valid_author_raw.items()))]
    pool.close()
    pool.join()
    # result = [i for i in map(partial_work,valid_author_raw.items())]
    
    return result
    # return all_pub_pair, author_names, np.array(feat).reshape((-1,1)), np.array(label)
# all_pub_pair, author_names, feat, label = generate_feature(10000,pub_embedding)
result = generate_output(10000000000,pub_embedding)

import json
result = dict(result)
with open("gdrive/My Drive/name-disambiguation/output/result.json",'w') as f:
    json.dump(result,f)

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).
gdrive/My Drive/name-disambiguation/input/sna_valid_author_raw.json 50
heng_li -> ['zszavvJh', 'a2Tl88Xu', 'X3Y0AYBw', '0kUbjIMW', '4bKZMNxE', 'vbQeLjWI', 'ZfSbG5ME', 'ierOTBz1', '6cXZA7zu', 'cXBXYa63', '141Kx9Ce', 'Gfl2GwXG', '98xkec8D', '2MIIwgBH', 'yVCAyLJo', 'NaPOuXXF', 'YoRqTILd', 'V8UOM2uU', 'wtiZkPVr', 'qggv8F0K', 'uDsLiLZd', 'FIVcmK6w', 'M2YJ0Dfa', 'QqQNHt3b', 'Cpi5vz3j', '7wJu2eAh', 'dUlZsS5z', 'cZnZW9uz', 'bxkStxS7', 'TH4B6A7Y', 'xn3xJkGU', 'xsbQHWyq', 'B5HNDTzE', 'MsT3Sml8', 'PO61SLCN', 'SCvnHfAx', 'pFmfB3Hm', 'qFyL0b7h', 'MdbnxBrH', 'WDe4ltiG', '4eHVOVB7', '8OiM0qmf', 'Q62Py6ew', 'Jmcve9yI', '475s6I47', 'OnwXGiRR', 'CUbTTLzf', '4DaB8tVN', 'oim5ZjhO', 'OyhZ1MiB', 'KZzOrPpy', 's0UYiQ6B', '0GzWz3O4', 'kWe0u2Ih', 'AyFd9xZZ', 'N7J7fEeF', 'C1WBo3Cs', 'Lg5JGKID', 'W6leQ8fX', 'qyh46ML8', 'ojqAGDyS', 'aA4FOt2Y', 'x3h6uA0c', 'tWxrqzEO', 'goiijZU0', 'wyion

HBox(children=(IntProgress(value=0, max=50), HTML(value='')))

heng_li 0 294528 768
fang_chen 0 49141 314
1 314 314 314 314.0
lin_zhou 0 479709 980
4 768 765 1 192.0
akio_yamamoto 0 36856 272
2 272 166 106 136.0
y_luo 0 13041 162
3 162 113 1 54.0
chun_li 0 50403 318
1 318 318 318 318.0
shiyi_chen 0 26106 229
2 229 225 4 114.5
zhigang_chen 0 918690 1356
3 979 977 1 326.3333333333333
chun_wang 0 24310 221
1 221 221 221 221.0
y_guo 0 250276 708
4 706 703 1 176.5
g_li 0 65341 362
1 362 362 362 362.0
jing_huang 0 627760 1121
1 1356 1356 1356 1356.0
atsushi_takeda 0 3003 78
1 78 78 78 78.0
fei_gao 0 905185 1346
7 1121 1115 1 160.14285714285714
rajendra_prasad 0 13530 165
1 165 165 165 165.0
qi_li 0 3605947 2686
3 1346 1344 1 448.6666666666667
bin_ren 0 68635 371
1 371 371 371 371.0
yong_cao 0 897130 1340
2 1340 1339 1 670.0
jing_zhou 0 46056 304
1 304 304 304 304.0
hong_jiang 0 3640951 2699
6 2678 2673 1 446.3333333333333
liang_zhou 0 22578 213
1 213 213 213 213.0
ming_xu 0 746031 1222
4 1222 1219 1 305.5
jie_sun 0 989121 1407
1 1407 1407 1407 1407.0
6 