In [58]:
import torch
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.nn import GCNConv
from torch_geometric.utils import to_dense_adj
from torch_geometric.nn import global_mean_pool as gap, global_max_pool as gmp

In [80]:
edge_index = torch.tensor([
    [0, 1],
    [0, 2],
    [0, 3],
    [1, 0],
    [2, 0],
    [3, 0]
], dtype=torch.long)

x = torch.tensor([[0], [0], [0], [0]], dtype=torch.float) # configs: [0, 0, 0, 0], [1, 1, 1, 1]

y = torch.tensor([[3.0]]) # ranks

data = Data(x=x, edge_index=edge_index.t().contiguous(), y=y)

In [81]:
data.num_nodes, data.num_edges, data.num_node_features, data.has_isolated_nodes(), data.has_self_loops(), data.is_directed()

(4, 6, 1, False, False, False)

In [82]:
data.keys()

['y', 'x', 'edge_index']

In [83]:
adjacency = to_dense_adj(edge_index.t().contiguous())[0]

In [84]:
adjacency

tensor([[0., 1., 1., 1.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.]])

In [85]:
data.y

tensor([[3.]])

In [86]:
class GCN(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(1, 16)
        self.conv2 = GCNConv(16, 16)
        self.out = torch.nn.Linear(16, 1)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = gap(x, None)

        # print(x)
        # return F.log_softmax(x, dim=1)
        return self.out(x)

In [88]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN().to(device)
data = data.to(device)
# data = dataset[0].to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

model.train()
for epoch in range(200):
    optimizer.zero_grad()
    out = model(data)
    # loss = F.nll_loss(out, data.y)
    loss = F.mse_loss(out, data.y)
    print("Loss:", loss)
    loss.backward()
    optimizer.step()

Loss: tensor(8.2198, grad_fn=<MseLossBackward0>)
Loss: tensor(8.1626, grad_fn=<MseLossBackward0>)
Loss: tensor(8.1055, grad_fn=<MseLossBackward0>)
Loss: tensor(8.0487, grad_fn=<MseLossBackward0>)
Loss: tensor(7.9921, grad_fn=<MseLossBackward0>)
Loss: tensor(7.9357, grad_fn=<MseLossBackward0>)
Loss: tensor(7.8795, grad_fn=<MseLossBackward0>)
Loss: tensor(7.8236, grad_fn=<MseLossBackward0>)
Loss: tensor(7.7678, grad_fn=<MseLossBackward0>)
Loss: tensor(7.7123, grad_fn=<MseLossBackward0>)
Loss: tensor(7.6571, grad_fn=<MseLossBackward0>)
Loss: tensor(7.6020, grad_fn=<MseLossBackward0>)
Loss: tensor(7.5472, grad_fn=<MseLossBackward0>)
Loss: tensor(7.4927, grad_fn=<MseLossBackward0>)
Loss: tensor(7.4383, grad_fn=<MseLossBackward0>)
Loss: tensor(7.3843, grad_fn=<MseLossBackward0>)
Loss: tensor(7.3304, grad_fn=<MseLossBackward0>)
Loss: tensor(7.2768, grad_fn=<MseLossBackward0>)
Loss: tensor(7.2235, grad_fn=<MseLossBackward0>)
Loss: tensor(7.1704, grad_fn=<MseLossBackward0>)
Loss: tensor(7.1176,