In [1]:
import dgl
import dgl.data
import torch
import torch.nn as nn
import torch.nn.functional as F


In [2]:
dataset = dgl.data.CoraGraphDataset("./")

  NumNodes: 2708
  NumEdges: 10556
  NumFeats: 1433
  NumClasses: 7
  NumTrainingSamples: 140
  NumValidationSamples: 500
  NumTestSamples: 1000
Done loading data from cached files.


In [3]:
print("n of categories: ", dataset.num_classes)

n of categories:  7


In [4]:
g = dataset[0]

In [5]:
print("Cora数据集", g.ndata["feat"].shape)
print("Node features")
print(g.ndata)
print("Edge features")
print(g.edata)

Cora数据集 torch.Size([2708, 1433])
Node features
{'train_mask': tensor([ True,  True,  True,  ..., False, False, False]), 'label': tensor([3, 4, 4,  ..., 3, 3, 3]), 'val_mask': tensor([False, False, False,  ..., False, False, False]), 'test_mask': tensor([False, False, False,  ...,  True,  True,  True]), 'feat': tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])}
Edge features
{}


In [6]:
print(g.ndata["train_mask"].sum().item())
print(g.ndata["test_mask"].sum().item())

140
1000


In [7]:
from dgl.nn import GraphConv

In [8]:
class GCN(nn.Module):
    def __init__(self, in_feat, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feat, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)

    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        return h

In [9]:
from tensorboardX import SummaryWriter

writer_train = SummaryWriter("./records/cora_gcn/train")
writer_test = SummaryWriter("./records/cora_gcn/test")
writer_val = SummaryWriter("./records/cora_gcn/val")

n_epochs = 0


def train(g, model, epochs=20):
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    features = g.ndata["feat"]
    labels = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]

    for e in range(epochs):
        logits = model(g, features)
        pred = logits.argmax(1)

        loss = F.cross_entropy(logits[train_mask], labels[train_mask])

        with torch.no_grad():
            train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
            test_acc = (pred[test_mask] == labels[test_mask]).float().mean()
            val_acc = (pred[val_mask] == labels[val_mask]).float().mean()

            writer_test.add_scalar("acc", test_acc, e)
            writer_train.add_scalar("acc", train_acc, e)
            writer_val.add_scalar("acc", val_acc, e)

        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if e % 10 == 0:
            print(e, "loss: ", loss.item())

    writer_test.close()
    writer_train.close()
    writer_val.close()

In [10]:
g.ndata.keys()

dict_keys(['train_mask', 'label', 'val_mask', 'test_mask', 'feat'])

In [11]:
n_feat = g.ndata["feat"].shape[1]
n_classes = dataset.num_classes
device = "cuda" if torch.cuda.is_available() else "cpu"

model = GCN(n_feat, 16, n_classes).to(device)
g = g.to(device)

In [12]:
train(g, model, 1000)

0 loss:  1.9452993869781494
10 loss:  1.793276071548462
20 loss:  1.5462335348129272
30 loss:  1.225893259048462
40 loss:  0.8872310519218445
50 loss:  0.5988467931747437
60 loss:  0.3923974335193634
70 loss:  0.2590206563472748
80 loss:  0.17637573182582855
90 loss:  0.12505769729614258
100 loss:  0.09231523424386978
110 loss:  0.07068513333797455
120 loss:  0.05589540675282478
130 loss:  0.045385170727968216
140 loss:  0.03766738995909691
150 loss:  0.031824272125959396
160 loss:  0.027294205501675606
170 loss:  0.023703640326857567
180 loss:  0.02080441266298294
190 loss:  0.01842677779495716
200 loss:  0.016451487317681313
210 loss:  0.014789470471441746
220 loss:  0.013376478105783463
230 loss:  0.012164352461695671
240 loss:  0.01111654844135046
250 loss:  0.010203568264842033
260 loss:  0.009402435272932053
270 loss:  0.008695924654603004
280 loss:  0.008069016970694065
290 loss:  0.00750992214307189
300 loss:  0.0070089614018797874
310 loss:  0.006558095570653677
320 loss:  0.0