In [4]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data

# Define the GNN model
class GNN(torch.nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(input_dim, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, output_dim)

    def forward(self, x, edge_index):
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        print("x before log softmax", x)
        return F.log_softmax(x, dim=1)

# Create a toy graph
# Number of nodes: 4
# Edges: (0 -> 1), (1 -> 2), (2 -> 3), (3 -> 0)
edge_index = torch.tensor([
    [0, 1, 2, 3],
    [1, 2, 3, 0]
], dtype=torch.long)

# Node features (4 nodes, each with 3 features)
x = torch.tensor([
    [1.0, 0.0, 1.0],
    [0.0, 1.0, 1.0],
    [1.0, 1.0, 0.0],
    [0.0, 0.0, 1.0]
], dtype=torch.float)

# Labels (binary classification)
y = torch.tensor([0, 1, 0, 1], dtype=torch.long)

# Define the dataset
data = Data(x=x, edge_index=edge_index, y=y)

# Initialize the model, optimizer, and loss function
model = GNN(input_dim=3, hidden_dim=16, output_dim=2)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.CrossEntropyLoss()

# Training loop
for epoch in range(50):
    model.train()
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    print("out", out, "y", data.y)
    loss = loss_fn(out, data.y)
    loss.backward()
    optimizer.step()
    print(f"Epoch {epoch+1}, Loss: {loss.item():.4f}")

# Testing
model.eval()
_, pred = model(data.x, data.edge_index).max(dim=1)
accuracy = (pred == data.y).sum().item() / data.y.size(0)
print(f"Accuracy: {accuracy:.4f}")


x before log softmax tensor([[-0.5952,  0.7045],
        [-0.7346,  0.8649],
        [-0.6142,  0.6207],
        [-0.4748,  0.4603]], grad_fn=<AddBackward0>)
out tensor([[-1.5408, -0.2411],
        [-1.7834, -0.1840],
        [-1.4902, -0.2553],
        [-1.2663, -0.3311]], grad_fn=<LogSoftmaxBackward0>) y tensor([0, 1, 0, 1])
Epoch 1, Loss: 0.8865
x before log softmax tensor([[-0.5281,  0.5923],
        [-0.6581,  0.7390],
        [-0.5389,  0.4968],
        [-0.4089,  0.3500]], grad_fn=<AddBackward0>)
out tensor([[-1.4027, -0.2823],
        [-1.6181, -0.2210],
        [-1.3395, -0.3038],
        [-1.1430, -0.3840]], grad_fn=<LogSoftmaxBackward0>) y tensor([0, 1, 0, 1])
Epoch 2, Loss: 0.8368
x before log softmax tensor([[-0.4642,  0.4868],
        [-0.5849,  0.6203],
        [-0.4659,  0.3773],
        [-0.3452,  0.2438]], grad_fn=<AddBackward0>)
out tensor([[-1.2777, -0.3267],
        [-1.4672, -0.2621],
        [-1.2011, -0.3579],
        [-1.0304, -0.4414]], grad_fn=<LogSoftmaxBack