In [1]:
import os
import torch
%matplotlib inline
import networkx as nx
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def visualize_graph(G, color):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=False,
                     node_color=color, cmap="Set2")
    plt.show()

In [3]:
from torch_geometric.datasets import KarateClub

dataset = KarateClub()

In [4]:
data = dataset[0]
data

Data(x=[34, 34], edge_index=[2, 156], y=[34], train_mask=[34])

In [9]:
import torch
from torch.nn import Linear
from torch_geometric.nn import GCNConv


class GCN(torch.nn.Module):
    def __init__(self):
        super(GCN, self).__init__()
        torch.manual_seed(12345)
        self.conv1 = GCNConv(1, 4)
        self.conv2 = GCNConv(4, 4)
        self.conv3 = GCNConv(4, 2)
        self.classifier = Linear(2, 34)

    def forward(self, x, edge_index):
        print(x)
        h = self.conv1(x, edge_index)
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()  # Final GNN embedding space.
        
        # Apply a final (linear) classifier.
        out = self.classifier(h)

        return out, h

model = GCN()
print(model)

GCN(
  (conv1): GCNConv(34, 4)
  (conv2): GCNConv(4, 4)
  (conv3): GCNConv(4, 2)
  (classifier): Linear(in_features=2, out_features=4, bias=True)
)


In [14]:
from torch_geometric.data import InMemoryDataset, Data
N = 34
for x in range(34):
    cur_features = torch.zeros((N,1))
    cur_features[x] = 1
    cur_data = Data(x=torch.tensor(cur_features).float(), edge_index=data.edge_index)
    print(cur_data)

Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
Data(x=[34, 1], edge_index=[2, 156])
D

  cur_data = Data(x=torch.tensor(cur_features).float(), edge_index=data.edge_index)


In [11]:
model = GCN()
criterion = torch.nn.CrossEntropyLoss()  #Initialize the CrossEntropyLoss function.
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)  # Initialize the Adam optimizer.

def train(data):
    optimizer.zero_grad()  # Clear gradients.
    print(data.edge_index)
    out, h = model(data.x, data.edge_index)  # Perform a single forward pass.
    loss = criterion(out[data.train_mask], data.y[data.train_mask])  # Compute the loss solely based on the training nodes.
    loss.backward()  # Derive gradients.
    optimizer.step()  # Update parameters based on gradients.
    return loss, h

for epoch in range(401):
    loss, h = train(data)
    print(f'Epoch: {epoch}, Loss: {loss}')

tensor([[ 0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  1,  1,
          1,  1,  1,  1,  1,  1,  1,  2,  2,  2,  2,  2,  2,  2,  2,  2,  2,  3,
          3,  3,  3,  3,  3,  4,  4,  4,  5,  5,  5,  5,  6,  6,  6,  6,  7,  7,
          7,  7,  8,  8,  8,  8,  8,  9,  9, 10, 10, 10, 11, 12, 12, 13, 13, 13,
         13, 13, 14, 14, 15, 15, 16, 16, 17, 17, 18, 18, 19, 19, 19, 20, 20, 21,
         21, 22, 22, 23, 23, 23, 23, 23, 24, 24, 24, 25, 25, 25, 26, 26, 27, 27,
         27, 27, 28, 28, 28, 29, 29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 31, 31,
         31, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33,
         33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33, 33],
        [ 1,  2,  3,  4,  5,  6,  7,  8, 10, 11, 12, 13, 17, 19, 21, 31,  0,  2,
          3,  7, 13, 17, 19, 21, 30,  0,  1,  3,  7,  8,  9, 13, 27, 28, 32,  0,
          1,  2,  7, 12, 13,  0,  6, 10,  0,  6, 10, 16,  0,  4,  5, 16,  0,  1,
          2,  3,  0,  2, 30, 32, 33,  2, 33,  0,  4

KeyboardInterrupt: 