In [62]:
import os.path as osp

import torch
import torch.nn.functional as F

from torch_geometric.data import Data
from torch_geometric.datasets import Planetoid
from torch_geometric.nn import SGConv

In [52]:
dataset = 'Cora'
path = '/Users/aqib/desktop/bccl/rna_tn_fusion'

In [53]:
dataset = Planetoid(path, dataset)
data = dataset[0]
len(dataset), dataset.num_classes, dataset.num_node_features

(1, 7, 1433)

In [61]:
data

Data(x=[2708, 1433], edge_index=[2, 10556], y=[2708], train_mask=[2708], val_mask=[2708], test_mask=[2708])

In [55]:
data.is_undirected(), data.train_mask.sum().item(), data.val_mask.sum().item(), data.test_mask.sum().item()

(True, 140, 500, 1000)

In [59]:
data.edge_index

tensor([[   0,    0,    0,  ..., 2707, 2707, 2707],
        [ 633, 1862, 2582,  ...,  598, 1473, 2706]])

In [23]:
class Net(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = SGConv(dataset.num_features, dataset.num_classes, K=2,
                            cached=True)

    def forward(self):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        return F.log_softmax(x, dim=1)

In [24]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.2, weight_decay=0.005)

In [29]:
def train():
    model.train()
    optimizer.zero_grad()
    F.nll_loss(model()[data.train_mask], data.y[data.train_mask]).backward()
    optimizer.step()

In [35]:
def test():
    model.eval()
    logits, accs = model(), []
    for _, mask in data('train_mask', 'val_mask', 'test_mask'):
        pred = logits[mask].max(1)[1]
        acc = pred.eq(data.y[mask]).sum().item() / mask.sum().item()
        accs.append(acc)
    return accs

In [39]:
best_val_acc = test_acc = 0
for epoch in range(1, 101):
    train()
    train_acc, val_acc, tmp_test_acc = test()
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        test_acc = tmp_test_acc
    print(f'Epoch: {epoch:03d}, Train: {train_acc:.4f}, '
          f'Val: {best_val_acc:.4f}, Test: {test_acc:.4f}')

Epoch: 001, Train: 0.9643, Val: 0.7200, Test: 0.7460
Epoch: 002, Train: 0.9857, Val: 0.7700, Test: 0.7810
Epoch: 003, Train: 0.9929, Val: 0.7780, Test: 0.7840
Epoch: 004, Train: 0.9929, Val: 0.7780, Test: 0.7840
Epoch: 005, Train: 0.9929, Val: 0.7780, Test: 0.7840
Epoch: 006, Train: 1.0000, Val: 0.7780, Test: 0.7840
Epoch: 007, Train: 1.0000, Val: 0.7780, Test: 0.7840
Epoch: 008, Train: 1.0000, Val: 0.7780, Test: 0.7840
Epoch: 009, Train: 0.9929, Val: 0.7780, Test: 0.7840
Epoch: 010, Train: 0.9929, Val: 0.7780, Test: 0.7840
Epoch: 011, Train: 0.9929, Val: 0.7780, Test: 0.7840
Epoch: 012, Train: 0.9929, Val: 0.7780, Test: 0.7840
Epoch: 013, Train: 0.9929, Val: 0.7780, Test: 0.7840
Epoch: 014, Train: 1.0000, Val: 0.7820, Test: 0.8170
Epoch: 015, Train: 1.0000, Val: 0.7820, Test: 0.8170
Epoch: 016, Train: 1.0000, Val: 0.7820, Test: 0.8170
Epoch: 017, Train: 1.0000, Val: 0.7820, Test: 0.8170
Epoch: 018, Train: 1.0000, Val: 0.7820, Test: 0.8170
Epoch: 019, Train: 1.0000, Val: 0.7860, Test: 