In [190]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from itertools import combinations
from datetime import datetime
from torch_geometric.data import InMemoryDataset, HeteroData
from torch.utils.data import Dataset, DataLoader
import gc

In [230]:
universe = pd.read_pickle('data/universe.pkl')
universe.index = pd.to_datetime(universe.index)
permno_universe = pd.read_pickle('data/permno_universe.pkl')
permno_universe.index = pd.to_datetime(permno_universe.index)
mapper_df = pd.read_pickle('data/mapper_df.pkl')
hist_ret_df = pd.read_pickle('data/hist_ret_df.pkl.gz')
weekly_ret_df = pd.read_pickle('data/weekly_ret_df.pkl')
sector_df = pd.read_pickle('data/sector_df.pkl')
all_sectors = sector_df['gsector'].dropna().astype(int).unique()
supchain_df = pd.read_pickle('data/supchain_df.pkl')

date_lst = list(weekly_ret_df.loc['2012-01-01':'2022-01-01'].index.unique())

concat_lst = []
for date in date_lst:
    cur_hist_ret_df = hist_ret_df[hist_ret_df.index.get_level_values(0)<= date]
    cur_hist_ret_df = cur_hist_ret_df[cur_hist_ret_df.index.get_level_values(0) == cur_hist_ret_df.index.get_level_values(0).max()].droplevel(0,axis=0)
    concat_lst.append(cur_hist_ret_df.stack())
weekly_hist_ret_df = pd.DataFrame(concat_lst,index=date_lst).stack(level=0)
weekly_hist_ret_df

# del hist_ret_df
# gc.collect()

Unnamed: 0_level_0,Unnamed: 1_level_0,annret_1,annret_10,annret_126,annret_21,annret_42,annret_5,annret_63,annvol_10,annvol_126,annvol_21,annvol_42,annvol_5,annvol_63
Unnamed: 0_level_1,PERMNO,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
2012-01-06,10104,3.222324,0.112760,-0.002830,-0.087781,-0.024585,0.436061,-0.005498,0.058949,0.039306,0.105381,0.069426,0.072915,0.052101
2012-01-06,10107,3.869208,0.222130,0.001213,0.053021,0.011504,0.786250,0.005002,0.060353,0.024998,0.043627,0.037206,0.090714,0.028821
2012-01-06,10138,1.815660,0.095697,0.000700,0.013955,0.017810,0.236638,0.012142,0.057792,0.047099,0.069306,0.059986,0.063400,0.065391
2012-01-06,10145,-1.858500,0.047227,-0.000295,0.007538,0.005288,0.074652,0.014259,0.058987,0.037104,0.056291,0.047043,0.088532,0.043731
2012-01-06,10516,-1.465128,0.036580,-0.000233,-0.008644,0.001917,0.117372,0.008581,0.060245,0.035331,0.059386,0.044757,0.073511,0.042580
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
2021-12-31,92121,-0.587412,0.001988,-0.000059,0.051337,0.002889,0.019495,-0.004914,0.110845,0.029668,0.074016,0.049005,0.026931,0.042631
2021-12-31,92293,-4.487616,0.034985,-0.001638,-0.000515,-0.037571,-0.378585,-0.018213,0.077906,0.036500,0.065435,0.071886,0.068659,0.051525
2021-12-31,92602,1.978452,0.050712,-0.000318,0.059409,0.004672,0.222768,0.000360,0.032780,0.016222,0.032681,0.027309,0.022702,0.021535
2021-12-31,92611,-1.341648,0.029358,-0.001217,0.076301,0.004149,0.004476,-0.003039,0.066726,0.022921,0.054307,0.046845,0.026555,0.038972


In [241]:
# def get_gvkey_dict():
#     gvkey_dict = set(supchain_df['gvkey']).intersection(supchain_df['cgvkey'])
#     N = len(gvkey_dict)
#     gvkey_dict = dict(zip(list(gvkey_dict),range(N)))
#     return gvkey_dict

def get_mapper(date,permno=None,gvkey=None):
    # Map between permno and gvkey
    mapper = mapper_df[mapper_df['date']<= date]
    mapper = mapper[mapper['date'] == mapper['date'].max()].drop('date',axis=1)
    assert len(mapper) > 0
    if permno is not None:
        mapper = mapper.set_index('permno')
        return mapper.loc[permno,'gvkey'].astype(int).tolist()
    else:
        assert gvkey is not None
        mapper = mapper.set_index('gvkey')
        return mapper.loc[gvkey,'permno'].astype(int).tolist()

def get_permnogvkey_toidx_mapper(date,permno=None,gvkey=None):
    # Map permno/gvkey to idx
    assert (permno is None) or (gvkey is None)
    mapper = mapper_df[mapper_df['date']<= date]
    mapper = mapper[mapper['date'] == mapper['date'].max()].drop('date',axis=1).reset_index(drop=True)
    mapper['index'] = mapper.index
    assert len(mapper) > 0
    if permno is not None:
        mapper = mapper.set_index('permno',drop=True)
        return mapper.loc[permno,'index'].tolist()
    else:
        assert gvkey is not None
        mapper = mapper.set_index('gvkey',drop=True)
        return mapper.loc[gvkey,'index'].tolist()

def get_idxto_permnogvkey_mapper(date,idx,permno=None,gvkey=None):
    # Map idx to permno/gvkey
    assert (permno and not gvkey) or (gvkey and not permno)
    mapper = mapper_df[mapper_df['date']<= date]
    mapper = mapper[mapper['date'] == mapper['date'].max()].drop('date',axis=1).reset_index(drop=True)
    assert len(mapper) > 0
    if permno:
        return mapper.loc[idx,'permno'].tolist()
    else:
        return mapper.loc[idx,'gvkey'].tolist()

def get_gvkey_universe(date):
    cur_universe = universe[universe.index<= date]
    cur_universe = cur_universe.loc[cur_universe.index.max()]
    return cur_universe.dropna().astype(int).tolist()
    
def get_permno_universe(date):
    cur_universe = permno_universe[permno_universe.index<= date]
    cur_universe = cur_universe.loc[cur_universe.index.max()]
    return cur_universe.dropna().astype(int).tolist()

def get_sector_edges(date,gvkey_to_idx_mapper,gvkey_universe):
    cur_sector_df = sector_df[sector_df['datadate']<= date]
    cur_sector_df = cur_sector_df[cur_sector_df['datadate'] == cur_sector_df['datadate'].max()].drop('datadate',axis=1)
    cur_sector_df = cur_sector_df[cur_sector_df['GVKEY'].isin(gvkey_universe)]
    total = 0
    edge_lst = []
    for sector in all_sectors:
        gvkeys = cur_sector_df.loc[cur_sector_df['gsector']==sector,'GVKEY'].unique()
        gvkeys = gvkey_to_idx_mapper(date,gvkey=gvkeys)
        pairs = list(combinations(gvkeys,2))
        total += len(pairs)
        edge_lst.extend(pairs)
    # print(total)
    return np.array(edge_lst)

def get_supply_chain_edges(date,gvkey_to_idx_mapper,gvkey_universe):
    cur_df = supchain_df[(supchain_df['srcdate'] <= date) & (supchain_df['srcdate'] >= date -pd.DateOffset(years=3)) ]
    cur_df = cur_df[cur_df['gvkey'].isin(gvkey_universe) & cur_df['cgvkey'].isin(gvkey_universe)]
    c_edge_list = cur_df[['gvkey','cgvkey']].rename({'gvkey':'sup','cgvkey':'con'},axis=1).apply(lambda x: gvkey_to_idx_mapper(date,gvkey=x),axis=0)
    s_edge_list = c_edge_list[['con','sup']].rename({'sup':'con','con':'sup'},axis=1)
    
    return c_edge_list.values,s_edge_list.values

def get_hist_ret_df(date,gvkey_to_idx_mapper,permno_to_gvkey_mapper,gvkey_universe,permno_universe):
    cur_df = weekly_hist_ret_df.loc[date]
    cur_df = cur_df[cur_df.index.isin(permno_universe)]
    cur_df.index = permno_to_gvkey_mapper(date,permno=cur_df.index)
    cur_df = cur_df[cur_df.index.isin(gvkey_universe)]
    cur_df.index = gvkey_to_idx_mapper(date,gvkey=cur_df.index)
    cur_df = cur_df.reindex(np.arange(500))
    return cur_df.values

def get_weekly_ret_df(date,gvkey_to_idx_mapper,permno_to_gvkey_mapper,gvkey_universe,permno_universe):
    cur_df = weekly_ret_df.loc[date]
    cur_df = cur_df[cur_df.index.isin(permno_universe)]
    cur_df.index = permno_to_gvkey_mapper(date,permno=cur_df.index)
    cur_df = cur_df[cur_df.index.isin(gvkey_universe)]
    cur_df.index = gvkey_to_idx_mapper(date,gvkey=cur_df.index)
    cur_df = cur_df.reindex(np.arange(500))
    return cur_df.values

In [25]:
print(get_permnogvkey_toidx_mapper(datetime(2012,1,1),gvkey=[7875,2184]))
print(get_idxto_permnogvkey_mapper(datetime(2012,1,1),[92,105],gvkey=True))


[92, 105]
[7875, 2184]


In [251]:
class GNNDataset(Dataset):
    def __init__(self):
        # self.universe = pd.read_pickle('data/universe.pkl')
        # self.mapper_df = pd.read_pickle('data/mapper_df.pkl')
        # self.hist_ret_df = pd.read_pickle('data/hist_ret_df.pkl.gz')
        # self.weekly_ret_df = pd.read_pickle('data/weekly_ret_df.pkl')
        # self.sector_df = pd.read_pickle('data/sector_df.pkl')
        # self.all_sectors = sector_df['gind'].dropna().astype(int).unique()
        # self.supchain_df = pd.read_pickle('data/supchain_df.pkl')
        super().__init__()
        self.date_lst = list(weekly_ret_df.loc['2012-01-01':'2022-01-01'].index.unique())
    def __len__(self):
        return len(self.date_lst)-1

    def __getitem__(self,idx):
        cur_date = self.date_lst[idx]
        next_date = self.date_lst[idx+1]
        cur_universe = get_gvkey_universe(cur_date)
        cur_permno_universe = get_permno_universe(cur_date)

        # Sector edges
        sector_edge_lst = get_sector_edges(cur_date,get_permnogvkey_toidx_mapper,cur_universe)
        # Consumer and supplier edges
        c_edge_lst, s_edge_lst = get_supply_chain_edges(cur_date,get_permnogvkey_toidx_mapper,cur_universe)

        # Individual stock features
        cur_hist_ret_df = get_hist_ret_df(cur_date,get_permnogvkey_toidx_mapper,get_mapper,cur_universe,cur_permno_universe)
        # Labels
        cur_weekly_ret_df = get_weekly_ret_df(next_date,get_permnogvkey_toidx_mapper,get_mapper,cur_universe,cur_permno_universe)

        mask = (np.isnan(cur_hist_ret_df).any(axis=1)) | (np.isnan(cur_weekly_ret_df))
        mask = ~mask

        return sector_edge_lst,c_edge_lst,s_edge_lst,\
            cur_hist_ret_df,cur_weekly_ret_df, mask
    
    

In [253]:
dataset = GNNDataset()
it = iter(dataset)
next(it)
next(it)

(array([[194, 325],
        [194, 125],
        [194, 277],
        ...,
        [ 61, 250],
        [ 61,  68],
        [250,  68]]),
 array([[216, 320],
        [216, 218],
        [ 88, 196],
        [333, 209],
        [333, 209],
        [333, 358],
        [333, 263],
        [333, 209],
        [306, 305],
        [306, 332],
        [332,   3],
        [419, 320],
        [118, 320],
        [178, 320],
        [  2, 320],
        [  2, 320],
        [  2, 320],
        [183, 249],
        [ 96, 196],
        [136, 320],
        [136, 320],
        [136, 320],
        [394, 320],
        [424, 320],
        [393, 320],
        [305,   3],
        [305, 282],
        [468, 148],
        [287, 320],
        [ 56,  78],
        [ 13, 148],
        [ 56, 232],
        [399, 308],
        [359, 209],
        [312, 320],
        [364, 128],
        [ 35, 482],
        [  6, 149],
        [  6, 322],
        [454, 320],
        [ 10, 148],
        [403,  75],
        [403, 320],
     

Ignore below

In [None]:
get_weekly_ret_df(date_lst[1],get_permnogvkey_toidx_mapper,get_mapper,cur_universe,cur_permno_universe)

In [165]:
dataset = GNNDataset()
it = iter(dataset)
next(it)
next(it)
np.max(next(it))

499

In [None]:
hist_ret_df.loc[datetime(2010,1,4)]

Unnamed: 0_level_0,annret_1,annvol_1,annret_5,annvol_5,annret_10,annvol_10,annret_21,annvol_21,annret_42,annvol_42,annret_63,annvol_63,annret_126,annvol_126
PERMNO,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1
10104,3.287340,,-0.038254,0.074304,0.213888,0.105945,0.054751,0.064146,0.024216,0.039024,0.013312,0.029485,0.003207,0.020415
10107,3.885840,,-0.012509,0.096912,0.114325,0.061271,0.022615,0.034766,0.015887,0.025850,0.014284,0.025950,0.004976,0.021882
10138,5.442192,,0.048344,0.097122,0.126701,0.066989,0.059446,0.051427,0.015909,0.037412,0.015982,0.044955,0.005592,0.032035
10145,7.392924,,0.123601,0.116363,0.052834,0.061361,0.009202,0.044124,0.017167,0.032108,0.008880,0.029630,0.005041,0.022064
10516,1.287720,,-0.030794,0.054109,0.076409,0.039563,-0.001673,0.037840,0.005520,0.032710,0.007138,0.028877,0.003010,0.023746
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
92121,-1.713096,,-0.203908,0.032341,-0.048535,0.045073,-0.020068,0.094363,0.005779,0.062299,-0.001387,0.056357,0.006126,0.040792
92293,-4.489884,,-0.247958,0.085527,-0.060077,0.065448,0.009275,0.062089,0.012841,0.044765,0.010730,0.036902,0.004972,0.028169
92602,6.065892,,0.095760,0.091937,0.034181,0.054357,0.009718,0.043091,0.004474,0.032328,0.002394,0.026910,0.002324,0.018986
92611,1.959300,,0.215641,0.058422,0.033602,0.060510,0.041111,0.052744,0.019500,0.031475,0.017170,0.027461,0.006209,0.021246


In [None]:
def get_heterodata

In [None]:
class GNNData(HeteroData):
    def __init__(self):
        super().__init__()
        self.universe = pd.read_pickle('data/universe.pkl')
        self.mapper_df = pd.read_pickle('data/mapper_df.pkl')
        self.hist_ret_df = pd.read_pickle('data/hist_ret_df.pkl.gz')
        self.sector_df = pd.read_pickle('data/sector_df.pkl')
        self.all_sectors = sector_df['gind'].dropna().astype(int).unique()
        self.supchain_df = pd.read_pickle('data/supchain_df.pkl')
        
    # Expected node keys: stock_features
    # Expected edge keys: supply_to, supply_from, sector, [correlated]
    def __cat_dim__(self, key, value, *args, **kwargs):
         if key == 'foo':
             return None
         else:
             return super().__cat_dim__(key, value, *args, **kwargs)

In [None]:
class GNNDataset(InMemoryDataset):
    def __init__(self,root=None, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.universe = pd.read_pickle('data/universe.pkl')
        self.mapper_df = pd.read_pickle('data/mapper_df.pkl')
        self.hist_ret_df = pd.read_pickle('data/hist_ret_df.pkl.gz')
        self.sector_df = pd.read_pickle('data/sector_df.pkl')
        self.all_sectors = sector_df['gind'].dropna().astype(int).unique()
        self.supchain_df = pd.read_pickle('data/supchain_df.pkl')
    
    @property
    def raw_file_names(self):
        return []

    @property
    def processed_file_names(self):
        return []

    def download(self):
        # Download to `self.raw_dir`.
        pass