<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Graph_Neural_Networks_(GNNs).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
pip install torch_geometric

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import DataLoader
from torch_geometric.transforms import NormalizeFeatures

# Load the dataset (using the Cora dataset as an example)
dataset = Planetoid(root='/tmp/Cora', name='Cora', transform=NormalizeFeatures())

# Define the GCN model
class GCN(torch.nn.Module):
    def __init__(self, num_features, num_classes):
        super().__init__()
        self.conv1 = GCNConv(num_features, 16)
        self.conv2 = GCNConv(16, num_classes)

    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

# Initialize the model, optimizer, and loss function
model = GCN(dataset.num_features, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

# Training loop
def train():
    model.train()
    optimizer.zero_grad()
    out = model(dataset[0])
    loss = F.nll_loss(out[dataset[0].train_mask], dataset[0].y[dataset[0].train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

# Test function
def test():
    model.eval()
    logits, accs = model(dataset[0]), []
    for _, mask in dataset[0]('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(dataset[0].y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs

# Run the training process
for epoch in range(200):
    loss = train()
    if epoch % 10 == 0:
        accs = test()
        print(f'Epoch: {epoch:03d}, Loss: {loss:.4f}, Train: {accs[0]:.4f}, Val: {accs[1]:.4f}, Test: {accs[2]:.4f}')

# Final test accuracy
accs = test()
print(f'Final Test Accuracy: {accs[2]:.4f}')