-
Notifications
You must be signed in to change notification settings - Fork 16
/
node_clf_mb.py
38 lines (30 loc) · 1.18 KB
/
node_clf_mb.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
"""使用邻居采样的顶点分类GNN
https://docs.dgl.ai/en/latest/guide/minibatch-node.html
"""
import torch
import torch.nn.functional as F
import torch.optim as optim
from dgl.data import CiteseerGraphDataset
from dgl.dataloading import MultiLayerFullNeighborSampler, NodeDataLoader
from gnn.dgl.model import GCN
def main():
data = CiteseerGraphDataset()
g = data[0]
train_idx = g.ndata['train_mask'].nonzero(as_tuple=True)[0]
sampler = MultiLayerFullNeighborSampler(2)
dataloader = NodeDataLoader(g, train_idx, sampler, batch_size=32)
model = GCN(g.ndata['feat'].shape[1], 100, data.num_classes)
optimizer = optim.Adam(model.parameters())
for epoch in range(30):
model.train()
losses = []
for input_nodes, output_nodes, blocks in dataloader:
logits = model(blocks, blocks[0].srcdata['feat'])
loss = F.cross_entropy(logits, blocks[-1].dstdata['label'])
losses.append(loss.item())
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Epoch {:d} | Loss {:.4f}'.format(epoch + 1, torch.tensor(losses).mean().item()))
if __name__ == '__main__':
main()