In [1]:
import torch
from torch_geometric.data import HeteroData

import pandas as pd
import numpy as np

import os
import json

In [8]:
pd.set_option('display.max_rows', None)
pd.set_option('display.max_columns', None)

In [44]:
DATA_PATH = './data'

# folder to save entity id mapping
ID_MAPPING = os.path.join(DATA_PATH, 'entity_id_map')
os.makedirs(ID_MAPPING, exist_ok=True)

# folder to save edge index
EDGE_INDEX = os.path.join(DATA_PATH, 'edge_index')
os.makedirs(EDGE_INDEX, exist_ok=True)

# 1. Grab entities
Companies, symbols, mutual funds, institutions, C-level board

Relations:
* `company` (shortName) --- has symbol ---> `stock symbol`
* `stock symbol`--- listed on ---> `exchange`
* `stock symbol` --- is in ---> `industy`
* `mutual fund`--- has symbol ---> `mutualfund symbol`
* `institution` --- holds ---> `stock symbol`
* `mutual fund` --- holds ---> `stock symbol`

## 1.1. Company, stock symbol, and exchange

In [10]:
stocks = pd.read_parquet(os.path.join(DATA_PATH, 'stocks.parquet'))

In [28]:
stocks.columns

Index(['language', 'region', 'quoteType', 'typeDisp', 'quoteSourceName',
       'triggerable', 'customPriceAlertConfidence', 'marketCap', 'currency',
       'gmtOffSetMilliseconds', 'esgPopulated', 'tradeable', 'cryptoTradeable',
       'bid', 'ask', 'exchange', 'fiftyTwoWeekHigh', 'fiftyTwoWeekLow',
       'averageAnalystRating', 'regularMarketChangePercent',
       'financialCurrency', 'shortName', 'hasPrePostMarketData',
       'firstTradeDateMilliseconds', 'priceHint', 'postMarketChangePercent',
       'postMarketTime', 'postMarketPrice', 'postMarketChange',
       'regularMarketChange', 'regularMarketTime', 'regularMarketPrice',
       'regularMarketDayHigh', 'regularMarketDayRange', 'regularMarketDayLow',
       'regularMarketVolume', 'regularMarketPreviousClose', 'bidSize',
       'askSize', 'market', 'messageBoardId', 'fullExchangeName', 'longName',
       'regularMarketOpen', 'averageDailyVolume3Month',
       'averageDailyVolume10Day', 'corporateActions', 'fiftyTwoWeekLowChan

In [31]:
# grab entities from the data
company = stocks['shortName'].dropna().unique()
stock_symbols = stocks['symbol'].dropna().unique()
exchanges = stocks['exchange'].dropna().unique()

In [35]:
# build id maps
company2id = {name: i for i, name in enumerate(company)}
stocksymbol2id = {name: i for i, name in enumerate(stock_symbols)}
exchange2id = {name: i for i, name in enumerate(exchanges)}

In [39]:
# save to json
with open(os.path.join(ID_MAPPING, 'company2id.json'), 'w') as f:
    json.dump(company2id, f, indent=2)

with open(os.path.join(ID_MAPPING, 'stocksymbol2id.json'), 'w') as f:
    json.dump(stocksymbol2id, f, indent=2)

with open(os.path.join(ID_MAPPING, 'exchange2id.json'), 'w') as f:
    json.dump(exchange2id, f, indent=2)

## 1.2. Mutual fund, mutualfund symbol, insitutions

# 2. Build edge lists for each relation

## 2.1. `company` --- has symbol ---> `stock symbol`

In [40]:
comp2sym_src = []
comp2sym_dst = []

# load id mappings
with open(os.path.join(ID_MAPPING, 'company2id.json')) as f:
    company2id = json.load(f)

with open(os.path.join(ID_MAPPING, 'stocksymbol2id.json')) as f:
    stocksymbol2id = json.load(f)

for row in stocks.itertuples():
    if pd.isna(row.shortName) or pd.isna(row.symbol):   # skip rows where either company shortName or symbol is missing
        continue
    comp2sym_src.append(company2id[row.shortName])
    comp2sym_dst.append(stocksymbol2id[row.symbol])

In [43]:
# sanity check - src and dst should have same length
assert len(comp2sym_src) == len(comp2sym_dst)

In [48]:
# save edge index as torch tensor
edge_index = torch.tensor([comp2sym_src, comp2sym_dst], dtype=torch.long)
torch.save(edge_index, os.path.join(EDGE_INDEX, "comp2sym.pt"))

## 2.2 `stock symbol` --- listed on ---> `exchange`

In [50]:
sym2ex_src = []
sym2ex_dst = []

# load id mappings
with open(os.path.join(ID_MAPPING, 'exchange2id.json')) as f:
    exchange2id = json.load(f)

with open(os.path.join(ID_MAPPING, 'stocksymbol2id.json')) as f:
    stocksymbol2id = json.load(f)

for row in stocks.itertuples():
    if pd.isna(row.exchange) or pd.isna(row.symbol):   # skip rows where either company shortName or symbol is missing
        continue
    sym2ex_src.append(stocksymbol2id[row.symbol])
    sym2ex_dst.append(exchange2id[row.exchange])

In [51]:
assert len(sym2ex_src) == len(sym2ex_dst)

In [52]:
# save edge index as torch tensor
edge_index = torch.tensor([sym2ex_src, sym2ex_dst], dtype=torch.long)
torch.save(edge_index, os.path.join(EDGE_INDEX, "sym2ex.pt"))

In [None]:
## 2.3. 