필요한 Library

In [21]:
from torch_geometric.datasets import Amazon
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import to_networkx
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import torch

Amazon Dataset

In [18]:
dataset = Amazon(root='./Amazon_data', name='computers')
data = dataset[0]

In [22]:
# Split the dataset into training, validation, and test sets
num_nodes = data.x.shape[0]
num_train = int(0.8 * num_nodes)
num_val = int(0.1 * num_nodes)
num_test = num_nodes - num_train - num_val

# Create masks for training, validation, and test nodes
data.train_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.val_mask = torch.zeros(num_nodes, dtype=torch.bool)
data.test_mask = torch.zeros(num_nodes, dtype=torch.bool)

# Assign nodes to masks
data.train_mask[:num_train] = 1
data.val_mask[num_train:num_train + num_val] = 1
data.test_mask[num_train + num_val:] = 1

Graph 구조와 특징

In [24]:
# Print information about the dataset
print(f'Dataset: {dataset}')
print('-------------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of nodes: {data.x.shape[0]}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

# Print information about the graph
print(f'\nGraph:')
print('------')
print(f'Training nodes: {num_train}')
print(f'Evaluation nodes: {num_val}')
print(f'Test nodes: {num_test}')
print(f'Edges are directed: {data.is_directed()}')
print(f'Graph has isolated nodes: {data.has_isolated_nodes()}')
print(f'Graph has loops: {data.has_self_loops()}')

Dataset: AmazonComputers()
-------------------
Number of graphs: 1
Number of nodes: 13752
Number of features: 767
Number of classes: 10

Graph:
------
Training nodes: 11001
Evaluation nodes: 1375
Test nodes: 1376
Edges are directed: False
Graph has isolated nodes: True
Graph has loops: False


Minibatch 생성


Minibatch란 대용량의 Graph를 나누어 학습하기 위해 나온 것으로
학습할 노드들을 랜덤하게 선택하여 하나의 Minibatch를 생성

본 데이터 셋은 데이터의 양이 적어 Full-batch로 진행

In [26]:
# Create batches with neighbor sampling
train_loader = NeighborLoader(
    data,
    num_neighbors=[5,5],
    batch_size=1024,
    input_nodes=data.train_mask,
)

생성된 Minibatch의 설명

In [27]:
# Print each subgraph
for i, subgraph in enumerate(train_loader):
    print(f'Subgraph {i}: {subgraph}')

Subgraph 0: Data(x=[8417, 767], edge_index=[2, 19740], y=[8417], train_mask=[8417], val_mask=[8417], test_mask=[8417], n_id=[8417], e_id=[19740], input_id=[1024], batch_size=1024)
Subgraph 1: Data(x=[8443, 767], edge_index=[2, 19827], y=[8443], train_mask=[8443], val_mask=[8443], test_mask=[8443], n_id=[8443], e_id=[19827], input_id=[1024], batch_size=1024)
Subgraph 2: Data(x=[8491, 767], edge_index=[2, 20010], y=[8491], train_mask=[8491], val_mask=[8491], test_mask=[8491], n_id=[8491], e_id=[20010], input_id=[1024], batch_size=1024)
Subgraph 3: Data(x=[8377, 767], edge_index=[2, 19724], y=[8377], train_mask=[8377], val_mask=[8377], test_mask=[8377], n_id=[8377], e_id=[19724], input_id=[1024], batch_size=1024)
Subgraph 4: Data(x=[8460, 767], edge_index=[2, 19962], y=[8460], train_mask=[8460], val_mask=[8460], test_mask=[8460], n_id=[8460], e_id=[19962], input_id=[1024], batch_size=1024)
Subgraph 5: Data(x=[8399, 767], edge_index=[2, 19880], y=[8399], train_mask=[8399], val_mask=[8399],

Minibatch 시각화

In [None]:
# Plot each subgraph
fig = plt.figure(figsize=(8,8))
for idx, (subdata, pos) in enumerate(zip(train_loader, [221, 222, 223, 224])):
    G = to_networkx(subdata, to_undirected=True)
    ax = fig.add_subplot(pos)
    ax.set_title(f'Subgraph {idx}', fontsize=10)
    plt.axis('off')
    nx.draw_networkx(G,
                    pos=nx.spring_layout(G, seed=0),
                    with_labels=False,
                    node_color=subdata.y,
                    node_size = 100
                    )
plt.show()

In [29]:
import torch
torch.manual_seed(-1)
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

GraphSAGE 모델 구현

In [30]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, dim_in, dim_h, dim_out):
        super().__init__()
        self.sage1 = SAGEConv(dim_in, dim_h) # default = mean aggregator
        self.sage2 = SAGEConv(dim_h, dim_out)

    def forward(self, x, edge_index):
        h = self.sage1(x, edge_index)
        h = torch.relu(h)
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.sage2(h, edge_index)
        h = F.log_softmax(h, dim=1)
        return h

    def fit(self, data, epochs):
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01)

        self.train()
        for epoch in range(epochs+1):
            total_loss = 0
            acc = 0
            val_loss = 0
            val_acc = 0
            # Train on batches
            for batch in train_loader:
                optimizer.zero_grad()
                out = self(batch.x, batch.edge_index)
                loss = criterion(out[batch.train_mask], batch.y[batch.train_mask])
                total_loss += loss.item()
                acc += accuracy(out[batch.train_mask].argmax(dim=1), batch.y[batch.train_mask])
                loss.backward()
                optimizer.step()

                # Validation
                val_loss += criterion(out[batch.val_mask], batch.y[batch.val_mask])
                val_acc += accuracy(out[batch.val_mask].argmax(dim=1), batch.y[batch.val_mask])
                print(f'Batch : {batch}')
            # Print metrics every 10 epochs
            #if epoch % 20 == 0:
            print(f'Epoch {epoch:>3} | Train Loss: {loss/len(train_loader):.3f} | Train Acc: {acc/len(train_loader)*100:>6.2f}% | Val Loss: {val_loss/len(train_loader):.2f} | Val Acc: {val_acc/len(train_loader)*100:.2f}%')


    @torch.no_grad()
    def test(self, data):
        self.eval()
        out = self(data.x, data.edge_index)
        acc = accuracy(out.argmax(dim=1)[data.test_mask], data.y[data.test_mask])
        return acc

def accuracy(pred_y, y):
    return ((pred_y == y).sum() / len(y)).item()

200번의 학습 (Epoch)을 통해

정확도: 77.9% 를 얻음

In [31]:
# Create GraphSAGE
graphsage = GraphSAGE(dataset.num_features, 64, dataset.num_classes)
print(graphsage)

# Train
graphsage.fit(data, 200)

# Test
acc = graphsage.test(data)
print(f'GraphSAGE test accuracy: {acc*100:.2f}%')

GraphSAGE(
  (sage1): SAGEConv(767, 64, aggr=mean)
  (sage2): SAGEConv(64, 10, aggr=mean)
)
Batch : Data(x=[8325, 767], edge_index=[2, 19800], y=[8325], train_mask=[8325], val_mask=[8325], test_mask=[8325], n_id=[8325], e_id=[19800], input_id=[1024], batch_size=1024)
Batch : Data(x=[8413, 767], edge_index=[2, 19762], y=[8413], train_mask=[8413], val_mask=[8413], test_mask=[8413], n_id=[8413], e_id=[19762], input_id=[1024], batch_size=1024)
Batch : Data(x=[8469, 767], edge_index=[2, 19972], y=[8469], train_mask=[8469], val_mask=[8469], test_mask=[8469], n_id=[8469], e_id=[19972], input_id=[1024], batch_size=1024)
Batch : Data(x=[8461, 767], edge_index=[2, 19667], y=[8461], train_mask=[8461], val_mask=[8461], test_mask=[8461], n_id=[8461], e_id=[19667], input_id=[1024], batch_size=1024)
Batch : Data(x=[8443, 767], edge_index=[2, 19936], y=[8443], train_mask=[8443], val_mask=[8443], test_mask=[8443], n_id=[8443], e_id=[19936], input_id=[1024], batch_size=1024)
Batch : Data(x=[8413, 767], 