In [1]:
import torch
from torch_geometric.data import HeteroData

import pandas as pd
import numpy as np

import os
import json

import re
import ast

from utils.utils import *    # import custom functions from utils module for cleaning up name strings
from utils.kge import build_global_id_map, build_global_triples, build_hetero_graph

import warnings
warnings.filterwarnings('ignore')

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

In [3]:
DATA_PATH = './data'

# folder to save entity id mapping
ID_MAPPING = os.path.join(DATA_PATH, 'entity_id_map')
os.makedirs(ID_MAPPING, exist_ok=True)

# folder to save edge index
EDGE_INDEX = os.path.join(DATA_PATH, 'edge_index')
os.makedirs(EDGE_INDEX, exist_ok=True)

# folder to save graph data
GRAPH_DATA = os.path.join(DATA_PATH, 'graph_data')
os.makedirs(GRAPH_DATA, exist_ok=True)

# 1. Grab entities
Companies, symbols, mutual funds, institutions, C-level board

Relations:
* `company` (shortName) --- has symbol ---> `stock symbol`
* `stock symbol`--- listed on ---> `exchange`
* `company` --- belongs to ---> `industry`
* `industry`--- is part of ---> `sector`
* `company` --- employs C-level member ---> `person`
* `mutual fund`--- has symbol ---> `mutualfund symbol`
* `institution` --- holds ---> `stock symbol`
* `mutual fund` --- holds ---> `stock symbol`
* `stock symbol` --- co-mentioned in news with ---> `stock symbol`

In [4]:
stocks = pd.read_parquet(os.path.join(DATA_PATH, 'stocks_info.parquet'))

In [5]:
inst = pd.read_parquet(os.path.join(DATA_PATH, 'institutional_holders.parquet'))

In [6]:
funds = pd.read_parquet(os.path.join(DATA_PATH, 'mutualfund_holders.parquet'))

In [7]:
funds_symbol = pd.read_parquet(os.path.join(DATA_PATH, 'funds_symbol.parquet'))

In [8]:
funds_symbol.head()

Unnamed: 0,fund_name,symbols
0,vanguard total stock market index fund,VTI
1,vanguard total stock market index fund,VTI.MX
2,vanguard extended market index fund,VXF
3,vanguard extended market index fund,VEMPX
4,vanguard extended market index fund,VEXAX


## 1.1. Company, stock symbol, exchange, sector, industry

In [9]:
stocks.head(1)

Unnamed: 0,address1,city,state,zip,country,phone,fax,website,industry,industryKey,industryDisp,sector,sectorKey,sectorDisp,longBusinessSummary,fullTimeEmployees,companyOfficers,compensationAsOfEpochDate,executiveTeam,maxAge,priceHint,previousClose,open,dayLow,dayHigh,regularMarketPreviousClose,regularMarketOpen,regularMarketDayLow,regularMarketDayHigh,exDividendDate,payoutRatio,beta,forwardPE,volume,regularMarketVolume,averageVolume,averageVolume10days,averageDailyVolume10Day,bid,ask,bidSize,askSize,marketCap,fiftyTwoWeekLow,fiftyTwoWeekHigh,allTimeHigh,allTimeLow,priceToSalesTrailing12Months,fiftyDayAverage,twoHundredDayAverage,trailingAnnualDividendRate,trailingAnnualDividendYield,currency,tradeable,enterpriseValue,profitMargins,floatShares,sharesOutstanding,sharesShort,sharesShortPriorMonth,sharesShortPreviousMonthDate,dateShortInterest,sharesPercentSharesOut,heldPercentInsiders,heldPercentInstitutions,shortRatio,shortPercentOfFloat,impliedSharesOutstanding,bookValue,lastFiscalYearEnd,nextFiscalYearEnd,mostRecentQuarter,netIncomeToCommon,trailingEps,forwardEps,lastSplitFactor,lastSplitDate,enterpriseToRevenue,enterpriseToEbitda,52WeekChange,SandP52WeekChange,lastDividendValue,lastDividendDate,quoteType,currentPrice,targetHighPrice,targetLowPrice,targetMeanPrice,targetMedianPrice,recommendationMean,recommendationKey,numberOfAnalystOpinions,totalCash,totalCashPerShare,ebitda,totalDebt,quickRatio,currentRatio,totalRevenue,revenuePerShare,returnOnAssets,grossProfits,freeCashflow,operatingCashflow,revenueGrowth,grossMargins,ebitdaMargins,operatingMargins,financialCurrency,symbol,language,region,typeDisp,quoteSourceName,triggerable,customPriceAlertConfidence,marketState,corporateActions,preMarketTime,regularMarketTime,exchange,messageBoardId,exchangeTimezoneName,exchangeTimezoneShortName,gmtOffSetMilliseconds,market,esgPopulated,regularMarketChangePercent,regularMarketPrice,hasPrePostMarketData,firstTradeDateMilliseconds,preMarketChange,preMarketChangePercent,preMarketPrice,regularMarketChange,regularMarketDayRange,fullExchangeName,averageDailyVolume3Month,fiftyTwoWeekLowChange,fiftyTwoWeekLowChangePercent,fiftyTwoWeekRange,fiftyTwoWeekHighChange,fiftyTwoWeekHighChangePercent,fiftyTwoWeekChangePercent,dividendDate,earningsTimestampStart,earningsTimestampEnd,earningsCallTimestampStart,earningsCallTimestampEnd,isEarningsDateEstimate,epsTrailingTwelveMonths,epsForward,fiftyDayAverageChange,fiftyDayAverageChangePercent,twoHundredDayAverageChange,twoHundredDayAverageChangePercent,priceToBook,sourceInterval,exchangeDataDelayedBy,averageAnalystRating,cryptoTradeable,shortName,longName,displayName,trailingPegRatio,address2,auditRisk,boardRisk,compensationRisk,shareHolderRightsRisk,overallRisk,governanceEpochDate,debtToEquity,returnOnEquity,earningsTimestamp,epsCurrentYear,priceEpsCurrentYear,ipoExpectedDate,dividendRate,dividendYield,fiveYearAvgDividendYield,trailingPE,earningsQuarterlyGrowth,earningsGrowth,prevName,nameChangeDate,irWebsite,openInterest,pegRatio,newListingDate,prevTicker,tickerChangeDate,prevExchange,exchangeTransferDate,industrySymbol
0,9655 Maroon Circle,Englewood,CO,80112,United States,303 703 4906,800 495 6695,https://www.zynex.com,Medical Distribution,medical-distribution,Medical Distribution,Healthcare,healthcare,Healthcare,"Zynex, Inc., together with its subsidiaries, d...",1000.0,"[{'age': 65.0, 'exercisedValue': 0, 'fiscalYea...",1735603000.0,[],86400,4,0.7374,0.721,0.695,2.1,0.7374,0.721,0.695,2.1,1641341000.0,0.0,1.01,3.3125,136585146.0,136585146.0,4407053.0,27279780.0,27279780.0,1.29,2.06,2.0,2.0,48317932.0,0.38,8.72,27.027273,0.054545,0.446553,1.2252,2.4123,0.0,0.0,USD,False,106676928.0,-0.68352,15791758.0,30388635.0,3552328.0,3438928.0,1760486000.0,1763078000.0,0.1169,0.48163,0.13781,18.37,0.2247,30388635.0,-1.34,1735603000.0,1767139000.0,1759190000.0,-73958000.0,-2.42,0.48,11:10,1641341000.0,0.986,-3.295,-0.808894,0.12934,0.1,1641341000.0,EQUITY,1.59,3.7,3.7,3.7,3.7,1.0,strong_buy,1.0,13259000.0,0.436,-32378000.0,71618000.0,0.253,0.469,108202000.0,3.489,-0.27107,77912000.0,-2242375.0,-20566000.0,-0.733,0.72006,-0.29924,-0.98967,USD,ZYXI,en-US,US,Equity,Nasdaq Real Time Price,True,HIGH,PRE,[],1764337000.0,1764190801,NMS,finmb_3103657,America/New_York,EST,-18000000,us_market,False,115.622,1.59,True,1077719000000.0,-0.036,-2.264152,1.554,0.8526,0.695 - 2.1,NasdaqGS,4407053.0,1.21,3.184211,0.38 - 8.72,-7.13,-0.817661,-80.88942,1642723000.0,1763068000.0,1763068000.0,1763474000.0,1763474000.0,False,-2.42,0.48,0.3648,0.297747,-0.8223,-0.340878,-1.186567,15,0,1.0 - Strong Buy,False,"Zynex, Inc.","Zynex, Inc.",Zynex,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,,


In [10]:
# grab entities from the data
companies = stocks['shortName'].dropna().unique()
stock_symbols = stocks['symbol'].dropna().unique()
exchanges = stocks['exchange'].dropna().unique()
industries = stocks['industryKey'].dropna().unique()
sectors = stocks['sectorKey'].dropna().unique()

In [11]:
# build id maps
company2id = {name: i for i, name in enumerate(companies)}
stocksymbol2id = {name: i for i, name in enumerate(stock_symbols)}
exchange2id = {name: i for i, name in enumerate(exchanges)}
industry2id = {name: i for i, name in enumerate(industries)}
sector2id = {name: i for i, name in enumerate(sectors)}

In [12]:
# save to json
with open(os.path.join(ID_MAPPING, 'company2id.json'), 'w') as f:
    json.dump(company2id, f, indent=2)

with open(os.path.join(ID_MAPPING, 'stocksymbol2id.json'), 'w') as f:
    json.dump(stocksymbol2id, f, indent=2)

with open(os.path.join(ID_MAPPING, 'exchange2id.json'), 'w') as f:
    json.dump(exchange2id, f, indent=2)

with open(os.path.join(ID_MAPPING, 'industry2id.json'), 'w') as f:
    json.dump(industry2id, f, indent=2)

with open(os.path.join(ID_MAPPING, 'sector2id.json'), 'w') as f:
    json.dump(sector2id, f, indent=2)

## 1.2. Company and employed officers

In [13]:
# get list of officers
tmp = stocks[['shortName', 'companyOfficers']].copy()

tmp['officerNames'] = tmp['companyOfficers'].apply(extract_officer_names)


In [14]:
tmp.head(1).values

array([['Zynex, Inc.',
        array([{'age': 65.0, 'exercisedValue': 0, 'fiscalYear': 2024.0, 'maxAge': 1, 'name': 'Mr. Thomas  Sandgaard', 'title': 'Founder, President & Chairman', 'totalPay': 879352.0, 'unexercisedValue': 22462, 'yearBorn': 1959.0},
               {'age': 51.0, 'exercisedValue': 0, 'fiscalYear': 2024.0, 'maxAge': 1, 'name': 'Dr. Steven Lewis Dyson Ph.D.', 'title': 'CEO & Director', 'totalPay': None, 'unexercisedValue': 0, 'yearBorn': 1973.0},
               {'age': None, 'exercisedValue': 0, 'fiscalYear': 2024.0, 'maxAge': 1, 'name': 'Mr. Vikram  Bajaj', 'title': 'Chief Financial Officer', 'totalPay': None, 'unexercisedValue': 0, 'yearBorn': None},
               {'age': 51.0, 'exercisedValue': 0, 'fiscalYear': 2024.0, 'maxAge': 1, 'name': 'Mr. John T. Bibb', 'title': 'Chief Legal Officer', 'totalPay': None, 'unexercisedValue': 0, 'yearBorn': 1973.0},
               {'age': None, 'exercisedValue': 0, 'fiscalYear': 2024.0, 'maxAge': 1, 'name': 'Mr. Ajay  Gopal', 'tit

In [15]:
# get list of officers
officers = tmp['officerNames'].explode().dropna().unique()
# build id map
officer2id = {name: i for i, name in enumerate(officers)}
# save id map
with open(os.path.join(ID_MAPPING, 'officer2id.json'), 'w') as f:
    json.dump(officer2id, f, indent=2)

## 1.3. Institutions

In [16]:
# get list of institution
inst['holderNames'] = inst['holders'].apply(extract_institution_names)

institutions = inst['holderNames'].explode().dropna().unique()

# build id map
institution2id = {name: i for i, name in enumerate(institutions)}
with open(os.path.join(ID_MAPPING, 'institution2id.json'), 'w') as f:
    json.dump(institution2id, f, indent=2)

# institution2id

## 1.4. Mutual funds & mutual funds symbol

In [17]:
# funds.iloc[0].values

# get list of mutual funds
funds['fundNames'] = funds['holders'].apply(extract_mutualfund_names)

mutualfunds = funds['fundNames'].explode().dropna().unique()

# build id map
mutualfund2id = {name: i for i, name in enumerate(mutualfunds)}
with open(os.path.join(ID_MAPPING, 'mutualfund2id.json'), 'w') as f:
    json.dump(mutualfund2id, f, indent=2)

In [18]:
# get list of fund symbols
symbols = funds_symbol['symbols'].dropna().unique()

# build id map
fundsymbol2id = {name: i for i, name in enumerate(symbols)}
with open(os.path.join(ID_MAPPING, 'fundsymbol2id.json'), 'w') as f:
    json.dump(fundsymbol2id, f, indent=2)

# 2. Build edge lists for each relation

## 2.1. `company` --- has symbol ---> `stock symbol`

In [19]:
comp2sym_src = []
comp2sym_dst = []

# load id mappings
with open(os.path.join(ID_MAPPING, 'company2id.json')) as f:
    company2id = json.load(f)

with open(os.path.join(ID_MAPPING, 'stocksymbol2id.json')) as f:
    stocksymbol2id = json.load(f)

edges = set()
for row in stocks.itertuples():
    if pd.isna(row.shortName) or pd.isna(row.symbol):   # skip rows where either company shortName or symbol is missing
        continue
    h = company2id[row.shortName]
    t = stocksymbol2id[row.symbol]
    edges.add((h, t))

comp2sym_src, comp2sym_dst = zip(*edges)

In [20]:
# sanity check - src and dst should have same length
assert len(comp2sym_src) == len(comp2sym_dst)

In [21]:
# save edge index as torch tensor
edge_index = torch.tensor([comp2sym_src, comp2sym_dst], dtype=torch.long)
torch.save(edge_index, os.path.join(EDGE_INDEX, "comp2sym.pt"))

## 2.2 `stock symbol` --- listed on ---> `exchange`

In [22]:
sym2ex_src = []
sym2ex_dst = []

# load id mappings
with open(os.path.join(ID_MAPPING, 'exchange2id.json')) as f:
    exchange2id = json.load(f)

with open(os.path.join(ID_MAPPING, 'stocksymbol2id.json')) as f:
    stocksymbol2id = json.load(f)

edges = set()
for row in stocks.itertuples():
    if pd.isna(row.exchange) or pd.isna(row.symbol):   # skip rows where either company shortName or symbol is missing
        continue
    h = stocksymbol2id[row.symbol]
    t = exchange2id[row.exchange]
    edges.add((h, t))

sym2ex_src, sym2ex_dst = zip(*edges)

In [23]:
assert len(sym2ex_src) == len(sym2ex_dst)

In [24]:
# save edge index as torch tensor
edge_index = torch.tensor([sym2ex_src, sym2ex_dst], dtype=torch.long)
torch.save(edge_index, os.path.join(EDGE_INDEX, "sym2ex.pt"))

## 2.3. `company` --- belongs to ---> `industry`

In [25]:
comp2ind_src = []
comp2ind_dst = []

# load id mappings
with open(os.path.join(ID_MAPPING, 'company2id.json')) as f:
    company2id = json.load(f)

with open(os.path.join(ID_MAPPING, 'industry2id.json')) as f:
    industry2id = json.load(f)

edges = set()
for row in stocks.itertuples():
    if pd.isna(row.shortName) or pd.isna(row.industryKey):   # skip rows where either company shortName or industry key is missing
        continue
    # comp2ind_src.append(company2id[row.shortName])
    # comp2ind_dst.append(industry2id[row.industryKey])
    h = company2id[row.shortName]
    t = industry2id[row.industryKey]
    edges.add((h, t))

comp2ind_src, comp2ind_dst = zip(*edges)

In [26]:
assert len(comp2ind_src) == len(comp2ind_dst)

In [27]:
# save edge index as torch tensor
edge_index = torch.tensor([comp2ind_src, comp2ind_dst], dtype=torch.long)
torch.save(edge_index, os.path.join(EDGE_INDEX, "comp2ind.pt"))

## 2.4. `industry` --- is part of ---> `sector`

In [28]:
ind2sec_src = []
ind2sec_dst = []

# load id mappings
with open(os.path.join(ID_MAPPING, 'industry2id.json')) as f:
    industry2id = json.load(f)

with open(os.path.join(ID_MAPPING, 'sector2id.json')) as f:
    sector2id = json.load(f)

edges = set()
for row in stocks.itertuples():
    if pd.isna(row.industryKey) or pd.isna(row.sectorKey):   # skip rows where either sectorKey or industry key is missing
        continue
    h = industry2id[row.industryKey]
    t = sector2id[row.sectorKey]
    edges.add((h, t))

ind2sec_src, ind2sec_dst = zip(*edges)

In [29]:
assert len(ind2sec_src) == len(ind2sec_dst)

In [30]:
# save edge index as torch tensor
edge_index = torch.tensor([ind2sec_src, ind2sec_dst], dtype=torch.long)
torch.save(edge_index, os.path.join(EDGE_INDEX, "ind2sec.pt"))

## 2.5. `company` --- employs C-level officer ---> `officer`

In [31]:
comp2off_src = []
comp2off_dst = []

# load id mappings
with open(os.path.join(ID_MAPPING, 'company2id.json')) as f:
    company2id = json.load(f)

with open(os.path.join(ID_MAPPING, 'officer2id.json')) as f:
    officer2id = json.load(f)

# get list of officers and corresponding company names
tmp = stocks[['shortName', 'companyOfficers']].copy()
tmp['officerNames'] = tmp['companyOfficers'].apply(extract_officer_names)
tmp = tmp.explode('officerNames')
tmp.head()

edges = set()
for row in tmp.itertuples():
    if pd.isna(row.shortName) or pd.isna(row.officerNames):   # skip rows where either sectorKey or industry key is missing
        continue
    h = company2id[row.shortName]
    t = officer2id[row.officerNames]
    edges.add((h, t))

comp2off_src, comp2off_dst = zip(*edges)

In [32]:
assert len(comp2off_src) == len(comp2off_dst)

In [33]:
# save edge index as torch tensor
edge_index = torch.tensor([comp2off_src, comp2off_dst], dtype=torch.long)
torch.save(edge_index, os.path.join(EDGE_INDEX, "comp2off.pt"))

## 2.6. `institution` --- holds ---> `stock symbol`

In [34]:
inst2sym_src = []
inst2sym_dst = []

# load id mappings
with open(os.path.join(ID_MAPPING, 'stocksymbol2id.json')) as f:
    stocksymbol2id = json.load(f)

with open(os.path.join(ID_MAPPING, 'institution2id.json')) as f:
    institution2id = json.load(f)

# get list of officers and corresponding company names
tmp = inst[['symbol', 'holders']].copy()
tmp['holderNames'] = tmp['holders'].apply(extract_institution_names)
tmp = tmp.explode('holderNames')
tmp.head()

edges = set()
for row in tmp.itertuples():
    if pd.isna(row.symbol) or pd.isna(row.holderNames):   # skip rows where either sectorKey or industry key is missing
        continue
    h = institution2id[row.holderNames]
    t = stocksymbol2id[row.symbol]
    edges.add((h, t))

inst2sym_src, inst2sym_dst = zip(*edges)

In [35]:
assert len(inst2sym_src) == len(inst2sym_dst)

In [36]:
# save edge index as torch tensor
edge_index = torch.tensor([inst2sym_src, inst2sym_dst], dtype=torch.long)
torch.save(edge_index, os.path.join(EDGE_INDEX, "inst2sym.pt"))

## 2.7. `stock symbol` --- co-mentioned in news with ---> `stock symbol` (to be completed later)

In [37]:
# comention = pd.read_parquet(os.path.join(DATA_PATH, 'stocks_related_tickers.parquet'))

# comention.head()

## 2.8. `mutual fund` --- holds ---> `stock symbol`

In [38]:
fund2sym_src = []
fund2sym_dst = []

# load id mappings
with open(os.path.join(ID_MAPPING, 'stocksymbol2id.json')) as f:
    stocksymbol2id = json.load(f)

with open(os.path.join(ID_MAPPING, 'mutualfund2id.json')) as f:
    mutualfund2id = json.load(f)

# get list of officers and corresponding company names
tmp = funds[['symbol', 'holders']].copy()
tmp['holderNames'] = tmp['holders'].apply(extract_mutualfund_names)
tmp = tmp.explode('holderNames')
tmp.head()

edges = set()
for row in tmp.itertuples():
    if pd.isna(row.symbol) or pd.isna(row.holderNames):   # skip rows where either sectorKey or industry key is missing
        continue
    h = mutualfund2id[row.holderNames]
    t = stocksymbol2id[row.symbol]
    edges.add((h, t))

fund2sym_src, fund2sym_dst = zip(*edges)

In [39]:
assert len(fund2sym_src) == len(fund2sym_dst)

In [40]:
# save edge index as torch tensor
edge_index = torch.tensor([fund2sym_src, fund2sym_dst], dtype=torch.long)
torch.save(edge_index, os.path.join(EDGE_INDEX, "fund2stocksym.pt"))

## 2.9. `mutual fund` --- has symbol ---> `fund symbol`

In [41]:
fund2sym_src = []
fund2sym_dst = []

# load id mappings
with open(os.path.join(ID_MAPPING, 'mutualfund2id.json')) as f:
    mutualfund2id = json.load(f)

with open(os.path.join(ID_MAPPING, 'fundsymbol2id.json')) as f:
    fundsymbol2id = json.load(f)

edges = set()
for row in funds_symbol.itertuples():
    if pd.isna(row.fund_name) or pd.isna(row.symbols):   # skip rows where either company shortName or symbol is missing
        continue
    h = mutualfund2id[row.fund_name]
    t = fundsymbol2id[row.symbols]
    edges.add((h, t))

fund2sym_src, fund2sym_dst = zip(*edges)

In [42]:
# sanity check - src and dst should have same length
assert len(fund2sym_src) == len(fund2sym_dst)

In [43]:
# save edge index as torch tensor
edge_index = torch.tensor([fund2sym_src, fund2sym_dst], dtype=torch.long)
torch.save(edge_index, os.path.join(EDGE_INDEX, "fund2fundsym.pt"))

# 3. Build knowledge graph
Steps:
1. Convert local IDs of each entity type to a global ID
2. Convert every typed edge_index (src, dst) to triples (h, r, t)

## 3.1. Step 1: Build global ID map for entities

In [44]:
# grab entity id mappings
with open(os.path.join(ID_MAPPING, 'company2id.json')) as f:
    company2id = json.load(f)

with open(os.path.join(ID_MAPPING, 'stocksymbol2id.json')) as f:
    stocksymbol2id = json.load(f)

with open(os.path.join(ID_MAPPING, 'industry2id.json')) as f:
    industry2id = json.load(f)

with open(os.path.join(ID_MAPPING, 'sector2id.json')) as f:
    sector2id = json.load(f)

with open(os.path.join(ID_MAPPING, 'exchange2id.json')) as f:
    exchange2id = json.load(f)

with open(os.path.join(ID_MAPPING, 'officer2id.json')) as f:
    officer2id = json.load(f)

with open(os.path.join(ID_MAPPING, 'institution2id.json')) as f:
    institution2id = json.load(f)

with open(os.path.join(ID_MAPPING, 'mutualfund2id.json')) as f:
    mutualfund2id = json.load(f)

with open(os.path.join(ID_MAPPING, 'fundsymbol2id.json')) as f:
    fundsymbol2id = json.load(f)

In [45]:
entity_id_maps = [
    company2id, stocksymbol2id, industry2id, sector2id, 
    exchange2id, officer2id, institution2id, mutualfund2id, fundsymbol2id
]

entity_types = [
    'company', 'stock_symbol', 'industry', 'sector',
    'exchange', 'officer', 'institution', 'mutualfund', 'fund_symbol'
]

global_map, type_map, offsets = build_global_id_map(
    entity_id_maps=entity_id_maps, entity_types=entity_types
)

In [46]:
# save mappings
with open(os.path.join(ID_MAPPING, 'global_id.json'), 'w') as f:
    json.dump(global_map, f, indent=2)

with open(os.path.join(ID_MAPPING, 'global_type_map.json'), 'w') as f:
    json.dump(type_map, f, indent=2)

with open(os.path.join(ID_MAPPING, 'global_entity_id_offsets.json'), 'w') as f:
    json.dump(offsets, f, indent=2)

## 3.2. Build global edge index

In [47]:
# load local edge indices (edge indices of each relation type using local entity id mappings)
comp2ind = torch.load(os.path.join(EDGE_INDEX, 'comp2ind.pt'))   # company belongs to industry
comp2off = torch.load(os.path.join(EDGE_INDEX, 'comp2off.pt'))   # company employs officer
comp2sym = torch.load(os.path.join(EDGE_INDEX, 'comp2sym.pt'))   # company has stock symbol
ind2sec = torch.load(os.path.join(EDGE_INDEX, 'ind2sec.pt'))     # industry belongs to sector
inst2sym = torch.load(os.path.join(EDGE_INDEX, 'inst2sym.pt'))   # institution holds stock symbol
sym2ex = torch.load(os.path.join(EDGE_INDEX, 'sym2ex.pt'))       # stock symbol listed on exchange
fund2stocksym = torch.load(os.path.join(EDGE_INDEX, 'fund2stocksym.pt'))   # mutual fund holds stock symbol
fund2fundsym = torch.load(os.path.join(EDGE_INDEX, 'fund2fundsym.pt'))   # mutual fund has fund symbol

In [48]:
# create edge indices list for input 
edge_indices_list = [
    {
        "relation": "has_symbol",
        "head_type": "company",
        "tail_type": "stock_symbol",
        "edge_index": comp2sym,  # tensor [2, num_edges]
    },
    {
        "relation": "is_listed_on",
        "head_type": "stock_symbol",
        "tail_type": "exchange",
        "edge_index": sym2ex,
    },
    {
        "relation": "belongs_to",
        "head_type": "company",
        "tail_type": "industry",
        "edge_index": comp2ind,
    },
    {
        "relation": "is_part_of",
        "head_type": "industry",
        "tail_type": "sector",
        "edge_index": ind2sec,
    },
    {
        "relation": "employs",
        "head_type": "company",
        "tail_type": "officer",
        "edge_index": comp2off,
    },
    {
        "relation": "holds",
        "head_type": "institution",
        "tail_type": "stock_symbol",
        "edge_index": inst2sym,
    },
    {
        "relation": "holds",
        "head_type": "mutualfund",
        "tail_type": "stock_symbol",
        "edge_index": fund2stocksym,
    },
    {
        "relation": "holds",
        "head_type": "mutualfund",
        "tail_type": "fund_symbol",
        "edge_index": fund2fundsym,
    },
]

# grab global entity id mappings created earlier
with open(os.path.join(ID_MAPPING, 'global_id.json')) as f:
    global_id = json.load(f)

In [49]:
# create global triples
global_triples = build_global_triples(edge_indices_list, global_id)

# store
with open(os.path.join(EDGE_INDEX, 'global_triples.json'), 'w') as f:
    json.dump(global_triples, f, indent=2)

Building global triples for (company has_symbol stock_symbol)
Building global triples for (stock_symbol is_listed_on exchange)
Building global triples for (company belongs_to industry)
Building global triples for (industry is_part_of sector)
Building global triples for (company employs officer)
Building global triples for (institution holds stock_symbol)
Building global triples for (mutualfund holds stock_symbol)
Building global triples for (mutualfund holds fund_symbol)


## 3.3. Create HeteroData object

In [50]:
# load global triples and global entity type maps
with open(os.path.join(EDGE_INDEX, 'global_triples.json')) as f:
    global_triples = json.load(f)

with open(os.path.join(ID_MAPPING, 'global_type_map.json')) as f:
    type_map = json.load(f)

In [51]:
# create hetero data object
hetero_data = build_hetero_graph(global_triples, type_map)

In [52]:
hetero_data

HeteroData(
  (company, has_symbol, stock_symbol)={ edge_index=[2, 4508] },
  (stock_symbol, is_listed_on, exchange)={ edge_index=[2, 4607] },
  (company, belongs_to, industry)={ edge_index=[2, 4035] },
  (industry, is_part_of, sector)={ edge_index=[2, 145] },
  (company, employs, officer)={ edge_index=[2, 33334] },
  (institution, holds, stock_symbol)={ edge_index=[2, 39512] },
  (mutualfund, holds, stock_symbol)={ edge_index=[2, 37098] },
  (mutualfund, holds, fund_symbol)={ edge_index=[2, 6851] }
)

In [53]:
# save data object
torch.save(hetero_data, os.path.join(GRAPH_DATA, 'yfinance_kge.pt'))

In [54]:
# load it
graph_ = torch.load(os.path.join(GRAPH_DATA, 'yfinance_kge.pt'))

graph_

HeteroData(
  (company, has_symbol, stock_symbol)={ edge_index=[2, 4508] },
  (stock_symbol, is_listed_on, exchange)={ edge_index=[2, 4607] },
  (company, belongs_to, industry)={ edge_index=[2, 4035] },
  (industry, is_part_of, sector)={ edge_index=[2, 145] },
  (company, employs, officer)={ edge_index=[2, 33334] },
  (institution, holds, stock_symbol)={ edge_index=[2, 39512] },
  (mutualfund, holds, stock_symbol)={ edge_index=[2, 37098] },
  (mutualfund, holds, fund_symbol)={ edge_index=[2, 6851] }
)