In [1]:
# node2vec

import argparse

import torch
from torch_geometric.nn import Node2Vec

from ogb.nodeproppred import PygNodePropPredDataset

In [2]:
device = f'cuda:0' if torch.cuda.is_available() else 'cpu'
device = torch.device(device)
print(device)

cuda:0


In [3]:
dataset = PygNodePropPredDataset(name = "ogbn-products", root = 'dataset/')
data = dataset[0]
print(data)

Data(edge_index=[2, 123718280], x=[2449029, 100], y=[2449029, 1])


In [4]:
def save_embedding(model):
    torch.save(model.embedding.weight.data.cpu(), 'embedding.pt')

In [5]:
embedding_dim = 128
walk_length = 40
context_size = 20
walks_per_node = 10
batch_size = 256
lr = 0.01
epochs = 1
log_steps = 100

model = Node2Vec(data.edge_index, embedding_dim, walk_length,
                     context_size, walks_per_node,
                     sparse=True).to(device)

loader = model.loader(batch_size=batch_size, shuffle=True,
                          num_workers=4)
    
optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=lr)

In [6]:
model.train()

for epoch in range(1, epochs + 1):
    for i, (pos_rw, neg_rw) in enumerate(loader):
        optimizer.zero_grad()
        loss = model.loss(pos_rw.to(device), neg_rw.to(device))
        loss.backward()
        optimizer.step()

        if (i + 1) % log_steps == 0:
            print(f'Epoch: {epoch:02d}, Step: {i+1:03d}/{len(loader)}, '
                  f'Loss: {loss:.4f}')

        if (i + 1) % 100 == 0:  # Save model every 100 steps.
            save_embedding(model)
    save_embedding(model)

Epoch: 01, Step: 100/9567, Loss: 9.4703
Epoch: 01, Step: 200/9567, Loss: 8.5829
Epoch: 01, Step: 300/9567, Loss: 7.7455
Epoch: 01, Step: 400/9567, Loss: 6.9069
Epoch: 01, Step: 500/9567, Loss: 6.1520
Epoch: 01, Step: 600/9567, Loss: 5.4254
Epoch: 01, Step: 700/9567, Loss: 4.7844
Epoch: 01, Step: 800/9567, Loss: 4.2113
Epoch: 01, Step: 900/9567, Loss: 3.7529
Epoch: 01, Step: 1000/9567, Loss: 3.3591
Epoch: 01, Step: 1100/9567, Loss: 3.0074
Epoch: 01, Step: 1200/9567, Loss: 2.7340
Epoch: 01, Step: 1300/9567, Loss: 2.4927
Epoch: 01, Step: 1400/9567, Loss: 2.3274
Epoch: 01, Step: 1500/9567, Loss: 2.1452
Epoch: 01, Step: 1600/9567, Loss: 1.9810
Epoch: 01, Step: 1700/9567, Loss: 1.8911
Epoch: 01, Step: 1800/9567, Loss: 1.7995
Epoch: 01, Step: 1900/9567, Loss: 1.7030
Epoch: 01, Step: 2000/9567, Loss: 1.6219
Epoch: 01, Step: 2100/9567, Loss: 1.5631
Epoch: 01, Step: 2200/9567, Loss: 1.4980
Epoch: 01, Step: 2300/9567, Loss: 1.4357
Epoch: 01, Step: 2400/9567, Loss: 1.3865
Epoch: 01, Step: 2500/956

KeyboardInterrupt: 