In [1]:
import pandas as pd
import scipy as sc
import networkx as nx
import numpy as np
import json
import os
from tqdm import tqdm

In [13]:
feat_dim = 300

def text_wrap(text):
    return text.replace('######','') + '  '

def rel_orgs_wrap(text):
    return text.split('|') if text else []

node_dic = {}
def to_feat(k):
    if isinstance(k, int):
        # Document node
        return G.nodes[k]['feat']
    else:
        # Stock node
        return np.zeros(feat_dim)

class NewsDataSet:
  def __init__(self, row_path, feat_path, data_path_list):
    self.row_path = row_path
    self.feat_path = feat_path
    self.data_path_list = data_path_list
    self.index = 0
    #if len(self.X) != len(self.Y):
    #  raise Exception("The length of X does not match the length of Y")
  def __len__(self):
    return len(self.data_path_list)
  
  def __iter__(self):
    return self

  def __next__(self):
    if self.index >= len(self.data_path_list):
      raise StopIteration
    index = self.index
    self.index += 1
    row_path = os.path.join(self.row_path, self.data_path_list[index]+'.pkl')
    feat_path = os.path.join(self.feat_path, self.data_path_list[index]+'.npy')
    df = pd.read_pickle(row_path)
    feat = np.load(feat_path, allow_pickle=True)
    _df = pd.DataFrame()
    _df['id'] = df['id']
    _df['rel_orgs'] = df['rel_org_a_companies_code'].map(rel_orgs_wrap)
    _df['date'] = df['date_time'].apply(lambda x: x.date())
    _df['feat'] = _df.apply(lambda x: feat[x.name], axis=1)
    df = _df
    return df

In [14]:
data_sets = ['2020-1', '2020-2', '2020-3', '2020-4', '2020-5', '2020-6', '2020-7', '2020-8', '2020-9', '2020-10', '2020-11', '2020-12']
loader = NewsDataSet('data/news_row', 'data/news_embed',  data_sets)
G = nx.Graph()
for df_month in tqdm(loader):
    for _, node in df_month.iterrows():
        for org in node['rel_orgs']:
            G.add_edge(node.id, org)
            G.nodes[node.id]['feat'] = node['feat']
ssm = nx.convert_matrix.to_scipy_sparse_matrix(G)
feats = map(to_feat, G.nodes)
feats = np.stack(list(feats))
feats.shape

100%|██████████| 12/12 [11:22<00:00, 56.88s/it]


(620056, 300)

In [15]:
data_prefix = 'data/22mnews/'
os.path.exists(data_prefix ) or os.makedirs(data_prefix)
sc.sparse.save_npz(data_prefix+'adj_full.npz', ssm)
sc.sparse.save_npz(data_prefix+'adj_train.npz', ssm)
np.save(data_prefix+'feats.npy', feats)
json.dump(dict.fromkeys(range(G.number_of_nodes()), 0), open(data_prefix+'class_map.json', 'w'))
role = {'tr': list(range(G.number_of_nodes())), 'va': [], 'te': []}
json.dump(role, open(data_prefix+'role.json', 'w'))