In [32]:
import dgl
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
import torch
from itertools import combinations

def fully_connected(num_nodes):
    test_list = range(num_nodes)
    edges = list(combinations(test_list, 2))
    start_nodes = [i[0] for i in edges]
    end_nodes = [i[1] for i in edges]
    return torch.tensor(start_nodes), torch.tensor(end_nodes)

# https://docs.dgl.ai/en/0.6.x/generated/dgl.batch.html
# g1 = dgl.graph((torch.tensor([0, 1, 2, 2]), torch.tensor([1, 2, 3, 1]))) # This is 4 nodes
g1 = dgl.graph(fully_connected(4))
a = torch.randn(4, 3) # num nodes x num features
# print(type(g1))
g1 = dgl.to_homogeneous(g1)
g1 = dgl.add_reverse_edges(g1)

# g2 = dgl.graph((torch.tensor([0, 0, 0, 1]), torch.tensor([0, 1, 2, 0]))) # This is 3 nodes
g2 = dgl.graph(fully_connected(3))
b= torch.randn(3, 3)
g2 = dgl.add_reverse_edges(dgl.to_homogeneous(g2))

# g3 = dgl.graph((torch.tensor([0, 0, 0, 1, 2, 4]), torch.tensor([0, 1, 2, 0, 3, 0])))
g3 = dgl.graph(fully_connected(5))
c = torch.randn(5, 3)
g3 = dgl.add_reverse_edges(dgl.to_homogeneous(g3))


bg = dgl.batch([g1, g2, g3])
# bg = dgl.add_reverse_edges(bg)
print(bg.batch_size)
print(bg.batch_num_nodes())
print(bg.batch_num_edges())

# h = g1.in_degrees().view(-1, 1).float()
# h = bg.ndata['h']
# print(h.shape)
# print(h)


3
tensor([4, 3, 5])
tensor([12,  6, 20])


In [35]:
import dgl
import numpy as np
import torch as th
from dgl.nn import GraphConv
from dgl.nn import AvgPooling
# g = dgl.graph(([0,1,2,3,2,5], [1,2,3,4,0,3]))
# g = dgl.add_self_loop(g)
# feat = torch.ones(6, 10)

# in features, out features
feats = torch.cat([a, b, c])
print(feats.shape) # 12, 3 - they're just stacked column wise
conv = GraphConv(3, 8, norm='both', weight=True, bias=True)
res = conv(bg, feats)

print(res.shape)

avgpool = AvgPooling()  # create an average pooling layer
res2 = avgpool(bg, res) # number of graphs x features

print(res2)

torch.Size([12, 3])
torch.Size([12, 8])
tensor([[-0.3521, -0.0996, -0.6933,  0.7285,  0.4690,  0.3049, -0.5690, -0.0132],
        [ 0.1832,  0.1552,  0.3596, -0.3692, -0.3798, -0.3125,  0.3494, -0.1877],
        [-0.4008, -0.0981,  0.0583,  0.0159, -0.0427,  0.5975, -0.3282,  0.1869]],
       grad_fn=<DivBackward0>)


In [5]:

class SAGE(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats):
        super().__init__()
        self.conv1 = dglnn.SAGEConv(
            in_feats=in_feats, out_feats=hid_feats, aggregator_type='mean')
        self.conv2 = dglnn.SAGEConv(
            in_feats=hid_feats, out_feats=out_feats, aggregator_type='mean')

    def forward(self, graph, inputs):
        # inputs are features of nodes
        h = self.conv1(graph, inputs)
        h = F.relu(h)
        h = self.conv2(graph, h)
        return h
        

In [4]:
node_features = graph.ndata['feat']
node_labels = graph.ndata['label']
train_mask = graph.ndata['train_mask']
valid_mask = graph.ndata['val_mask']
test_mask = graph.ndata['test_mask']
n_features = node_features.shape[1]
n_labels = int(node_labels.max().item() + 1)

model = SAGE(in_feats=n_features, hid_feats=100, out_feats=n_labels)
opt = torch.optim.Adam(model.parameters())

for epoch in range(10):
    model.train()
    # forward propagation by using all nodes
    logits = model(graph, node_features)
    # compute loss
    loss = F.cross_entropy(logits[train_mask], node_labels[train_mask])
    # compute validation accuracy
    acc = evaluate(model, graph, node_features, node_labels, valid_mask)
    # backward propagation
    opt.zero_grad()
    loss.backward()
    opt.step()
    print(loss.item())

    # Save model if necessary.  Omitted in this example.

SAGE(
  (conv1): SAGEConv(
    (feat_drop): Dropout(p=0.0, inplace=False)
    (fc_self): Linear(in_features=10, out_features=5, bias=False)
    (fc_neigh): Linear(in_features=10, out_features=5, bias=False)
  )
  (conv2): SAGEConv(
    (feat_drop): Dropout(p=0.0, inplace=False)
    (fc_self): Linear(in_features=5, out_features=2, bias=False)
    (fc_neigh): Linear(in_features=5, out_features=2, bias=False)
  )
)
