# Load Cora Dataset 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
device = torch.device('cpu')
import numpy as np
import random

data = torch.load('data.pth')
g = data['g'].to(device)
feat = data['feat'].to(device)
label = data['label'].to(device)
train_nodes = data['train_nodes']
val_nodes = data['val_nodes']
test_nodes = data['test_nodes']

def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
setup_seed(20)

Using backend: pytorch


# Define GCN

In [2]:
from dgl.nn import GraphConv

class GCN(nn.Module):
    """Graph Convolution Network (GCN)

    Example
    -------
    # GCN with one hidden layer
    >>> model = GCN(100, 10, hid=32)
    """
    def __init__(self,
                 in_feats: int,
                 out_feats: int,
                 hid: list = 16,
                 dropout: float = 0.5):
        super().__init__()
        self.conv1 = GraphConv(in_feats, hid)
        self.conv2 = GraphConv(hid, out_feats)
        self.dropout = nn.Dropout(dropout)

    def forward(self, g, feat):

        if torch.is_tensor(g):
            feat = self.dropout(feat)
            feat = g @ (feat @ self.conv1.weight) + self.conv1.bias
            feat = F.relu(feat)
            feat = self.dropout(feat)
            feat = g @ (feat @ self.conv2.weight) + self.conv2.bias
            return feat
        
        g = g.add_self_loop()
        feat = self.dropout(feat)
        feat = self.conv1(g, feat)
        feat = F.relu(feat)
        feat = self.dropout(feat)
        feat = self.conv2(g, feat)
        return feat


# Train

In [3]:
def train():
    model.train()
    optimizer.zero_grad()
    loss_fn(model(g, feat)[train_nodes], label[train_nodes]).backward()
    optimizer.step()


@torch.no_grad()
def test():
    model.eval()
    logits, accs = model(g, feat), []
    for nodes in (train_nodes, val_nodes, test_nodes):
        pred = logits[nodes].max(1)[1]
        acc = pred.eq(label[nodes]).float().mean()
        accs.append(acc)
    return accs


num_feats = feat.size(1)
num_classes = int(label.max() + 1)
model = GCN(num_feats, num_classes).to(device)

optimizer = torch.optim.Adam([
    dict(params=model.conv1.parameters(), weight_decay=5e-4),
    dict(params=model.conv2.parameters(), weight_decay=0)
], lr=0.01)  # Only perform weight-decay on first convolution.

loss_fn = nn.CrossEntropyLoss()

best_val_acc = test_acc = 0
for epoch in range(1, 101):
    train()
    train_acc, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, '
          f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')

Epoch: 001, Train: 0.5081, Val: 0.4498, Test: 0.4386
Epoch: 002, Train: 0.5806, Val: 0.5382, Test: 0.5045
Epoch: 003, Train: 0.6492, Val: 0.5502, Test: 0.5337
Epoch: 004, Train: 0.6815, Val: 0.5703, Test: 0.5578
Epoch: 005, Train: 0.6935, Val: 0.5904, Test: 0.5795
Epoch: 006, Train: 0.7177, Val: 0.6185, Test: 0.6021
Epoch: 007, Train: 0.7460, Val: 0.6426, Test: 0.6408
Epoch: 008, Train: 0.7823, Val: 0.7108, Test: 0.6781
Epoch: 009, Train: 0.8145, Val: 0.7430, Test: 0.7108
Epoch: 010, Train: 0.8347, Val: 0.7631, Test: 0.7309
Epoch: 011, Train: 0.8468, Val: 0.7671, Test: 0.7445
Epoch: 012, Train: 0.8548, Val: 0.7751, Test: 0.7545
Epoch: 013, Train: 0.8629, Val: 0.7751, Test: 0.7545
Epoch: 014, Train: 0.8790, Val: 0.7831, Test: 0.7676
Epoch: 015, Train: 0.8831, Val: 0.7831, Test: 0.7676
Epoch: 016, Train: 0.8911, Val: 0.7992, Test: 0.7817
Epoch: 017, Train: 0.9234, Val: 0.8233, Test: 0.8048
Epoch: 018, Train: 0.9274, Val: 0.8434, Test: 0.8154
Epoch: 019, Train: 0.9355, Val: 0.8594, Test: 

# Evaluate

In [4]:
target = 1
target_label = label[target]
print("target label: ", target_label)

print("model predict: ", model(g, feat)[target].argmax())

target label:  tensor(2)
model predict:  tensor(2)


In [5]:
# save vitim model
torch.save(model.state_dict(), 'model.pth')