In [1]:
!pip install torch-geometric \
  torch-sparse \
  torch-scatter \
  torch-cluster \
  torch-cluster \
  pyg-lib \
  -f https://data.pyg.org/whl/torch-2.4.0+cu124.html

Looking in links: https://data.pyg.org/whl/torch-2.4.0+cu124.html
Collecting torch-geometric
  Downloading torch_geometric-2.6.1-py3-none-any.whl.metadata (63 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m63.1/63.1 kB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-sparse
  Downloading https://data.pyg.org/whl/torch-2.4.0%2Bcu124/torch_sparse-0.6.18%2Bpt24cu124-cp310-cp310-linux_x86_64.whl (5.1 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.1/5.1 MB[0m [31m19.7 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-scatter
  Downloading https://data.pyg.org/whl/torch-2.4.0%2Bcu124/torch_scatter-2.1.2%2Bpt24cu124-cp310-cp310-linux_x86_64.whl (10.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.7/10.7 MB[0m [31m17.5 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting torch-cluster
  Downloading https://data.pyg.org/whl/torch-2.4.0%2Bcu124/torch_cluster-1.6.3%2Bpt24cu124-cp310-cp310-linux_x86_64.whl (3.4 M

In [2]:
import networkx as nx
import random
import numpy as np
from typing import List
from tqdm import tqdm
from gensim.models.word2vec import Word2Vec


class DeepWalk:
    def __init__(self, window_size: int, embedding_size: int, walk_length: int, walks_per_node: int):
        """
        :param window_size: window size for the Word2Vec model
        :param embedding_size: size of the final embedding
        :param walk_length: length of the walk
        :param walks_per_node: number of walks per node
        """
        self.window_size = window_size
        self.embedding_size = embedding_size
        self.walk_length = walk_length
        self.walk_per_node = walks_per_node

    def random_walk(self, g: nx.Graph, start: str, use_probabilities: bool = False) -> List[str]:
        """
        Generate a random walk starting on start
        :param g: Graph
        :param start: starting node for the random walk
        :param use_probabilities: if True take into account the weights assigned to each edge to select the next candidate
        :return:
        """
        walk = [start]
        for i in range(self.walk_length):
            neighbours = g.neighbors(walk[i])
            neighs = list(neighbours)
            if use_probabilities:
                probabilities = [g.get_edge_data(walk[i], neig)["weight"] for neig in neighs]
                sum_probabilities = sum(probabilities)
                probabilities = list(map(lambda t: t / sum_probabilities, probabilities))
                p = np.random.choice(neighs, p=probabilities)
            else:
                p = random.choice(neighs)
            walk.append(p)
        return walk

    def get_walks(self, g: nx.Graph, use_probabilities: bool = False) -> List[List[str]]:
        """
        Generate all the random walks
        :param g: Graph
        :param use_probabilities:
        :return:
        """
        random_walks = []
        for _ in range(self.walk_per_node):
            random_nodes = list(g.nodes)
            random.shuffle(random_nodes)
            for node in tqdm(random_nodes):
                random_walks.append(self.random_walk(g=g, start=node, use_probabilities=use_probabilities))
        return random_walks

    def compute_embeddings(self, walks: List[List[str]]):
        """
        Compute the node embeddings for the generated walks
        :param walks: List of walks
        :return:
        """
        model = Word2Vec(sentences=walks, window=self.window_size, vector_size=self.embedding_size)
        return model.wv

In [3]:
from torch_geometric.nn import Node2Vec
import os.path as osp
import torch
from torch_geometric.datasets import Planetoid
from tqdm.notebook import tqdm

dataset = 'Cora'
path = osp.join('.', 'data', dataset)
dataset = Planetoid(path, dataset)  # dowload or load the Cora dataset
data = dataset[0]
device = 'cuda' if torch.cuda.is_available() else 'cpu'  # check if cuda is available to send the model and tensors to the GPU
model = Node2Vec(data.edge_index, embedding_dim=128, walk_length=20,
                 context_size=10, walks_per_node=10,
                 num_negative_samples=1, p=1, q=1, sparse=True).to(device)

loader = model.loader(batch_size=128, shuffle=True, num_workers=4)  # data loader to speed the train
optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=0.01)  # initzialize the optimizer


def train():
    model.train()  # put model in train model
    total_loss = 0
    for pos_rw, neg_rw in tqdm(loader):
        optimizer.zero_grad()  # set the gradients to 0
        loss = model.loss(pos_rw.to(device), neg_rw.to(device))  # compute the loss for the batch
        loss.backward()
        optimizer.step()  # optimize the parameters
        total_loss += loss.item()
    return total_loss / len(loader)


for epoch in range(1, 100):
    loss = train()
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}')

all_vectors = ""
for tensor in model(torch.arange(data.num_nodes, device=device)):
    s = "\t".join([str(value) for value in tensor.detach().cpu().numpy()])
    all_vectors += s + "\n"
# save the vectors
with open("vectors.txt", "w") as f:
    f.write(all_vectors)
# save the labels
with open("labels.txt", "w") as f:
    f.write("\n".join([str(label) for label in data.y.numpy()]))

Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 01, Loss: 8.1545


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 02, Loss: 6.0943


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 03, Loss: 4.9933


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 04, Loss: 4.1718


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 05, Loss: 3.5144


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 06, Loss: 2.9815


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 07, Loss: 2.5705


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 08, Loss: 2.2333


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 09, Loss: 1.9683


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 10, Loss: 1.7495


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 11, Loss: 1.5808


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 12, Loss: 1.4407


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 13, Loss: 1.3262


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 14, Loss: 1.2392


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 15, Loss: 1.1672


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 16, Loss: 1.1105


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 17, Loss: 1.0625


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 18, Loss: 1.0258


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 19, Loss: 0.9984


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 20, Loss: 0.9733


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 21, Loss: 0.9530


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 22, Loss: 0.9391


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 23, Loss: 0.9245


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 24, Loss: 0.9124


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 25, Loss: 0.9022


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 26, Loss: 0.8936


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 27, Loss: 0.8874


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 28, Loss: 0.8810


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 29, Loss: 0.8759


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 30, Loss: 0.8714


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 31, Loss: 0.8680


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 32, Loss: 0.8619


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 33, Loss: 0.8607


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 34, Loss: 0.8565


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 35, Loss: 0.8547


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 36, Loss: 0.8510


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 37, Loss: 0.8505


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 38, Loss: 0.8481


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 39, Loss: 0.8448


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 40, Loss: 0.8443


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 41, Loss: 0.8440


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 42, Loss: 0.8414


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 43, Loss: 0.8400


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 44, Loss: 0.8389


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 45, Loss: 0.8377


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 46, Loss: 0.8374


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 47, Loss: 0.8374


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 48, Loss: 0.8365


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 49, Loss: 0.8362


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 50, Loss: 0.8345


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 51, Loss: 0.8328


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 52, Loss: 0.8317


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 53, Loss: 0.8316


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 54, Loss: 0.8325


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 55, Loss: 0.8305


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 56, Loss: 0.8302


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 57, Loss: 0.8298


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 58, Loss: 0.8297


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 59, Loss: 0.8293


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 60, Loss: 0.8292


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 61, Loss: 0.8289


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 62, Loss: 0.8293


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 63, Loss: 0.8279


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 64, Loss: 0.8275


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 65, Loss: 0.8282


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 66, Loss: 0.8262


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 67, Loss: 0.8288


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 68, Loss: 0.8276


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 69, Loss: 0.8271


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 70, Loss: 0.8269


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 71, Loss: 0.8267


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 72, Loss: 0.8267


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 73, Loss: 0.8260


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 74, Loss: 0.8263


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 75, Loss: 0.8264


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 76, Loss: 0.8253


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 77, Loss: 0.8255


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 78, Loss: 0.8248


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 79, Loss: 0.8260


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 80, Loss: 0.8244


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 81, Loss: 0.8253


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 82, Loss: 0.8256


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 83, Loss: 0.8248


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 84, Loss: 0.8262


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 85, Loss: 0.8258


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 86, Loss: 0.8261


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 87, Loss: 0.8258


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 88, Loss: 0.8258


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 89, Loss: 0.8237


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 90, Loss: 0.8251


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 91, Loss: 0.8246


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 92, Loss: 0.8248


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 93, Loss: 0.8249


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 94, Loss: 0.8247


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 95, Loss: 0.8248


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 96, Loss: 0.8252


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 97, Loss: 0.8247


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 98, Loss: 0.8255


  0%|          | 0/22 [00:00<?, ?it/s]

Epoch: 99, Loss: 0.8244
