把原始数据的标签转换成数字形式，并完成Train/Validation/Test的分割。这里的划分是用于比赛模型训练和模型选择用的，并不是原始的文件名。

In [1]:
import pandas as pd
import numpy as np
import os
import pickle

import dgl

Using backend: pytorch


In [2]:
# path
base_path = './final_data'
publish_path = ''

nodes_path = os.path.join(base_path, publish_path, 'IDandLabels.csv')

### 读取节点列表

In [3]:
nodes_df = pd.read_csv(nodes_path, dtype={'Label':str})
print(nodes_df.shape)
nodes_df.tail(4)

(5346177, 4)


Unnamed: 0,node_idx,paper_id,Label,Split_ID
5346173,5346173,1b8ab3d079dca59f31b846fd79e5ebb5,,1
5346174,5346174,38684c9ad0cbb959bbfd66c12938b227,,1
5346175,5346175,613fbc81d975a8d604ad71c48036b02e,,1
5346176,5346176,f58fbe42664299820e3b3b50b9a5983f,,1


In [4]:
nodes_df.head(20)

Unnamed: 0,node_idx,paper_id,Label,Split_ID
0,0,bfdee5ab86ef5e68da974d48a138c28e,S,0
1,1,78f43b8b62f040347fec0be44e5f08bd,,0
2,2,a971601a0286d2701aa5cde46e63a9fd,G,0
3,3,ac4b88a72146bae66cedfd1c13e1552d,,0
4,4,a48c92cc8f67a8327adac7ff62d24a53,W,0
5,5,4736ef4d2512bb23954118adcb605b5e,H,0
6,6,c50a868bea34f9295afb3af544c14504,,0
7,7,917de373f8b3cb2dfe245b25ac72a73e,,0
8,8,76c9ca903451d3620a17e7ece4907585,,0
9,9,f22eb5767466d2870afa286b9aaf50bd,,0


In [5]:
len(nodes_df[nodes_df.Split_ID==0])

3655033

In [6]:
len(nodes_df[(nodes_df.Split_ID==0) & (nodes_df.Label.notnull())]) / len(nodes_df[nodes_df.Split_ID==0])

0.28574762526083897

### 转换标签为数字

In [7]:
# 先检查一下标签的分布
label_dist = nodes_df.groupby(by='Label').count()
print(label_dist.shape)
label_dist

(23, 3)


Unnamed: 0_level_0,node_idx,paper_id,Split_ID
Label,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
A,2670,2670,2670
B,65303,65303,65303
C,111502,111502,111502
D,104005,104005,104005
E,45014,45014,45014
F,32876,32876,32876
G,43452,43452,43452
H,71824,71824,71824
I,23994,23994,23994
J,25241,25241,25241


#### 可以看到一共有23个标签，A类最少，C类最多，基本每类都有几万个。下面从0开始，重够标签


In [8]:
# 按A-W的顺序，从0开始转换
for i, l in enumerate(label_dist.index.to_list()):
    nodes_df.loc[(nodes_df.Label==l), 'label'] = i

nodes_df.label.fillna(-1, inplace=True)
nodes_df.label = nodes_df.label.astype('int')
nodes_df.head(4)

Unnamed: 0,node_idx,paper_id,Label,Split_ID,label
0,0,bfdee5ab86ef5e68da974d48a138c28e,S,0,18
1,1,78f43b8b62f040347fec0be44e5f08bd,,0,-1
2,2,a971601a0286d2701aa5cde46e63a9fd,G,0,6
3,3,ac4b88a72146bae66cedfd1c13e1552d,,0,-1


#### 只保留新的node index、标签和原始的分割标签

In [9]:
nodes = nodes_df[['node_idx', 'label', 'Split_ID']]
nodes.tail(4)

Unnamed: 0,node_idx,label,Split_ID
5346173,5346173,-1,1
5346174,5346174,-1,1
5346175,5346175,-1,1
5346176,5346176,-1,1


## 划分Train/Validation/Test

由于只有原始的Train_nodes文件里面包括了标签，所以这里的Train/Validation是对原始的分割。

这里按照9:1的比例划分Train/Validation。Test就是原来的validation_nodes里面的index。

In [10]:
# 获取所有的标签
tr_val_labels_df = nodes[(nodes.Split_ID == 0) & (nodes.label >= 0)]
test_label_df = nodes[nodes.Split_ID == 1]

# 按照0~22每个标签划分train/validation
tr_labels_idx = np.array([0])
val_labels_idx = np.array([0])
split_ratio = 0.9

for label in range(23):
    label_idx = tr_val_labels_df[tr_val_labels_df.label == label].node_idx.to_numpy()
    split_point = int(label_idx.shape[0] * split_ratio)
    
    # 把每个标签的train和validation的index添加到整个列表
    tr_labels_idx = np.append(tr_labels_idx, label_idx[: split_point])
    val_labels_idx = np.append(val_labels_idx, label_idx[split_point: ])

In [11]:
# 获取Train/Validation/Test标签index
tr_labels_idx = tr_labels_idx[1: ]
val_labels_idx = val_labels_idx[1: ]

test_labels_idx = test_label_df.node_idx.to_numpy()

In [12]:
# 获取完整的标签列表
labels = nodes.label.to_numpy()

In [13]:
# 保存标签以及Train/Validation/Test的index为二进制格式方便后面建模时的快速读取
label_path = os.path.join(base_path, publish_path, 'labels.pkl')

with open(label_path, 'wb') as f:
    pickle.dump({'tr_label_idx': tr_labels_idx, 
                 'val_label_idx': val_labels_idx, 
                 'test_label_idx': test_labels_idx,
                 'label': labels}, f)

In [15]:
tr_labels_idx[:10]

array([ 6018,  9505,  9507,  9508, 12132, 14077, 16761, 19551, 24098,
       25789])

In [16]:
val_labels_idx[:10]

array([2343425, 2348861, 2349640, 2349726, 2351857, 2352245, 2353021,
       2353324, 2353414, 2355763])

In [17]:
test_labels_idx[:10]

array([3655033, 3655034, 3655035, 3655036, 3655037, 3655038, 3655039,
       3655040, 3655041, 3655042])

In [18]:
labels[:10]

array([18, -1,  6, -1, 22,  7, -1, -1, -1, -1])