In [1]:
import torch
from torch_geometric.data import HeteroData
from torch_geometric.loader import NeighborLoader

import pandas as pd
import numpy as np

import os
import json

from utils.hake_dataset import *

from hake.models import HAKE
from hake.data import DataReader, TrainDataset, BatchType, BidirectionalOneShotIterator
from torch.utils.data import DataLoader

import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
DATA_PATH = './data'

# folder to save entity id mapping
ID_MAPPING = os.path.join(DATA_PATH, 'entity_id_map')
os.makedirs(ID_MAPPING, exist_ok=True)

# folder to save edge index
EDGE_INDEX = os.path.join(DATA_PATH, 'edge_index')
os.makedirs(EDGE_INDEX, exist_ok=True)

# folder to save graph data
GRAPH_DATA = os.path.join(DATA_PATH, 'graph_data')
os.makedirs(GRAPH_DATA, exist_ok=True)

# folder to save HAKE data and model
HAKE_PATH = './hake'
os.makedirs(HAKE_PATH, exist_ok=True)

HAKE_DATA = os.path.join(HAKE_PATH, 'data/yfinance_kge')
os.makedirs(HAKE_DATA, exist_ok=True)

HAKE_EMBEDDINGS = os.path.join(HAKE_PATH, 'embeddings/yfinance_kge')
os.makedirs(HAKE_EMBEDDINGS, exist_ok=True)

# 1. Create data input for HAKE model

In [3]:
# # prepare triples input
# hake_triples = make_hake_triples(id_mapping_dir=ID_MAPPING, edge_index_dir=EDGE_INDEX,)

In [4]:
# # convert the triples of the dataset into a HAKE compatible dataset
# make_hake_dataset(
#     triples=hake_triples,
#     # entity2gid=entity2gid,
#     out_dir = HAKE_DATA,
# )

# 2. Use HAKE to generate node embeddings

In [5]:
# !touch hake/data/yfinance_kge/valid.txt
# !touch hake/data/yfinance_kge/test.txt

In [6]:
data_reader = DataReader(HAKE_DATA)

num_entity = len(data_reader.entity_dict)
num_relation = len(data_reader.relation_dict)

print(f"Num entities: {num_entity}")
print(f"Num relations: {num_relation}")

Num entities: 20050
Num relations: 4


In [13]:
MODEL = 'HAKE'
HIDDEN_DIM = 200
GAMMA = 12.0
BATCH_SIZE = 512
NEGATIVE_SAMPLE_SIZE = 512
LEARNING_RATE = 0.0001
MAX_STEPS = 10000
MODULUS_WEIGHT = 1.0
PHASE_WEIGHT = 0.5
ADVERSARIAL_TEMPERATURE = 1.0
CPU_NUM = 4

In [14]:
# -----------------------------
# Model
# -----------------------------
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = HAKE(num_entity, num_relation, HIDDEN_DIM, GAMMA, MODULUS_WEIGHT, PHASE_WEIGHT).to(device)

In [15]:
# -----------------------------
# DataLoader & iterator
# -----------------------------
train_dataloader_head = DataLoader(
    TrainDataset(data_reader, NEGATIVE_SAMPLE_SIZE, BatchType.HEAD_BATCH),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=CPU_NUM,
    collate_fn=TrainDataset.collate_fn
)

train_dataloader_tail = DataLoader(
    TrainDataset(data_reader, NEGATIVE_SAMPLE_SIZE, BatchType.TAIL_BATCH),
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=CPU_NUM,
    collate_fn=TrainDataset.collate_fn
)

train_iterator = BidirectionalOneShotIterator(train_dataloader_head, train_dataloader_tail)

In [16]:
# -----------------------------
# Optimizer
# -----------------------------
optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

In [17]:
# -----------------------------
# Training loop
# -----------------------------
for step in range(MAX_STEPS):
    log = model.train_step(model, optimizer, train_iterator, args=type('args', (), {'adversarial_temperature': ADVERSARIAL_TEMPERATURE}))
    
    if step % 200 == 0:
        print(f"Step {step}: loss={log['loss']:.4f}")

Step 0: loss=3.3810
Step 200: loss=2.0333
Step 400: loss=0.9507
Step 600: loss=0.6238
Step 800: loss=0.5006
Step 1000: loss=0.4516
Step 1200: loss=0.3859
Step 1400: loss=0.3614
Step 1600: loss=0.3239
Step 1800: loss=0.2811
Step 2000: loss=0.2830
Step 2200: loss=0.2532
Step 2400: loss=0.2454
Step 2600: loss=0.2137
Step 2800: loss=0.2258
Step 3000: loss=0.2047
Step 3200: loss=0.1963
Step 3400: loss=0.1795
Step 3600: loss=0.1703
Step 3800: loss=0.1788
Step 4000: loss=0.1629
Step 4200: loss=0.1619
Step 4400: loss=0.1575
Step 4600: loss=0.1478
Step 4800: loss=0.1460
Step 5000: loss=0.1449
Step 5200: loss=0.1418
Step 5400: loss=0.1276
Step 5600: loss=0.1330
Step 5800: loss=0.1291
Step 6000: loss=0.1346
Step 6200: loss=0.1254
Step 6400: loss=0.1398
Step 6600: loss=0.1255
Step 6800: loss=0.1236
Step 7000: loss=0.1169
Step 7200: loss=0.1144
Step 7400: loss=0.1112
Step 7600: loss=0.1056
Step 7800: loss=0.1087
Step 8000: loss=0.1051
Step 8200: loss=0.1066
Step 8400: loss=0.1227
Step 8600: loss=0.

In [18]:
# -----------------------------
# Save embeddings
# -----------------------------
torch.save(model.entity_embedding.detach().cpu(), os.path.join(HAKE_EMBEDDINGS, 'entity_embedding.pt'))
torch.save(model.relation_embedding.detach().cpu(), os.path.join(HAKE_EMBEDDINGS, 'relation_embedding.pt'))
print("Embeddings saved!")

Embeddings saved!
