In [1]:
import dgl
import os
import time
import pandas as pd
import torch as th
import torch.nn.functional as F
import numpy as np
from dgl.dataloading.neighbor import MultiLayerNeighborSampler
from dgl.dataloading.pytorch import NodeDataLoader

from models import GraphSageModel, GraphConvModel, GraphAttnModel
from utils import load_dgl_graph, time_diff
from model_train import load_subtensor

Using backend: pytorch


In [93]:
# paths
output_dir = "./output/experiment-2021-10-23-1694"
model_name = "model-best-val-acc-0.523.pth"
data_path = "../../dataset"
model_path = os.path.join(output_dir, model_name)

In [94]:
# model settings
gnn_model = "graphsage"
in_feat = 300
hidden_dim = 64 
n_layers = 2
fanouts = [20, 20]
batch_size = 4096
num_workers = 4
device_id = 0

n_classes = 23
LABELS = [chr(ord('A')+i) for i in range(n_classes)]

In [95]:
# Retrieve preprocessed data and add reverse edge and self-loop
graph, labels, train_nid, val_nid, test_nid, node_feat = load_dgl_graph(data_path)
test_nid = test_nid[:591972]  # 只预测validation中的结点
graph = dgl.to_bidirected(graph, copy_ndata=True)
graph = dgl.add_self_loop(graph)

################ Graph info: ###############
Graph(num_nodes=3655452, num_edges=29168650,
      ndata_schemes={}
      edata_schemes={})
################ Label info: ################
Total labels (including not labeled): 3655452
               Training label number: 939963
             Validation label number: 104454
                   Test label number: 592391
################ Feature info: ###############
Node's feature shape:torch.Size([3655452, 300])


In [96]:
sampler = MultiLayerNeighborSampler(fanouts)
test_dataloader = NodeDataLoader(graph,
                                  test_nid,
                                  sampler,
                                  batch_size=4096,  #len(test_nid),
                                  shuffle=False,
                                  drop_last=False,
                                  num_workers=num_workers,
                                  )

In [97]:
if gnn_model == 'graphsage':
    model = GraphSageModel(in_feat, hidden_dim, n_layers, n_classes)
elif gnn_model == 'graphconv':
    model = GraphConvModel(in_feat, hidden_dim, n_layers, n_classes,
                           norm='both', activation=F.relu, dropout=0)
elif gnn_model == 'graphattn':
    model = GraphAttnModel(in_feat, hidden_dim, n_layers, n_classes,
                           heads=([5] * n_layers), activation=F.relu, feat_drop=0, attn_drop=0)
else:
    raise NotImplementedError('So far, only support three algorithms: GraphSage, GraphConv, and GraphAttn')

In [98]:
# laod model parameters
params = th.load(model_path)
model.load_state_dict(params)
model.to(device_id)

GraphSageModel(
  (dropout): Dropout(p=0, inplace=False)
  (layers): ModuleList(
    (0): SAGEConv(
      (feat_drop): Dropout(p=0.0, inplace=False)
      (fc_self): Linear(in_features=300, out_features=64, bias=False)
      (fc_neigh): Linear(in_features=300, out_features=64, bias=False)
    )
    (1): SAGEConv(
      (feat_drop): Dropout(p=0.0, inplace=False)
      (fc_self): Linear(in_features=64, out_features=23, bias=False)
      (fc_neigh): Linear(in_features=64, out_features=23, bias=False)
    )
  )
)

In [99]:
model.eval()
test_logits = np.zeros((0, 23))
test_idx = []
for step, (input_nodes, seeds, blocks) in enumerate(test_dataloader):
    # forward
    batch_inputs, batch_labels = load_subtensor(node_feat, labels, seeds, input_nodes, device_id)
    blocks = [block.to(device_id) for block in blocks]
    
    test_batch_logits = model(blocks, batch_inputs)
    test_batch_logits = test_batch_logits.cpu().detach().numpy()
    test_logits = np.concatenate([test_logits, test_batch_logits], axis=0)
    test_idx.extend(seeds.cpu().detach().tolist())
    
print("Predict Done ...")

Predict Done ...


In [100]:
id_labels = pd.read_csv("../../dataset/IDandLabels.csv")

  has_raised = await self.run_ast_nodes(code_ast.body, cell_name,


In [101]:
id_labels.loc[test_idx]

Unnamed: 0,node_idx,paper_id,Label,Split_ID
3063061,3063061,c39457cc34fa969b03819eaa4f9b7a52,,1
3063062,3063062,668b9d0c53e9b6e2c6b1093102f976b3,,1
3063063,3063063,ca5c7bc1b40c0ef3c3f864aed032ca90,,1
3063064,3063064,44f810c0c000cda27ce618add55e815f,,1
3063065,3063065,3c206335d88637d36d83c2942586be98,,1
...,...,...,...,...
3655028,3655028,5e231ec5d4167c541055092ee6e65a74,,1
3655029,3655029,25f30607d6bfd52ca2780d8ea928e77e,,1
3655030,3655030,703abf983edaaef1d34091eabb4ffd20,,1
3655031,3655031,d051d9bec90a57152776fc5e9b08e5b2,,1


In [102]:
ids = id_labels.loc[test_idx]['paper_id']
test_pred = test_logits.argmax(axis=1)

In [103]:
sub = pd.DataFrame({'id': ids, 'label': test_pred})

In [104]:
sub['label'] = sub['label'].apply(lambda x: LABELS[x])

In [105]:
sub.head(5)

Unnamed: 0,id,label
3063061,c39457cc34fa969b03819eaa4f9b7a52,D
3063062,668b9d0c53e9b6e2c6b1093102f976b3,N
3063063,ca5c7bc1b40c0ef3c3f864aed032ca90,R
3063064,44f810c0c000cda27ce618add55e815f,F
3063065,3c206335d88637d36d83c2942586be98,L


In [106]:
fn = os.path.join(output_dir, f"{'.'.join(model_name.split('.')[:-1])}-{int(time.time())}.csv")
sub.to_csv(fn, index=False)
print(f"Saved to {fn} ...")

Saved to ./output/experiment-2021-10-23-1694/model-best-val-acc-0-1635056614.csv ...


'model-best-val-acc-0.523'