In [1]:
!pip install torch_geometric > /dev/null

In [2]:
class Args:
  def __init__(self):
    self.root_dir = "/content"
    self.data_dir = "/content/data"
    self.epochs = 300
    self.runs = 5
    self.droput = 0.4
    self.lr = 0.001
    self.wd = 0.001
    self.num_layers = 2
    self.num_hidden = 256
    self.num_features = 0 # placeholder
    self.num_classes = 0 # placeholder

def add_data_features(args, data):
  args.num_features = data.x.shape[1]
  args.num_classes = data.y.shape[0]
  return args

In [3]:
from torch_geometric.datasets import AMiner

args = Args()
dataset = AMiner(root=args.root_dir)
print(dataset[0])

Downloading https://www.dropbox.com/s/1bnz8r7mofx0osf/net_aminer.zip?dl=1
Extracting /content/net_aminer.zip
Downloading https://www.dropbox.com/s/nkocx16rpl4ydde/label.zip?dl=1
Extracting /content/raw/label.zip
Processing...
Done!


HeteroData(
  author={
    y=[246678],
    y_index=[246678],
    num_nodes=1693531,
  },
  venue={
    y=[134],
    y_index=[134],
    num_nodes=3883,
  },
  paper={ num_nodes=3194405 },
  (paper, written_by, author)={ edge_index=[2, 9323605] },
  (author, writes, paper)={ edge_index=[2, 9323605] },
  (paper, published_in, venue)={ edge_index=[2, 3194405] },
  (venue, publishes, paper)={ edge_index=[2, 3194405] }
)


In [5]:
import torch
from torch_geometric.nn import MetaPath2Vec
data = dataset[0]

# defining all 'metapaths' as tuples
metapath = [
    ('author', 'writes', 'paper'),
    ('paper', 'published_in', 'venue'),
    ('venue', 'publishes', 'paper'),
    ('paper', 'written_by', 'author'),
]

# select device and display it to user
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"DEVICE USED IN THIS SESSION: {device}")


model = MetaPath2Vec(data.edge_index_dict, embedding_dim=128,
                     metapath=metapath, walk_length=50, context_size=7,
                     walks_per_node=5, num_negative_samples=5,
                     sparse=True).to(device)

loader = model.loader(batch_size=128, shuffle=True, num_workers=6)
optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)


def train(epoch, log_steps=100, eval_steps=2000):
    model.train()

    total_loss = 0
    for i, (pos_rw, neg_rw) in enumerate(loader):
        optimizer.zero_grad()
        loss = model.loss(pos_rw.to(device), neg_rw.to(device))
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        if (i + 1) % log_steps == 0:
            print((f'Epoch: {epoch}, Step: {i + 1:05d}/{len(loader)}, '
                   f'Loss: {total_loss / log_steps:.4f}'))
            total_loss = 0

        if (i + 1) % eval_steps == 0:
            acc = test()
            print((f'Epoch: {epoch}, Step: {i + 1:05d}/{len(loader)}, '
                   f'Acc: {acc:.4f}'))


def test(train_ratio=0.1):
    model.eval()

    z = model('author', batch=data['author'].y_index.to(device))
    y = data['author'].y

    perm = torch.randperm(z.size(0))
    train_perm = perm[:int(z.size(0) * train_ratio)]
    test_perm = perm[int(z.size(0) * train_ratio):]

    return model.test(z[train_perm], y[train_perm], z[test_perm], y[test_perm],
                      max_iter=150)


for epoch in range(1, 5):
  train(epoch)
  test_acc = test()
  print(f"Epoch: {epoch} | Test Accuracy: {test_acc:.4f} ")

DEVICE USED IN THIS SESSION: cuda
Epoch: 1, Step: 00100/13231, Loss: 9.1106
Epoch: 1, Step: 00200/13231, Loss: 7.5786
Epoch: 1, Step: 00300/13231, Loss: 6.4659
Epoch: 1, Step: 00400/13231, Loss: 5.8515
Epoch: 1, Step: 00500/13231, Loss: 5.5902
Epoch: 1, Step: 00600/13231, Loss: 5.4337
Epoch: 1, Step: 00700/13231, Loss: 5.2960
Epoch: 1, Step: 00800/13231, Loss: 5.1691
Epoch: 1, Step: 00900/13231, Loss: 5.0432
Epoch: 1, Step: 01000/13231, Loss: 4.9267
Epoch: 1, Step: 01100/13231, Loss: 4.8128
Epoch: 1, Step: 01200/13231, Loss: 4.7019
Epoch: 1, Step: 01300/13231, Loss: 4.5944
Epoch: 1, Step: 01400/13231, Loss: 4.4930
Epoch: 1, Step: 01500/13231, Loss: 4.3938
Epoch: 1, Step: 01600/13231, Loss: 4.2975
Epoch: 1, Step: 01700/13231, Loss: 4.2022
Epoch: 1, Step: 01800/13231, Loss: 4.1149
Epoch: 1, Step: 01900/13231, Loss: 4.0240
Epoch: 1, Step: 02000/13231, Loss: 3.9382
Epoch: 1, Step: 02000/13231, Acc: 0.2858
Epoch: 1, Step: 02100/13231, Loss: 3.8537
Epoch: 1, Step: 02200/13231, Loss: 3.7717
E