In [1]:
!pip install torch torch_geometric

Collecting torch_geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/63.1 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m3.5 MB/s[0m eta [36m0:00:00[0m
Downloading torch_geometric-2.6.1-py3-none-any.whl (1.1 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.1/1.1 MB[0m [31m34.6 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torch_geometric
Successfully installed torch_geometric-2.6.1


In [10]:
import torch
from torch_geometric.datasets import CoraFull
from torch_geometric.data import DataLoader
dataset = CoraFull(root="cora")
data = dataset[0]

In [5]:
class MeanAggregator(torch.nn.Module):
    def __init__(self, in_features, out_features, neigh_input_dim):
        super(MeanAggregator, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.neigh_input_dim = neigh_input_dim

        self.w = torch.nn.Parameter(torch.empty(neigh_input_dim + in_features, out_features))
        self.reset_parameters()

    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.w)

    def forward(self, x, sampled_neighbors):
        neighbors_message = torch.mean(sampled_neighbors, dim=0)
        message = torch.concat((x, neighbors_message), 0)
        weighted_message = torch.matmul(message, self.w)
        return weighted_message

In [9]:
class GraphSAGE(torch.nn.Module):
  def __init__(self):
    super(GraphSAGE, self).__init__()
    self.aggr1 = MeanAggregator(8710, 5000, 8710)
    self.aggr2 = MeanAggregator(5000, 1000, 8710)
    self.aggr3 = MeanAggregator(1000, 70, 8710)

  def forward(self, input):
    x, neighbors = input
    h = torch.nn.functional.relu(self.aggr1(x, neighbors))
    h = torch.nn.functional.relu(self.aggr2(h, neighbors))
    logits = self.aggr3(h, neighbors)
    return logits

IndentationError: expected an indented block after function definition on line 8 (<ipython-input-9-315a5666b769>, line 9)

In [7]:
model = GraphSAGE()
x = data.x[0]
mask = data.edge_index[0] == 0
node_neighbors = torch.stack([data.x[i] for i in data.edge_index[1][mask]])

y = data.y[0]
model((x, node_neighbors))
# print(model.aggr2.w)

tensor([-0.0450,  0.0170,  0.0062, -0.0871,  0.0332,  0.1264, -0.0984, -0.0367,
         0.0455, -0.1104, -0.0043, -0.0519, -0.0020,  0.0112, -0.0288, -0.0112,
         0.0553,  0.0520, -0.0274, -0.0068,  0.0446,  0.0610, -0.0944,  0.0494,
        -0.0100, -0.0078, -0.0226,  0.0992, -0.0707,  0.0604,  0.0511, -0.0168,
         0.0122,  0.0554,  0.0358,  0.0296,  0.0853,  0.0031, -0.0239,  0.0617,
        -0.0314, -0.0271,  0.0633,  0.0466,  0.0431,  0.0557,  0.0735, -0.0234,
         0.0585, -0.0052, -0.0531, -0.0743,  0.0513, -0.0027,  0.0340,  0.0230,
        -0.0544,  0.0193, -0.0364,  0.0120, -0.0284, -0.0634,  0.0141, -0.1021,
         0.0586,  0.0361,  0.0428, -0.0134, -0.0504, -0.0847],
       grad_fn=<SqueezeBackward4>)

In [8]:
model = GraphSAGE()
loss_fn= torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)

epochs = 1
for epoch in range(100):
  for i in range(len(data.x[0])):
    optimizer.zero_grad()
    x = data.x[i]
    mask = data.edge_index[0] == i
    node_neighbors = torch.stack([data.x[i] for i in data.edge_index[1][mask]])

    # node_neighbors = torch.stack([data.x[v] for u, v in zip(data.edge_index[0], data.edge_index[1]) if data.edge_index[0][u] == i])
    y = data.y[i]
    y = y.unsqueeze(0)
    pred = model((x, node_neighbors))
    pred = pred.unsqueeze(0)
    loss = loss_fn(pred, y)
    loss.backward()
    optimizer.step()
    print(loss)

tensor(4.2856, grad_fn=<NllLossBackward0>)
tensor(4.2110, grad_fn=<NllLossBackward0>)
tensor(4.0313, grad_fn=<NllLossBackward0>)
tensor(4.0094, grad_fn=<NllLossBackward0>)
tensor(3.6952, grad_fn=<NllLossBackward0>)
tensor(3.5298, grad_fn=<NllLossBackward0>)
tensor(3.8693, grad_fn=<NllLossBackward0>)
tensor(3.6966, grad_fn=<NllLossBackward0>)
tensor(3.7482, grad_fn=<NllLossBackward0>)
tensor(3.6847, grad_fn=<NllLossBackward0>)
tensor(3.4109, grad_fn=<NllLossBackward0>)
tensor(3.3922, grad_fn=<NllLossBackward0>)
tensor(4.3490, grad_fn=<NllLossBackward0>)
tensor(3.1704, grad_fn=<NllLossBackward0>)
tensor(2.9795, grad_fn=<NllLossBackward0>)
tensor(2.8110, grad_fn=<NllLossBackward0>)
tensor(2.6908, grad_fn=<NllLossBackward0>)
tensor(2.3116, grad_fn=<NllLossBackward0>)
tensor(2.1776, grad_fn=<NllLossBackward0>)
tensor(4.2774, grad_fn=<NllLossBackward0>)
tensor(2.0424, grad_fn=<NllLossBackward0>)


KeyboardInterrupt: 