In [1]:
import dgl
import os
import time
import pandas as pd
import torch as th
import torch.nn.functional as F
import numpy as np
from dgl.dataloading.neighbor import MultiLayerNeighborSampler
from dgl.dataloading.pytorch import NodeDataLoader

from models import GraphSageModel, GraphConvModel, GraphAttnModel
from utils import load_dgl_graph, time_diff
from model_train import load_subtensor

Using backend: pytorch


In [55]:
# paths
output_dir = "./output/experiment-2021-11-18-46422"
model_name = "model-best-val-acc-0.56798.pth" #"dgl_model-009116.pth"
data_path = "../../dataset"
model_path = os.path.join(output_dir, model_name)

In [56]:
# model settings
gnn_model = "graphsage"
in_feat = 300
hidden_dim = [192, 64] 
n_layers = 3
fanouts = [15, 15, 15]  # 注意当改变n_layers时也要改变fanouts
batch_size = 4096
num_workers = 4
device_id = 0

n_classes = 23
LABELS = [chr(ord('A')+i) for i in range(n_classes)]

In [4]:
# Retrieve preprocessed data and add reverse edge and self-loop
graph, labels, train_nid, val_nid, test_nid, node_feat = load_dgl_graph(data_path)
test_nid = test_nid[:591972]  # 只预测validation中的结点
graph = dgl.to_bidirected(graph, copy_ndata=True)
graph = dgl.add_self_loop(graph)

################ Graph info: ###############
Graph(num_nodes=3655452, num_edges=29168650,
      ndata_schemes={}
      edata_schemes={})
################ Label info: ################
Total labels (including not labeled): 3655452
               Training label number: 939963
             Validation label number: 104454
                   Test label number: 592391
################ Feature info: ###############
Node's feature shape:torch.Size([3655452, 300])


In [5]:
node_feat

tensor([[ 9.0308e-01,  7.9809e-01, -2.0559e-01,  ..., -1.0074e+00,
          7.0118e-01, -4.9786e-01],
        [-5.9863e-01,  4.4366e-01, -1.0016e+00,  ..., -2.0384e+00,
          1.0898e+00,  7.3255e-01],
        [ 1.5540e+00,  3.0408e+00,  4.8199e-01,  ...,  1.1197e+00,
         -5.3127e-01, -1.8786e+00],
        ...,
        [-2.4577e-14,  8.7639e-15,  2.0788e-14,  ...,  2.0809e-14,
         -1.6891e-14, -3.0355e-14],
        [-2.4577e-14,  8.7639e-15,  2.0788e-14,  ...,  2.0809e-14,
         -1.6891e-14, -3.0355e-14],
        [-2.4577e-14,  8.7639e-15,  2.0788e-14,  ...,  2.0809e-14,
         -1.6891e-14, -3.0355e-14]])

In [57]:
sampler = MultiLayerNeighborSampler(fanouts)
test_dataloader = NodeDataLoader(graph,
                                  test_nid,
                                  sampler,
                                  batch_size=4096,  #len(test_nid),
                                  shuffle=False,
                                  drop_last=False,
                                  num_workers=num_workers,
                                  )

In [58]:
if gnn_model == 'graphsage':
    model = GraphSageModel(in_feat, hidden_dim, n_layers, n_classes)
elif gnn_model == 'graphconv':
    model = GraphConvModel(in_feat, hidden_dim, n_layers, n_classes,
                           norm='both', activation=F.relu, dropout=0)
elif gnn_model == 'graphattn':
    model = GraphAttnModel(in_feat, hidden_dim, n_layers, n_classes,
                           heads=([5] * n_layers), activation=F.relu, feat_drop=0, attn_drop=0)
else:
    raise NotImplementedError('So far, only support three algorithms: GraphSage, GraphConv, and GraphAttn')

In [59]:
# laod model parameters
params = th.load(model_path)
model.load_state_dict(params)
model.to(device_id)

GraphSageModel(
  (dropout): Dropout(p=0, inplace=False)
  (layers): ModuleList(
    (0): SAGEConv(
      (feat_drop): Dropout(p=0.0, inplace=False)
      (fc_self): Linear(in_features=300, out_features=192, bias=False)
      (fc_neigh): Linear(in_features=300, out_features=192, bias=False)
    )
    (1): SAGEConv(
      (feat_drop): Dropout(p=0.0, inplace=False)
      (fc_self): Linear(in_features=192, out_features=64, bias=False)
      (fc_neigh): Linear(in_features=192, out_features=64, bias=False)
    )
    (2): SAGEConv(
      (feat_drop): Dropout(p=0.0, inplace=False)
      (fc_self): Linear(in_features=64, out_features=23, bias=False)
      (fc_neigh): Linear(in_features=64, out_features=23, bias=False)
    )
  )
)

In [60]:
n = 0
def prod(x):
    if len(x) == 0:
        return None
    elif len(x) == 1:
        return x[0]
    else:
        return x[0] * prod(x[1:])
for k, v in params.items():
    print(f"{k}: {v.shape} {prod(v.shape)}")
    n += prod(v.shape)
print(f"total: {n}")

layers.0.bias: torch.Size([192]) 192
layers.0.fc_self.weight: torch.Size([192, 300]) 57600
layers.0.fc_neigh.weight: torch.Size([192, 300]) 57600
layers.1.bias: torch.Size([64]) 64
layers.1.fc_self.weight: torch.Size([64, 192]) 12288
layers.1.fc_neigh.weight: torch.Size([64, 192]) 12288
layers.2.bias: torch.Size([23]) 23
layers.2.fc_self.weight: torch.Size([23, 64]) 1472
layers.2.fc_neigh.weight: torch.Size([23, 64]) 1472
total: 142999


In [61]:
model.eval()
test_logits = np.zeros((0, 23))
test_idx = []
for step, (input_nodes, seeds, blocks) in enumerate(test_dataloader):
    # forward
    batch_inputs, batch_labels = load_subtensor(node_feat, labels, seeds, input_nodes, device_id)
    blocks = [block.to(device_id) for block in blocks]
    
    test_batch_logits = model(blocks, batch_inputs)
    test_batch_logits = test_batch_logits.cpu().detach().numpy()
    test_logits = np.concatenate([test_logits, test_batch_logits], axis=0)
    test_idx.extend(seeds.cpu().detach().tolist())
    
print("Predict Done ...")

Predict Done ...


In [62]:
id_labels = pd.read_csv("../../dataset/IDandLabels.csv")

  has_raised = await self.run_ast_nodes(code_ast.body, cell_name,


In [63]:
id_labels.loc[test_idx]

Unnamed: 0,node_idx,paper_id,Label,Split_ID
3063061,3063061,c39457cc34fa969b03819eaa4f9b7a52,,1
3063062,3063062,668b9d0c53e9b6e2c6b1093102f976b3,,1
3063063,3063063,ca5c7bc1b40c0ef3c3f864aed032ca90,,1
3063064,3063064,44f810c0c000cda27ce618add55e815f,,1
3063065,3063065,3c206335d88637d36d83c2942586be98,,1
...,...,...,...,...
3655028,3655028,5e231ec5d4167c541055092ee6e65a74,,1
3655029,3655029,25f30607d6bfd52ca2780d8ea928e77e,,1
3655030,3655030,703abf983edaaef1d34091eabb4ffd20,,1
3655031,3655031,d051d9bec90a57152776fc5e9b08e5b2,,1


In [101]:
ids = id_labels.loc[test_idx]['paper_id']
test_pred = test_logits.argmax(axis=1)

## 读取推断出的测试节点标签，进行替换

In [103]:
# 读取推断出的测试节点标签，进行替换
import pickle as pkl
with open(os.path.join(data_path, 'infer_nodes.pkl'), 'rb') as f:
    infer_nodes = pkl.load(f)
infer_test_idx = infer_nodes['test_idx']
infer_test_label = infer_nodes['test_lab']
infer_test_label.shape

(14214,)

In [102]:
sum(test_pred[infer_test_idx - test_idx[0]] == infer_test_label)  # 被推断的结点中，使用模型预测的结果有多少与推断的相同

10708

In [95]:
test_pred[infer_test_idx - test_idx[0]] = infer_test_label  # 进行替换

## 生成提交

In [97]:
sub = pd.DataFrame({'id': ids, 'label': test_pred})

In [98]:
sub['label'] = sub['label'].apply(lambda x: LABELS[x])

In [99]:
sub.head(15)

Unnamed: 0,id,label
3063061,c39457cc34fa969b03819eaa4f9b7a52,D
3063062,668b9d0c53e9b6e2c6b1093102f976b3,N
3063063,ca5c7bc1b40c0ef3c3f864aed032ca90,G
3063064,44f810c0c000cda27ce618add55e815f,F
3063065,3c206335d88637d36d83c2942586be98,K
3063066,c380307ccc10012f0a2a28e82f596745,D
3063067,34954a463b1a3dc3b02efcf439c6dfcf,D
3063068,367437da3355555ef1420de6b03ff6a6,D
3063069,951be1a3e76e22c97d216be961a50abd,P
3063070,b2ea7f2f33ec16ba9d2b3ddd2af1f22c,P


In [100]:
fn = os.path.join(output_dir, f"{'.'.join(model_name.split('.')[:-1])}-{int(time.time())}.csv")
sub.to_csv(fn, index=False)
print(f"Saved to {fn} ...")

Saved to ./output/experiment-2021-11-18-46422/model-best-val-acc-0.56798-1637331077.csv ...


## CAN后处理

In [None]:
from collections import Counter

from utils import CAN

In [None]:
t_labels = id_labels[~id_labels['Label'].isna()]['Label'].to_numpy()
t_labels = Counter(t_labels)
t_labels

In [None]:
prior = list(t_labels.items())
prior.sort(key=lambda x: x[0])
prior = np.array([e[1] for e in prior])
prior = prior / sum(prior)
prior

In [None]:
tt = th.softmax(th.Tensor(test_logits), dim=1)

In [None]:
tt.data.numpy().shape

In [None]:
(tt.max(axis=1).values >= 0.5).sum()

In [141]:
%%time
# 太慢了
adjusted = CAN(tt.data.numpy(), prior, tau=.8)

(251626, 23) (340346, 23)
0
10000
20000
30000
40000
50000
60000
70000
80000
90000
100000
110000
120000
130000
140000
150000
160000
170000
180000
190000
200000
210000
220000
230000
240000
250000
260000
270000
280000
290000
300000
310000
320000
330000
340000
CPU times: user 3h 57min 55s, sys: 5.11 s, total: 3h 58min
Wall time: 3h 57min 53s


In [143]:
adjusted.shape

(591972, 23)