In [None]:
import random
import numpy as np
import pandas as pd
import os.path as osp
from tqdm import tqdm
from copy import deepcopy
from collections import defaultdict
from sklearn.preprocessing import QuantileTransformer


# 自定义py文件
from transform import transform_pub
from get_78_train_data import get_train_sample, get_train_feature
from get_78_test_data import get_test_feature, cleanName
from get_sementic_data import get_sementic_feature, generate_embedding_list, calc_sims
from utils import *

# 读取用于训练的数据
+ public：每篇论文的信息
+ author_profile: 每个姓名对应的同名作者ID以及发表的论文

In [None]:
public = load_json("data/v3/train/train_pub.json")
author_profile = load_json("data/v3/train/train_author.json")

## 将论文信息的title和abstract去停用词等简单处理

In [None]:
if not osp.exists("data/v3/processed/train_pub.pkl"):
    # 重新生成，可能有点慢
    public = transform_pub(public)
    dump_json("data/v3/processed/train_pub.pkl", public)
else:
    # 或者加载现有的
    public = load_json("data/v3/processed/train_pub.pkl")

In [None]:
# 选取同名人数大于指定值的作者作为训练集
negNum = 1
paper_and_author = []
author_paper_train = {}
for name in tqdm(author_profile, desc='sampling'):
    if len(author_profile[name]) > negNum:
        for person in author_profile[name]:
            author_paper_train[person] = deepcopy(author_profile[name][person])
            for paper in author_profile[name][person]:
                paper_and_author.append((paper, person, name))
                
print(len(paper_and_author))

## 手工建模特征

## 为每篇论文采样负样本
+  `pos_neg_sample`包括一系列采样的训练样本，每个样本是：paperId, authorId,  name, label
+ 如果这篇论文是这个authorid发表的，则label等于1
+  如果这篇论文不是这个authorid发表的，但是是与他同名的author发表的，则label等于0

## 采样完样本后，自己手动构建特征矩阵


In [None]:
seed = 42
np.random.seed(seed)
random.seed(seed)

pos_neg_sample = []
for item in tqdm(paper_and_author):
    pos_neg_sample.extend(get_train_sample(item, author_profile, negNum=negNum))


In [None]:
samples_feature = []
for item in tqdm(pos_neg_sample[:10]):
    samples_feature.append(get_train_feature(item, public, author_paper_train))

dump_pickle("data/v3/processed/samples_feature.pkl", samples_feature)
print(len(samples_feature))

# 验证集数据处理
## 加载验证集数据

In [None]:
val_data = load_json('data/v3/cna_data/cna_valid_unass.json')
val_public = load_json('data/v3/cna_data/cna_valid_unass_pub.json')
whole_public = load_json('data/v3/cna_data/whole_author_profiles_pub.json')
whole_author_profiles = load_json('data/v3/cna_data/whole_author_profiles.json')

In [None]:
if not osp.exists("data/v3/processed/val_pub.pkl"):
    # 重新生成，可能有点慢
    val_public = transform_pub(val_public)
    dump_json("data/v3/processed/val_pub.pkl", val_public)
else:
    # 或者加载现有的
    val_public = load_json("data/v3/processed/val_pub.pkl")

In [None]:
author_data = defaultdict(list)
for item in whole_author_profiles:
    name = cleanName(whole_author_profiles[item]["name"])
    author_data[name].append(item)

In [None]:
NIL = []
classifySet_all = []
candidate_all = []
paper_ids = [] 

for item in tqdm(val_data):
    paperId, index = item.split('-')
    paperInfo = val_public[paperId]
    name = paperInfo["authors"]
    name = cleanName(name[int(index)]["name"])
    
    ###### 关键处理部分########################
    name = name_remove_comma(name)
    name = name_remove_zero(name)
    name = ch2en(name)
    name = name2name(name)
    
    candidate = []
    candidate = author_data.get(name, [])
    if not candidate:
        candidate = author_data.get(name_reverse(name), [])
        
    if not candidate:
        candidate = author_data.get(name_reverse(name), [])   
        
    if not candidate:
        candidate_name = find_all_candidate(name, author_data) + find_all_candidate(name_reverse(name), author_data)
        if candidate_name:
            for c in candidate_name:
                candidate.extend(author_data[c])    
        
    if not candidate:
        NIL.append(item)
        continue
        
    ################################################    
    
    classifySet = []
    for personId in candidate:
        exam = (paperId, personId) # 用 item的形式取代字符串加'-'连接
        temp = get_test_feature(exam, val_public, whole_author_profiles, whole_public)
        classifySet.append(temp)

    classifySet_all.append(classifySet)
    paper_ids.append(paperId)
    candidate_all.append(candidate)
    
print(f'第一次未匹配文章数 {len(NIL)}')

dump_pickle("data/v3/processed/valid_features.pkl", classifySet_all)
dump_pickle("data/v3/processed/valid_paper_ids.pkl", paper_ids)
dump_pickle("data/v3/processed/valid_candidate_all.pkl", candidate_all)

# 测试集数据处理
## 加载测试集数据

In [None]:
test_data = load_json('data/v3/cna_test_data/cna_test_unass.json')
test_public = load_json('data/v3/cna_test_data/cna_test_unass_pub.json')
whole_public = load_json('data/v3/cna_data/whole_author_profiles_pub.json')
whole_author_profiles = load_json('data/v3/cna_data/whole_author_profiles.json')

In [None]:
if not osp.exists("data/v3/processed/test_pub.pkl"):
    # 重新生成，可能有点慢
    test_public = transform_pub(test_public)
    dump_json("data/v3/processed/test_pub.pkl", test_public)
else:
    # 或者加载现有的
    test_public = load_json("data/v3/processed/test_pub.pkl")

In [None]:
author_data = defaultdict(list)
for item in whole_author_profiles:
    name = cleanName(whole_author_profiles[item]["name"])
    author_data[name].append(item)

In [None]:
NIL = []
classifySet_all = []
candidate_all = []
paper_ids = [] 

for item in tqdm(test_data):
    paperId, index = item.split('-')
    paperInfo = test_public[paperId]
    name = paperInfo["authors"]
    name = cleanName(name[int(index)]["name"])
    
    ###### 关键处理部分########################
    name = name_remove_comma(name)
    name = name_remove_zero(name)
    name = ch2en(name)
    name = name2name(name)
    
    candidate = []
    candidate = author_data.get(name, [])
    if not candidate:
        candidate = author_data.get(name_reverse(name), [])
        
    if not candidate:
        candidate = author_data.get(name_reverse(name), [])   
        
    if not candidate:
        candidate_name = find_all_candidate(name, author_data) + find_all_candidate(name_reverse(name), author_data)
        if candidate_name:
            for c in candidate_name:
                candidate.extend(author_data[c])    
        
    if not candidate:
        NIL.append(item)
        continue  
        
    ################################################    
    
    classifySet = []
    for personId in candidate:
        exam = (paperId, personId) # 用 item的形式取代字符串加'-'连接
        temp = get_test_feature(exam, test_public, whole_author_profiles, whole_public)
        classifySet.append(temp)
    

    classifySet_all.append(classifySet)
    paper_ids.append(paperId)
    candidate_all.append(candidate)
    
print(f'第一次未匹配文章数 {len(NIL)}')

dump_pickle("data/v3/processed/test_features.pkl", classifySet_all)
dump_pickle("data/v3/processed/test_paper_ids.pkl", paper_ids)
dump_pickle("data/v3/processed/test_candidate_all.pkl", candidate_all)