In [7]:
import dgl
import torch
import torch.nn as nn
from openhgnn.dataset.NodeClassificationDataset import OHGB_NodeClassification
from dgl.nn import MetaPath2Vec
from torch.optim import SparseAdam
from torch.utils.data import DataLoader
from tqdm import tqdm

acm = OHGB_NodeClassification(
    dataset_name="ohgbn-acm", raw_dir="./dataset", logger=None
)

  from .autonotebook import tqdm as notebook_tqdm


Extracting file to ./openhgnn/dataset\ohgbn-acm
Done saving data into cached files.


In [8]:
hg = acm.g
meta_paths_dict = acm.meta_paths_dict

In [9]:
import openhgnn

## test Mp2Vec

In [11]:
m2v_negative_size = 5
m2v_emb_dim = 128
m2v_window_size = 3
m2v_lr = 0.001

m2v_batch_size = 256
m2v_epoch = 20

# m2v_rw_walk_length= 10
# m2v_rw_walks_per_node= 3
device = "cuda"


def train_mp2vec(
    hg,
    category,
    metapaths_dict,
    mp2vec_feat_dim,
    mp2vec_window_size,
    mp2vec_negative_size,
    mp2vec_train_lr,
    mp2vec_train_epoch,
    mp2vec_batch_size,
):
    hg = hg.to(device)
    num_nodes = hg.num_nodes(category)
    embs = torch.zeros(num_nodes,mp2vec_feat_dim).to(device)
    
    # for each metapath
    for mp_name, mp in metapaths_dict.items():
        print("Metapath:", mp_name)
        m2v_model = MetaPath2Vec(
            hg, mp, mp2vec_window_size, mp2vec_feat_dim, mp2vec_negative_size
        ).to(device)
        m2v_model.train()
        dataloader = DataLoader(
            torch.arange(num_nodes),
            batch_size=mp2vec_batch_size,
            shuffle=True,
            collate_fn=m2v_model.sample,
        )
        optimizer = SparseAdam(m2v_model.parameters(), lr=mp2vec_train_lr)
        for _ in tqdm(range(mp2vec_train_epoch)):
            for pos_u, pos_v, neg_v in dataloader:
                loss = m2v_model(pos_u.to(device), pos_v.to(device), neg_v.to(device))
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

        # get the embeddings
        nids = torch.LongTensor(m2v_model.local_to_global_nid[category]).to(device)
        emb = m2v_model.node_embed(nids)
        # embs.append(emb)
        embs+=emb

        del m2v_model, nids, pos_u, pos_v, neg_v
        if device == "cuda":
            torch.cuda.empty_cache()

    # concat these emb of each metapath
    # return torch.concat(embs, dim=1).detach(),
    return embs/len(metapaths_dict)

In [12]:
mp2vec_feat = train_mp2vec(
    hg,
    "paper",
    metapaths_dict=meta_paths_dict,
    mp2vec_feat_dim=m2v_emb_dim,
    mp2vec_window_size=m2v_window_size,
    mp2vec_negative_size=m2v_negative_size,
    mp2vec_train_lr=m2v_lr,
    mp2vec_train_epoch=m2v_epoch,
    mp2vec_batch_size=m2v_batch_size,
)

Metapath: PAP


100%|██████████| 3025/3025 [00:02<00:00, 1186.36it/s]
100%|██████████| 20/20 [00:01<00:00, 10.19it/s]

Metapath: PSP



100%|██████████| 3025/3025 [00:02<00:00, 1224.54it/s]
100%|██████████| 20/20 [00:01<00:00, 10.26it/s]


In [18]:
mp2vec_feat.shape

torch.Size([3025, 128])

In [20]:
mp2vec_feat

tensor([[ 0.0611, -0.0082,  0.0148,  ..., -0.0573, -0.0157, -0.0115],
        [ 0.0560, -0.0079,  0.0225,  ..., -0.0509, -0.0034, -0.0080],
        [ 0.0632, -0.0057,  0.0166,  ..., -0.0595, -0.0013, -0.0015],
        ...,
        [ 0.0666, -0.0157,  0.0204,  ..., -0.0679, -0.0094, -0.0078],
        [ 0.0459, -0.0072,  0.0072,  ..., -0.0501, -0.0030, -0.0004],
        [ 0.0994,  0.0335, -0.0345,  ..., -0.1149, -0.0066,  0.0330]],
       device='cuda:0', grad_fn=<DivBackward0>)