필요한 Library

In [1]:
from torch_geometric.datasets import Reddit
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import to_networkx
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt

Reddit Dataset

In [2]:
dataset = Reddit(root='./Reddit_data')
data = dataset[0]

Downloading https://data.dgl.ai/dataset/reddit.zip
Extracting Reddit_data\raw\reddit.zip
Processing...
Done!


Graph 구조와 특징

In [3]:
# Print information about the dataset
print(f'Dataset: {dataset}')
print('-------------------')
print(f'Number of graphs: {len(dataset)}')
print(f'Number of nodes: {data.x.shape[0]}')
print(f'Number of features: {dataset.num_features}')
print(f'Number of classes: {dataset.num_classes}')

# Print information about the graph
print(f'\nGraph:')
print('------')
print(f'Training nodes: {sum(data.train_mask).item()}')
print(f'Evaluation nodes: {sum(data.val_mask).item()}')
print(f'Test nodes: {sum(data.test_mask).item()}')
print(f'Edges are directed: {data.is_directed()}')
print(f'Graph has isolated nodes: {data.has_isolated_nodes()}')
print(f'Graph has loops: {data.has_self_loops()}')

Dataset: Reddit()
-------------------
Number of graphs: 1
Number of nodes: 232965
Number of features: 602
Number of classes: 41

Graph:
------
Training nodes: 153431
Evaluation nodes: 23831
Test nodes: 55703
Edges are directed: False
Graph has isolated nodes: False
Graph has loops: False


Minibatch 생성


Minibatch란 대용량의 Graph를 나누어 학습하기 위해 나온 것으로
학습할 노드들을 랜덤하게 선택하여 하나의 Minibatch를 생성

한 batch 당 약 16000(2^14)개의 Training노드로 총 9개의 batch 생성

In [9]:
# Create batches with neighbor sampling
train_loader = NeighborLoader(
    data,
    num_neighbors=[5,5],
    batch_size=16384,
    input_nodes=data.train_mask,
)

생성된 Minibatch의 설명

In [10]:
# Print each subgraph
for i, subgraph in enumerate(train_loader):
    print(f'Subgraph {i}: {subgraph}')

Subgraph 0: Data(x=[139350, 602], edge_index=[2, 343014], y=[139350], train_mask=[139350], val_mask=[139350], test_mask=[139350], n_id=[139350], e_id=[343014], input_id=[16384], batch_size=16384)
Subgraph 1: Data(x=[138968, 602], edge_index=[2, 342314], y=[138968], train_mask=[138968], val_mask=[138968], test_mask=[138968], n_id=[138968], e_id=[342314], input_id=[16384], batch_size=16384)
Subgraph 2: Data(x=[139633, 602], edge_index=[2, 344114], y=[139633], train_mask=[139633], val_mask=[139633], test_mask=[139633], n_id=[139633], e_id=[344114], input_id=[16384], batch_size=16384)
Subgraph 3: Data(x=[139827, 602], edge_index=[2, 344004], y=[139827], train_mask=[139827], val_mask=[139827], test_mask=[139827], n_id=[139827], e_id=[344004], input_id=[16384], batch_size=16384)
Subgraph 4: Data(x=[139957, 602], edge_index=[2, 344865], y=[139957], train_mask=[139957], val_mask=[139957], test_mask=[139957], n_id=[139957], e_id=[344865], input_id=[16384], batch_size=16384)
Subgraph 5: Data(x=[

Minibatch 시각화

In [None]:
# Plot each subgraph
fig = plt.figure(figsize=(12,12))
for idx, (subdata, pos) in enumerate(zip(train_loader, [331, 332, 333, 334, 335, 336, 337, 338, 339])):
    G = to_networkx(subdata, to_undirected=True)
    ax = fig.add_subplot(pos)
    ax.set_title(f'Subgraph {idx}', fontsize=10)
    plt.axis('off')
    nx.draw_networkx(G,
                    pos=nx.spring_layout(G, seed=0),
                    with_labels=False,
                    node_color=subdata.y,
                    node_size = 100
                    )
    print("done,{pos}")
plt.show()

시간이 너무 오래 걸려서 중단

![Alt text](image-1.png)

![image.png](attachment:image.png)

In [15]:
import torch
torch.manual_seed(-1)
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv

GraphSAGE 모델 구현

In [16]:
class GraphSAGE(torch.nn.Module):
    def __init__(self, dim_in, dim_h, dim_out):
        super().__init__()
        self.sage1 = SAGEConv(dim_in, dim_h) # default = mean aggregator
        self.sage2 = SAGEConv(dim_h, dim_out)

    def forward(self, x, edge_index):
        h = self.sage1(x, edge_index)
        h = torch.relu(h)
        h = F.dropout(h, p=0.5, training=self.training)
        h = self.sage2(h, edge_index)
        h = F.log_softmax(h, dim=1)
        return h

    def fit(self, data, epochs):
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.parameters(), lr=0.01)

        self.train()
        for epoch in range(epochs+1):
            total_loss = 0
            acc = 0
            val_loss = 0
            val_acc = 0
            # Train on batches
            for batch in train_loader:
                optimizer.zero_grad()
                out = self(batch.x, batch.edge_index)
                loss = criterion(out[batch.train_mask], batch.y[batch.train_mask])
                total_loss += loss.item()
                acc += accuracy(out[batch.train_mask].argmax(dim=1), batch.y[batch.train_mask])
                loss.backward()
                optimizer.step()

                # Validation
                val_loss += criterion(out[batch.val_mask], batch.y[batch.val_mask])
                val_acc += accuracy(out[batch.val_mask].argmax(dim=1), batch.y[batch.val_mask])
                print(f'Batch : {batch}')
            # Print metrics every 10 epochs
            #if epoch % 20 == 0:
            print(f'Epoch {epoch:>3} | Train Loss: {loss/len(train_loader):.3f} | Train Acc: {acc/len(train_loader)*100:>6.2f}% | Val Loss: {val_loss/len(train_loader):.2f} | Val Acc: {val_acc/len(train_loader)*100:.2f}%')


    @torch.no_grad()
    def test(self, data):
        self.eval()
        out = self(data.x, data.edge_index)
        acc = accuracy(out.argmax(dim=1)[data.test_mask], data.y[data.test_mask])
        return acc

def accuracy(pred_y, y):
    return ((pred_y == y).sum() / len(y)).item()

200번의 학습 (Epoch)을 통해

Epoch 200 | Train Loss: 0.084 | Train Acc:  79.71% | Val Loss: 0.99 | Val Acc: 73.38%

GraphSAGE test accuracy: 95.79%

정확도: 95.79% 를 얻음

In [None]:
# Create GraphSAGE
graphsage = GraphSAGE(dataset.num_features, 64, dataset.num_classes)
print(graphsage)

# Train
graphsage.fit(data, 200)

# Test
acc = graphsage.test(data)
print(f'GraphSAGE test accuracy: {acc*100:.2f}%')

![Alt text](image.png)