In [1]:
import os
import pandas as pd
import urllib.request
import json 
from operator import itemgetter
import numpy as np
import requests
import pyreadr
import re
from sentence_transformers import SentenceTransformer
from fact_checking import fact_check_and_add
from collections import Counter
import graph_utils

  from tqdm.autonotebook import tqdm, trange


Connection to graph database established.


## Neo4J AuraDB Setup

### Resetting database

In [2]:
graph_utils.reset_graph()

In [3]:
graph_utils.reset_constraints()

## Schema Constraints

### Region Node

In [4]:
# m49 is key
graph_utils.execute_query('''
CREATE CONSTRAINT region_m49_key IF NOT EXISTS
FOR (r:Region) REQUIRE r.m49 IS NODE KEY''')

EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d2820d0>, keys=[])

In [5]:
# name is unique
graph_utils.execute_query('''
CREATE CONSTRAINT region_name_unique IF NOT EXISTS
FOR (r:Region) REQUIRE r.name IS UNIQUE''')

EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d265e50>, keys=[])

In [6]:
# name fulltext index
graph_utils.execute_query('''
CREATE FULLTEXT INDEX region_name_index IF NOT EXISTS
FOR (r:Region) ON EACH [r.name]''')

EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d282c90>, keys=[])

### Country Node

In [7]:
# iso3 is key
graph_utils.execute_query('''
CREATE CONSTRAINT country_iso3_key IF NOT EXISTS
FOR (c:Country) REQUIRE c.iso3 IS NODE KEY''')

EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d280ed0>, keys=[])

In [8]:
# iso2 is unique
graph_utils.execute_query('''
CREATE CONSTRAINT country_iso2_unique IF NOT EXISTS
FOR (c:Country) REQUIRE c.iso2 IS UNIQUE''')

EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d278590>, keys=[])

In [9]:
# name is unique
graph_utils.execute_query('''
CREATE CONSTRAINT country_name_unique IF NOT EXISTS
FOR (c:Country) REQUIRE c.name IS UNIQUE''')

EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d279610>, keys=[])

In [10]:
# aliases fulltext index
graph_utils.execute_query('''
CREATE FULLTEXT INDEX country_aliases_index IF NOT EXISTS
FOR (c:Country) ON EACH [c.aliases]''')

EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d279b10>, keys=[])

### Sector Node

In [11]:
# gics is key
graph_utils.execute_query('''
CREATE CONSTRAINT sector_gics_key IF NOT EXISTS
FOR (s:Sector) REQUIRE s.gics IS NODE KEY''')

EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d291c90>, keys=[])

In [12]:
# name is unique
graph_utils.execute_query('''
CREATE CONSTRAINT country_name_unique IF NOT EXISTS
FOR (c:Country) REQUIRE c.name IS UNIQUE''')

EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d2ac490>, keys=[])

### Industry Node

In [13]:
# gics is key
graph_utils.execute_query('''
CREATE CONSTRAINT industry_gics_key IF NOT EXISTS
FOR (i:Industry) REQUIRE i.gics IS NODE KEY''')

EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d2a3d50>, keys=[])

In [14]:
# name is unique
graph_utils.execute_query('''
CREATE CONSTRAINT industry_name_unique IF NOT EXISTS
FOR (i:Industry) REQUIRE i.name IS UNIQUE''')

EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d267ed0>, keys=[])

In [15]:
# description vector index
graph_utils.execute_query('''
CREATE VECTOR INDEX industry_description_index IF NOT EXISTS
FOR (i:Industry)
ON i.embedding
OPTIONS { indexConfig: {
 `vector.quantization.enabled`: false
}}''')

EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d2b8b10>, keys=[])

### Company Node

In [16]:
# ticker is key
graph_utils.execute_query('''CREATE CONSTRAINT company_ticker_key IF NOT EXISTS
FOR (c:Company) REQUIRE c.ticker IS NODE KEY''')

EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15b598810>, keys=[])

In [17]:
# names fulltext index
graph_utils.execute_query('''CREATE FULLTEXT INDEX company_names_index IF NOT EXISTS
FOR (c:Company) ON EACH [c.names]''')

EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d2a3850>, keys=[])

## Adding Initial Data

### Region Nodes

Our country and region nodes, as well as the IS_IN relationships from countries to regions, are sourced from [UNSD](https://unstats.un.org/unsd/methodology/m49/overview/).

In [18]:
df_m49 = pd.read_csv('data/UNSD_m49.csv', sep=';')

In [19]:
continents = df_m49[['Region Code', 'Region Name']]\
                    .dropna()\
                    .drop_duplicates()\
                    .rename(columns={
                        'Region Code': 'm49',
                        'Region Name': 'name'
                    })

In [20]:
subregions = df_m49[['Sub-region Code', 'Sub-region Name']]\
                    .dropna()\
                    .drop_duplicates()\
                    .rename(columns={
                        'Sub-region Code': 'm49',
                        'Sub-region Name': 'name'
                    })

In [21]:
itdregions = df_m49[['Intermediate Region Code', 'Intermediate Region Name']]\
                    .dropna()\
                    .drop_duplicates()\
                    .rename(columns={
                        'Intermediate Region Code': 'm49',
                        'Intermediate Region Name': 'name'
                    })

In [22]:
regions = pd.concat([continents, subregions, itdregions], ignore_index=True)\
            .astype({'m49': int})

In [23]:
param_dicts = regions.to_dict('records')
param_dicts[:5]

[{'m49': 2, 'name': 'Africa'},
 {'m49': 19, 'name': 'Americas'},
 {'m49': 142, 'name': 'Asia'},
 {'m49': 150, 'name': 'Europe'},
 {'m49': 9, 'name': 'Oceania'}]

In [24]:
graph_utils.execute_query_with_params("MERGE (:Region{m49: $m49, name: $name})",
                                      *param_dicts)

[EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d2c4210>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d2c4ad0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d2c5410>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d2c5e50>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d2c6a90>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d2c76d0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d2c8310>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d2c8f90>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d2c9c90>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15

### Country Nodes

In [25]:
countries = df_m49[['ISO-alpha3 Code', 'ISO-alpha2 Code', 'Country or Area']]\
                    .dropna()\
                    .drop_duplicates()\
                    .rename(columns={
                        'ISO-alpha3 Code': 'iso3',
                        'ISO-alpha2 Code': 'iso2',
                        'Country or Area': 'name'
                    })

In [26]:
param_dicts = countries.to_dict('records')
param_dicts[:5]

[{'iso3': 'DZA', 'iso2': 'DZ', 'name': 'Algeria'},
 {'iso3': 'EGY', 'iso2': 'EG', 'name': 'Egypt'},
 {'iso3': 'LBY', 'iso2': 'LY', 'name': 'Libya'},
 {'iso3': 'MAR', 'iso2': 'MA', 'name': 'Morocco'},
 {'iso3': 'SDN', 'iso2': 'SD', 'name': 'Sudan'}]

In [27]:
graph_utils.execute_query_with_params("MERGE (:Country{iso3: $iso3, name: $name, iso2: $iso2})",
                                      *param_dicts)

[EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d2f9350>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d2f9c90>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d2fa610>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d2fb090>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d2fb890>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15da00390>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15da00fd0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15da01c10>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15da02810>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15

### Country IS_IN Region Relationship

In [28]:
country_continent = df_m49[['ISO-alpha3 Code', 'Region Code']]\
                            .dropna()\
                            .drop_duplicates()\
                            .rename(columns={
                                'ISO-alpha3 Code': 'iso3',
                                'Region Code': 'm49'
                            })

In [29]:
country_subregion = df_m49[['ISO-alpha3 Code', 'Sub-region Code']]\
                            .dropna()\
                            .drop_duplicates()\
                            .rename(columns={
                                'ISO-alpha3 Code': 'iso3',
                                'Sub-region Code': 'm49'
                            })

In [30]:
country_itdregion = df_m49[['ISO-alpha3 Code', 'Intermediate Region Code']]\
                            .dropna()\
                            .drop_duplicates()\
                            .rename(columns={
                                'ISO-alpha3 Code': 'iso3',
                                'Intermediate Region Code': 'm49'
                            })

In [31]:
country_region = pd.concat([country_continent, country_subregion, country_itdregion], ignore_index=True)

In [32]:
param_dicts = country_region.to_dict('records')
param_dicts[:5]

[{'iso3': 'DZA', 'm49': 2.0},
 {'iso3': 'EGY', 'm49': 2.0},
 {'iso3': 'LBY', 'm49': 2.0},
 {'iso3': 'MAR', 'm49': 2.0},
 {'iso3': 'SDN', 'm49': 2.0}]

In [33]:
graph_utils.execute_query_with_params('''
MATCH
    (c:Country{iso3: $iso3}),
    (r:Region{m49: $m49})
MERGE (c)-[:IS_IN]->(r)''', *param_dicts)

[EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d2ee690>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15d2ef1d0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15db3d450>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15db5b150>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15db6c250>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15db6d310>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15db6e3d0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15db6f3d0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15db74450>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15

### Country Aliases Property

Alternative names for countries. Source: [Kaggle](https://www.kaggle.com/datasets/wbdill/country-aliaseslist-of-alternative-country-names)/[Wikipedia](https://en.wikipedia.org/wiki/List_of_alternative_country_names)

In [34]:
df_alias = pd.read_csv('data/country_aliases.csv')

In [35]:
def split_alias(row):
    '''
    Splits a row if Alias contains multiple aliases seperated by " or "
    '''
    if ' or ' in row['Alias']:
        values = row['Alias'].split(' or ')
        return pd.DataFrame({'iso3': [row['iso3']] * len(values), 'Alias': values})
    return pd.DataFrame({'iso3': [row['iso3']], 'Alias': [row['Alias']]})

In [36]:
aliases = pd.concat([split_alias(row) for _, row in df_alias.iterrows()],
                  ignore_index=True)\
        .dropna()\
        .drop_duplicates()\
        .rename(columns={'Alias': 'alias'})

In [37]:
param_dicts = aliases.to_dict('records')
param_dicts[:5]

[{'iso3': 'AFG', 'alias': 'Afghanistan'},
 {'iso3': 'AFG', 'alias': 'Islamic Republic of Afghanistan'},
 {'iso3': 'AFG', 'alias': 'Da Afganistan Islami Jumhoryat'},
 {'iso3': 'AFG', 'alias': 'Jomhuriyyeh Eslamiyyeh Afganestan'},
 {'iso3': 'ALB', 'alias': 'Albania'}]

In [38]:
graph_utils.execute_query_with_params('''
MERGE (c:Country {iso3: $iso3})
SET c.aliases = 
    CASE
        WHEN c.aliases IS NULL THEN [$alias]
        WHEN NOT $alias IN c.aliases THEN c.aliases + $alias
        ELSE c.aliases
    END''', *param_dicts)

[EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15db48dd0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15db496d0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15db49fd0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15db4a9d0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15db4b5d0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15db58210>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15db58dd0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15db59990>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15db5a510>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15

### Country Stats

Yearly stats for each country are sourced from [Worldbank](data.worldbank.org). Corporate Tax Rates sourced from [Tax Foundation](https://taxfoundation.org/data/all/global/corporate-tax-rates-by-country-2023).

In [39]:
def get_worldbank(indicator: str) -> pd.DataFrame:
    '''
    Get indicator data using worldbank API
    '''
    with urllib.request.urlopen(f"https://api.worldbank.org/v2/country/all/indicator/{indicator}?format=json&per_page=20000") as url:
        data = json.load(url)[1]
    ind = data[0]['indicator']['value']
    iso3 = map(itemgetter('countryiso3code'), data)
    year = map(itemgetter('date'), data)
    value = map(itemgetter('value'), data)
    return pd.DataFrame({
        'iso3': iso3,
        'year': year,
        ind: value
    }).replace('', np.nan)\
      .dropna()\
      .set_index(['iso3', 'year'])

In [40]:
population = get_worldbank('SP.POP.TOTL')

In [41]:
gdp = get_worldbank('NY.GDP.MKTP.CD')

In [42]:
pv = get_worldbank('PV.EST')

In [43]:
ctr = pd.read_excel('data/corp_tax_rate.xlsx')\
        .melt(id_vars='iso_3',
              value_vars=range(1980, 2024),
              var_name='year',
              value_name='corporate_tax_rate')\
        .rename(columns={'iso_3': 'iso3'})\
        .astype({'year': str})\
        .set_index(['iso3', 'year'])

In [44]:
stats = pd.concat([population, gdp, pv, ctr], axis=1).sort_index()\
          .reset_index()\
          .rename(columns={
              'Population, total': 'population',
              'GDP (current US$)': 'gdp',
              'Political Stability and Absence of Violence/Terrorism: Estimate': 'pv',
              'corporate_tax_rate': 'corporate_tax_rate'
          })

We use 2022 stats for now

In [45]:
param_dicts = stats[stats['year'] == '2022'].to_dict('records')
param_dicts[:3]

[{'iso3': 'ABW',
  'year': '2022',
  'population': 106445.0,
  'gdp': 3544707788.05664,
  'pv': 1.47468435764313,
  'corporate_tax_rate': 25.0},
 {'iso3': 'AFE',
  'year': '2022',
  'population': 720859132.0,
  'gdp': 1183962133998.87,
  'pv': nan,
  'corporate_tax_rate': nan},
 {'iso3': 'AFG',
  'year': '2022',
  'population': 41128771.0,
  'gdp': 14502158192.0904,
  'pv': -2.5508017539978,
  'corporate_tax_rate': 20.0}]

In [46]:
graph_utils.execute_query_with_params('''
MATCH (c:Country {iso3: $iso3})
SET
    c.population = $population,
    c.gdp = $gdp,
    c.pv = $pv,
    c.corporate_tax_rate = $corporate_tax_rate''', *param_dicts)

[EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15e9af890>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15e9ac390>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15ea12fd0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15ea122d0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15ea11690>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15ea10a90>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15ea03fd0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15ea03590>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15ea02950>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15

### Sector Node


The data used for Sector/Industry nodes comes from [bautheac/GICS](https://github.com/bautheac/GICS) packages the Global Industry Classification Standards (GICS) dataset for consumption in R.  The GICS hierarchy begins with 11 sectors and is followed by 24 industry groups, 68 industries, and 157 sub-industries. 

In [47]:
url = 'https://github.com/bautheac/GICS/raw/0c2b0e4c0ca56a0e520301fd978fc095ed4fc328/data/standards.rda'
response = requests.get(url)

rda_file_path = './data/standards.rda'
with open(rda_file_path, 'wb') as file:
    file.write(response.content)

# Load the .rda file using pyreadr
result = pyreadr.read_r(rda_file_path)

print(result.keys())  

df = result[list(result.keys())[0]]  

# Save the DataFrame as a CSV file and remove the rda file
df.to_csv('./data/standards.csv', index=False)

os.remove(rda_file_path)

print("Data has been saved as standards.csv")

odict_keys(['standards'])
Data has been saved as standards.csv


In [48]:
# data wrangling for industry/sector

def wrangling(csv_path):
    df = pd.read_csv(csv_path)
    
    df = df.dropna()

    df = df.drop_duplicates()
    
    df = df.rename(columns={
        'sector id': 'sector_id',
        'sector name': 'sector_name',
        'industry group id': 'industry_group_id',
        'industry group name': 'industry_group_name',
        'industry id': 'industry_id',
        'industry name': 'industry_name',
        'subindustry id': 'subindustry_id',
        'subindustry name': 'subindustry_name',
        'description': 'primary_activity'
    })

    
    df['sector_id'] = df['sector_id'].astype('Int64')  
    df['industry_group_id'] = df['industry_group_id'].astype('Int64')
    df['industry_id'] = df['industry_id'].astype('Int64')
    df['subindustry_id'] = df['subindustry_id'].astype('Int64')

    df.reset_index(drop=True, inplace=True)
    df.index += 1

    return df

df_standards = wrangling("./data/standards.csv")

In [49]:
sector = df_standards[['sector_id', 'sector_name']] \
        .drop_duplicates() \
        .rename(columns={
            'sector_id': 'gics',
            'sector_name': 'name'
        })

In [50]:
param_dicts = sector.to_dict('records')
param_dicts[:5]

[{'gics': 10, 'name': 'Energy'},
 {'gics': 15, 'name': 'Materials'},
 {'gics': 20, 'name': 'Industrials'},
 {'gics': 25, 'name': 'Consumer Discretionary'},
 {'gics': 30, 'name': 'Consumer Staples'}]

In [51]:
graph_utils.execute_query_with_params("MERGE (:Sector{gics: $gics, name: $name})", *param_dicts)

[EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15ea29b90>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15ea29250>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15ea28950>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15ea139d0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15eb53e90>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15eb529d0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15eb51d90>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15eb51550>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15eb50790>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15

### Industry Node

In [52]:
industry = df_standards[['subindustry_id', 'subindustry_name', 'primary_activity']] \
           .drop_duplicates() \
           .rename(columns={
               'subindustry_id': 'gics',
               'subindustry_name': 'name',
               'primary_activity': 'description'
            })

In [53]:
EMBEDDING_MODEL = SentenceTransformer("all-MiniLM-L6-v2")

In [54]:
industry_desc_embed = EMBEDDING_MODEL.encode(industry['description'].to_numpy())
industry['embedding'] = list(map(list, industry_desc_embed))

In [55]:
param_dicts = industry.to_dict('records')

In [56]:
graph_utils.execute_query_with_params("MERGE (:Industry{gics: $gics, name: $name, description: $description, embedding: $embedding})", *param_dicts)

[EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15f3d3e90>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15f3d35d0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15f3d2c50>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15f3d22d0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15f3d1610>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15f3d0a10>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15f3c3e50>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15f3c32d0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15f3c26d0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15

### Industry PART_OF Sector Relationship

In [57]:
industry_sector = df_standards[['subindustry_id', 'sector_id']] \
                  .drop_duplicates() \
                  .rename(columns={
                      'subindustry_id': 'industry_gics',
                      'sector_id': 'sector_gics'
                  })

In [58]:
param_dicts = industry_sector.to_dict('records')
param_dicts[:5]

[{'industry_gics': 10101010, 'sector_gics': 10},
 {'industry_gics': 10101020, 'sector_gics': 10},
 {'industry_gics': 10102010, 'sector_gics': 10},
 {'industry_gics': 10102020, 'sector_gics': 10},
 {'industry_gics': 10102030, 'sector_gics': 10}]

In [59]:
graph_utils.execute_query_with_params('''
MATCH
    (i:Industry{gics: $industry_gics}),
    (s:Sector{gics: $sector_gics})
MERGE (i)-[:PART_OF]->(s)''', *param_dicts)

[EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15f1da350>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15f1b42d0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15f1ab2d0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15f1aa790>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15f1a90d0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15f1a8410>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15f1e9a50>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15f11fb50>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15f11edd0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15

## Adding Company Data

### Import JSON files

In this part, import all the json files, `nasdaq_kg_schema.json`, `nasdaq_kg_schema_rank_1-10.json` and `nasdaq_kg_schema_rank_11-32.json`.

In [60]:
json_files = ['nasdaq_kg_schema.json', 'nasdaq_kg_schema_rank_1-10.json', 'nasdaq_kg_schema_rank_11-32.json']

def merge_json_files(json_list: list[str]) -> dict:

    merged_json = {
        "nodes": {},
        "relationships": {}
    }

    for file_path in json_list:
        with open(file_path, 'r') as file:
            try:
                data = json.load(file)
                
                if "nodes" in data:
                    for node_type, nodes in data["nodes"].items():
                        if node_type not in merged_json["nodes"]:
                            merged_json["nodes"][node_type] = []
                        merged_json["nodes"][node_type].extend(nodes)

               
                if "relationships" in data:
                    for relationship_type, relationships in data["relationships"].items():
                        if relationship_type not in merged_json["relationships"]:
                            merged_json["relationships"][relationship_type] = []
                        merged_json["relationships"][relationship_type].extend(relationships)

            except json.JSONDecodeError as e:
                print(f"Error decoding {file_path}: {e}")
                
            except FileNotFoundError:
                print("Error: The file was not found.")

    return merged_json

merged_json = merge_json_files(json_files)

### Data Validation

* This part is only for company nodes, where we want to ensure data integrity by checking the primary key ticker_code and ensuring other fields like company name adhere to specific formatting rules and constraints. 
* Validating ticker_code as a valid string consisting of 4 to 5 uppercase letters, is for accurate indexing in financial markets.
* The process identifies duplicates and keeps the more-info version of an entry over the other duplicates. 
* Standardizing company names into title case and eliminating special symbols maintains consistency in representation and usability.
* All data entry whose ticker_code and company_name does not meet these criteria will be removed to maintain a clean dataset for the reporting of company information.

In [61]:
def is_valid_ticker(ticker_code):
    """Helper function to check if the ticker code is valid (str, 4 to 5 letters, all upper case)."""
    return isinstance(ticker_code, str) and 4 <= len(ticker_code) <= 5 and ticker_code.isupper()

def remove_invalid_ticker_companies(data):
    """Remove companies whose ticker_code doesn't meet the 3, 4, or 5 letter criteria."""
    if isinstance(data, str):
        print("Warning: data is a string, attempting to load as JSON.")
        data = json.loads(data)  
    
    if "nodes" in data and "Company" in data["nodes"]:
        companies = data["nodes"]["Company"]
        filtered_companies = [company for company in companies if is_valid_ticker(company.get("ticker_code"))]
        data["nodes"]["Company"] = filtered_companies
    else:
        print("Warning: The expected structure is not found in the data.")

    return data

def is_more_comprehensive(entry1, entry2):
    """Helper function to determine which duplicate has more comprehensive details."""
    return sum(1 for v in entry1.values() if v) > sum(1 for v in entry2.values() if v)

def remove_duplicates(data):
    seen_tickers = {}
    for company in data["nodes"]["Company"]:   
        ticker = company.get("ticker_code", "")
        if ticker in seen_tickers:
            if is_more_comprehensive(company, seen_tickers[ticker]):
                seen_tickers[ticker] = company
        else:
            seen_tickers[ticker] = company 
    data["nodes"]["Company"] = list(seen_tickers.values())
    return data

def standarize_case(data):
    """function to standardize title case and no other special."""
    for company in data["nodes"]["Company"]:
        company["name"] = company["name"].title() 
        company["name"] = re.sub(r'[^a-zA-Z0-9\s&.-]', '',company["name"]) 
    return data

In [62]:
def company_validate(data): 
    data = remove_invalid_ticker_companies(data)
    # data = remove_duplicates(data)
    # data = standarize_case(data)
    print("Completed.")
    return data

validated_data = company_validate(merged_json)

Completed.


### Adding Company Nodes

In [63]:
companies = validated_data['nodes']['Company']
for company in companies:
    company['founded_year'] = company['founded_year'] or ""
companies[:5]

[{'name': 'Apple Inc.', 'ticker_code': 'AAPL', 'founded_year': ''},
 {'name': 'Apple', 'ticker_code': 'AAPL', 'founded_year': ''},
 {'name': 'AirPods', 'ticker_code': 'AAPL', 'founded_year': ''},
 {'name': 'Apple', 'ticker_code': 'AAPL', 'founded_year': ''},
 {'name': 'Major League Soccer (MLS)',
  'ticker_code': 'MLFB',
  'founded_year': ''}]

In [64]:
graph_utils.execute_query_with_params('''
MERGE (c:Company {ticker: $ticker_code})
SET c.names = 
    CASE
        WHEN c.names IS NULL THEN [$name]
        WHEN NOT $name IN c.names THEN c.names + $name
        ELSE c.names
    END,
    c.founded_year = $founded_year''', *companies)

[EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15ea02d10>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15ea03e50>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15ea13150>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15de619d0>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15f38c050>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15e3b8650>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15e89fd50>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15f3a8d90>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15e88f690>, keys=[]),
 EagerResult(records=[], summary=<neo4j._work.summary.ResultSummary object at 0x15

### Adding Product Nodes

In [65]:
# products = validated_data['nodes']['Product']
# products[:5]

In [66]:
# execute_query_with_params("MERGE (p:Product {name: $product_name})", *products)

### Company-Industry relationship

One of the key relationships defined in the data is "IS_INVOLVED_IN", which helps in modeling how companies are categorized based on their primary activities and the sectors they contribute to.


"IS_INVOLVED_IN": [
            {
                "company_name": "McKinsey &#38; Company",
                "industry_name": "management consulting"}]

In [67]:
company_industries = validated_data['relationships']['IS_INVOLVED_IN']
company_industries[:5]

[{'company_name': 'Apple', 'industry_name': 'accounting'},
 {'company_name': 'Apple', 'industry_name': 'technology'},
 {'company_name': 'Apple', 'industry_name': 'information technology'},
 {'company_name': 'Brightstar Corporation',
  'industry_name': 'wireless device services'},
 {'company_name': 'Applied Materials, Inc.',
  'industry_name': 'semiconductor equipment'}]

In [68]:
for company_industry in company_industries:
    industry_name = company_industry['industry_name']
    company_industry['embedding'] = EMBEDDING_MODEL.encode(industry_name).tolist()

In [69]:
query_results = graph_utils.execute_query_with_params("""
CALL db.index.fulltext.queryNodes('company_names_index', $company_name)
    YIELD node AS c, score AS company_score
CALL db.index.vector.queryNodes('industry_description_index', 10, $embedding)
    YIELD node AS i, score AS industry_score
WHERE company_score > 1
AND industry_score > 0.7
RETURN
    c.ticker AS ticker,
    i.gics AS gics""", *company_industries)

In [70]:
edges_to_add = []
for records, _, _ in query_results:
    for src_key_value, dst_key_value in records:
        edges_to_add.append((src_key_value, dst_key_value, "Company", "Industry", "IS_INVOLVED_IN"))

In [71]:
patterns = fact_check_and_add(edges_to_add)

Processing group: ('AMAT', 'Company', 'Industry', 'IS_INVOLVED_IN')
23 edges found in group.
0 patterns found for group.
Adding 23 filtered edges.

Processing group: ('CMCSA', 'Company', 'Industry', 'IS_INVOLVED_IN')
36 edges found in group.
1 patterns found for group.
Adding 0 filtered edges.

Processing group: ('KLAC', 'Company', 'Industry', 'IS_INVOLVED_IN')
35 edges found in group.
1 patterns found for group.
Adding 0 filtered edges.

Processing group: ('NVDA', 'Company', 'Industry', 'IS_INVOLVED_IN')
21 edges found in group.
1 patterns found for group.
Adding 0 filtered edges.

Processing group: ('GILD', 'Company', 'Industry', 'IS_INVOLVED_IN')
14 edges found in group.
1 patterns found for group.
Adding 0 filtered edges.

Processing group: ('LRCX', 'Company', 'Industry', 'IS_INVOLVED_IN')
19 edges found in group.
1 patterns found for group.
Adding 0 filtered edges.

Processing group: ('INTU', 'Company', 'Industry', 'IS_INVOLVED_IN')
15 edges found in group.
1 patterns found for gr

In [72]:
patterns

defaultdict(list,
            {('Company',
              'Industry',
              'IS_INVOLVED_IN'): [{'supp': 1,
               'conf': 1,
               'relations': [{'srcLabel': 'Industry',
                 'edgeLabel': 'PART_OF',
                 'dstLabel': 'Sector'},
                {'srcLabel': 'Industry',
                 'edgeLabel': 'PART_OF',
                 'dstLabel': 'Sector'}]}, {'supp': 1,
               'conf': 1,
               'relations': [{'srcLabel': 'Industry',
                 'edgeLabel': 'PART_OF',
                 'dstLabel': 'Sector'},
                {'srcLabel': 'Industry',
                 'edgeLabel': 'PART_OF',
                 'dstLabel': 'Sector'}]}, {'supp': 1,
               'conf': 1,
               'relations': [{'srcLabel': 'Industry',
                 'edgeLabel': 'PART_OF',
                 'dstLabel': 'Sector'},
                {'srcLabel': 'Industry',
                 'edgeLabel': 'PART_OF',
                 'dstLabel': 'Sector'}]}, {'supp