In [1]:
!pip install torch_geometric --quiet

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/661.6 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m194.6/661.6 kB[0m [31m5.7 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m655.4/661.6 kB[0m [31m10.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m8.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Building wheel for torch_geometric (pyproject.toml) ... [?25l[?25hdone


In [2]:
import torch_geometric
from torch_geometric.datasets import Planetoid
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import GATConv

In [3]:
dataset = Planetoid(root='tutorial2', name='Cora')

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


In [4]:
print(dataset)
print("Number of graphs", len(dataset))
print("Number of classes", dataset.num_classes)
print("Number of node features", dataset.num_node_features)
print("Number of edge features", dataset.num_edge_features)

Cora()
Number of graphs 1
Number of classes 7
Number of node features 1433
Number of edge features 0


In [5]:
print(dataset.data)

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])




In [6]:
torch.manual_seed(123)

<torch._C.Generator at 0x79ca7a6eccb0>

In [7]:
class Net(nn.Module):
  def __init__(self, in_dim, n_heads, hidden_dim, out_dim, dropout=0.2):
    super().__init__()
    self.dropout = dropout
    self.conv1 = GATConv(in_dim, hidden_dim, n_heads, dropout=dropout)
    self.conv2 = GATConv(hidden_dim*n_heads, hidden_dim, n_heads, dropout=dropout)
    self.out = nn.Linear(hidden_dim*n_heads, out_dim)

  def forward(self, data):
    hidden = F.relu(self.conv1(data.x, data.edge_index))
    hidden = F.relu(self.conv2(hidden, data.edge_index))
    out = self.out(hidden)
    return out


In [8]:
data = dataset[0]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(dataset.num_node_features, 8, 128, dataset.num_classes).to(device)
data = data.to(device)
opt = torch.optim.Adam(model.parameters(), lr=3e-4)
loss_fn = nn.CrossEntropyLoss()
epochs = 100

In [9]:
def train(data, opt):
  opt.zero_grad()
  out = model(data)[data.train_mask]
  loss = loss_fn(out, data.y[data.train_mask])
  loss.backward()
  opt.step()

def test(data):
  model.eval()
  with torch.no_grad():
    y_pred = model(data)[data.test_mask]
    loss = loss_fn(y_pred, data.y[data.test_mask])
    y_pred = F.softmax(y_pred, dim=-1)
    y_pred = y_pred.argmax(dim=-1)
    test_acc = (y_pred == data.y[data.test_mask]).sum()/len(data.y[data.test_mask])
    return loss, test_acc


In [10]:
for epoch in range(epochs):
  train(data, opt)
  loss, test_acc = test(data)
  if epoch %5 ==0 :
    print(f"Epoch: {epoch} | Loss: {loss.item():3f} | Test accuracy: {test_acc.item():3f}")

Epoch: 0 | Loss: 1.934225 | Test accuracy: 0.124000
Epoch: 5 | Loss: 1.857206 | Test accuracy: 0.758000
Epoch: 10 | Loss: 1.732530 | Test accuracy: 0.810000
Epoch: 15 | Loss: 1.530774 | Test accuracy: 0.808000
Epoch: 20 | Loss: 1.249939 | Test accuracy: 0.807000
Epoch: 25 | Loss: 0.942373 | Test accuracy: 0.805000
Epoch: 30 | Loss: 0.713364 | Test accuracy: 0.808000
Epoch: 35 | Loss: 0.608776 | Test accuracy: 0.809000
Epoch: 40 | Loss: 0.587530 | Test accuracy: 0.803000
Epoch: 45 | Loss: 0.602117 | Test accuracy: 0.800000
Epoch: 50 | Loss: 0.627292 | Test accuracy: 0.802000
Epoch: 55 | Loss: 0.653813 | Test accuracy: 0.803000
Epoch: 60 | Loss: 0.678684 | Test accuracy: 0.805000
Epoch: 65 | Loss: 0.699459 | Test accuracy: 0.806000
Epoch: 70 | Loss: 0.715332 | Test accuracy: 0.806000
Epoch: 75 | Loss: 0.726906 | Test accuracy: 0.805000
Epoch: 80 | Loss: 0.735326 | Test accuracy: 0.804000
Epoch: 85 | Loss: 0.741631 | Test accuracy: 0.805000
Epoch: 90 | Loss: 0.746503 | Test accuracy: 0.80