In [None]:
import os
import torch


In [None]:
from torch_geometric.datasets import KarateClub
dataset = KarateClub()
print(f"dataset: {dataset}")
print(f"Number of Graphs: {len(dataset)}")
print(f"Number of Features: {dataset.num_features}")
print(f"Number of Classes: {dataset.num_classes}")

In [None]:
data = dataset[0]
print(data)
print(f'Number of Nodes: {data.num_nodes}')
print(f'Number of Edges: {data.num_edges}')
print(f'Average Node Degree: {data.num_edges / data.num_nodes:.2f}')
print(f'Number of Training Nodes: {data.train_mask.sum()}')
print(f'Training Node Label Rate: {int(data.train_mask.sum()) / data.num_nodes:.2f}')
print(f'Contains Isolated Nodes: {data.contains_isolated_nodes()}')
print(f'Contains Self-Loops: {data.contains_self_loops()}')
print(f'Is Undirected: {data.is_undirected()}')


In [None]:
edge_index = data.edge_index
print(edge_index)

In [None]:
from torch_geometric.utils import to_dense_adj
adj = to_dense_adj(edge_index)
print(adj)

In [None]:
%matplotlib inline
import networkx as nx
import matplotlib.pyplot as plt
def visualize_graph(G, color):
    plt.figure(figsize=(7,7))
    plt.xticks([])
    plt.yticks([])
    nx.draw_networkx(G, pos=nx.spring_layout(G, seed=42), with_labels=True, node_color=color, cmap="Set2")
    plt.show()

from torch_geometric.utils.convert import to_networkx
G = to_networkx(data, to_undirected=True)
visualize_graph(G, color=data.y)

In [None]:
import torch
import torch.nn as nn
from torch.nn import Linear
from torch_geometric.nn import GCNConv
class GCN(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = GCNConv(dataset.num_features, 4)
        self.conv2 = GCNConv(4, 4)
        self.conv3 = GCNConv(4, 2)
        self.classifier = Linear(2, dataset.num_classes)

    def forward(self, x, edge_index):
        h = self.conv1(x, edge_index)
        h = h.tanh()
        h = self.conv2(h, edge_index)
        h = h.tanh()
        h = self.conv3(h, edge_index)
        h = h.tanh()
        out = self.classifier(h)

        return out, h

model = GCN()
print(model)