# Construct homogeneous stock graph using HiDy 

In [2]:
import os
import sys
sys.path.append('..')
import FinNeo 
import re
from tqdm import tqdm

## Connect to remote KB

In [3]:
fintech = FinNeo.FinNeo(url="bolt://143.89.126.57:5001", user='neo4j', pwd='csproject')

## Load stock list

In [7]:
path = '../../Data/csi300_stock_codes_22.txt'
with open(path,'r') as file:
    content = file.readlines()
csi = [line.strip() for line in content]
all_num_dict = [re.sub("[^0-9]", "", x) for x in csi]

## 3-level SW Industry 

In [None]:
#SW have 3 level industry

# SW level 1
SW_set_L1 = []
for i in tqdm(all_num_dict):
    query = 'match (n:company) -[r1]- (i:SW_industry) -[r2]- (c:company) where n.code contains \''+i+'\' return c.code as code '
    node = fintech.get_read_query(query)
    node = pd.DataFrame([dict(record) for record in node])
    node.replace('', np.nan, inplace=True)
    node = node.dropna().drop_duplicates()
    node = node.values.tolist()
    node = [x[0] for x in node]
    for n in node:
        if re.sub("[^0-9]", "", n) in all_num_dict:
            SW_set_L1.append([i,re.sub("[^0-9]", "", n) ])
            
# SW level 2      
SW_set_L2 = []
for i in tqdm(all_num_dict):
    query = 'match (n:company)-[r1]-(i1:SW_industry)-[r2]-(i2:SW_industry)-[r3]-(i3:SW_industry)-[r4]-(c:company) where n.code contains \''+i+'\' return c.code as code '
    node = fintech.get_read_query(query)
    node = pd.DataFrame([dict(record) for record in node])
    node.replace('', np.nan, inplace=True)
    node = node.dropna().drop_duplicates()
    node = node.values.tolist()
    node = [x[0] for x in node]
    for n in node:
        if re.sub("[^0-9]", "", n) in all_num_dict:
            SW_set_L2.append([i,re.sub("[^0-9]", "", n) ])

# SW level 3
SW_set_L3 = []
for i in tqdm(all_num_dict):
    query = 'match (n:company)-[r1]-(i1:SW_industry)-[r2]-(i2:SW_industry)-[r3]-(i3:SW_industry)-[r4]-(i4:SW_industry)-[r5]-(i5:SW_industry)-[r6]-(c:company) where n.code contains \''+i+'\' return c.code as code '
    node = fintech.get_read_query(query)
    node = pd.DataFrame([dict(record) for record in node])
    node.replace('', np.nan, inplace=True)
    node = node.dropna().drop_duplicates()
    node = node.values.tolist()
    node = [x[0] for x in node]
    for n in node:
        if re.sub("[^0-9]", "", n) in all_num_dict:
            SW_set_L3.append([i,re.sub("[^0-9]", "", n) ])


##  Get dynamic relation

In [None]:
query = 'match (n:company)-[r1]-(c:company) return r1'
node = fintech.get_read_query(query)
node = pd.DataFrame([dict(record) for record in node])
node.replace('', np.nan, inplace=True)
node = node.dropna().drop_duplicates()
node = node.values.tolist()
node = [re.findall(r'type=\'(.*?)\'', str(x[0])) for x in node]
node = [x[0] for x in node]
relation_type = list(set(node))

In [None]:
relation_type
dyset=[] 
for relation in relation_type:
    temp_set = []
    for i in tqdm(all_num_dict):
        query = 'match (n:company)-[r1:'+relation+']-(c:company) where n.code contains \''+i+'\' return c.code as code, r1.time as time '
        node = fintech.get_read_query(query)
        node = pd.DataFrame([dict(record) for record in node])
        node.replace('', np.nan, inplace=True)
        node = node.dropna().drop_duplicates()
        node = node.values.tolist()
        for n in node:
            if re.sub("[^0-9]", "", n[0]) in all_num_dict:
                temp_set.append([i,re.sub("[^0-9]", "", n[0]), n[1],relation])
        
    dyset.append(temp_set)
        

## Data preprocess 

In [None]:
def add_SH_SZ(value):
    if str(value).startswith('6'):
        return 'SH'+str(value)
    else:
        return "SZ" + str(value)

In [None]:
SW_new_csi = SW_set_L1 + SW_set_L2 + SW_set_L3
SW_new_csi = pd.DataFrame(SW_new_csi,columns=["e1","e2"])
SW_new_csi['r'] = 'SW_industry'

In [None]:
merged = []
for i in dyset:
    merged.extend(i)
dy_df = pd.DataFrame(merged,columns=["e1","e2","time",'r'])
dy_df['time'] = pd.to_datetime(dy_df['time'])  # 将'time'列转换为DatetimeIndex类型
dy_df = dy_df.sort_values('time')

In [None]:
SW_new_csi['e1'] = SW_new_csi['e1'].astype(str).str.zfill(6)
SW_new_csi['e2'] = SW_new_csi['e2'].astype(str).str.zfill(6)
SW_new_csi['e1'] = SW_new_csi['e1'].apply(add_SH_SZ)
SW_new_csi['e2'] = SW_new_csi['e2'].apply(add_SH_SZ)

dy_df['e1'] = dy_df['e1'].astype(str).str.zfill(6)
dy_df['e2'] = dy_df['e2'].astype(str).str.zfill(6)
dy_df['e1'] = dy_df['e1'].apply(add_SH_SZ)
dy_df['e2'] = dy_df['e2'].apply(add_SH_SZ)

In [None]:
SW_new_csi = SW_new_csi.drop_duplicates()
dy_df = dy_df.drop_duplicates()

In [None]:
df.to_csv('../../Data/csi22_dytuple.csv')
SW_new_csi.to_csv('../../Data/csi22_SW.csv')