In [1]:
import pandas as pd
import neo4j
import psycopg2

class IvoryTradeGraph:
    def __init__(self, postgres_params, neo4j_uri, neo4j_user, neo4j_password, csv_file_path):
        self.postgres_connection = self.connect_to_postgres(**postgres_params)
        self.neo4j_session = self.connect_to_neo4j(neo4j_uri, neo4j_user, neo4j_password)
        self.csv_file_path = csv_file_path
        self.data = None
        self.selected_year = None
        self.graph_type = None 
        self.relationship_type = None

    def connect_to_postgres(self, user, password, host, port, database):
        return psycopg2.connect(
            user=user,
            password=password,
            host=host,
            port=port,
            database=database
        )

    def connect_to_neo4j(self, uri, user, password):
        driver = neo4j.GraphDatabase.driver(uri, auth=(user, password))
        return driver.session()

    def select_query_pandas(self, query, rollback_before=False, rollback_after=False):
        if rollback_before:
            self.postgres_connection.rollback()
        
        df = pd.read_sql_query(query, self.postgres_connection)
        
        if rollback_after:
            self.postgres_connection.rollback()

        for column in df:
            if df[column].dtype == "float64":
                if not any(df[column] % 1):
                    df[column] = df[column].astype('Int64')
        
        return df
    
    def group_dataframe(self, df):
        df = df[df.year == self.selected_year]
        agg_functions = {'standardized_quantity': 'sum'}
        df_grouped = df.groupby(['exporter', 'importer']).aggregate(agg_functions).reset_index().sort_values(by = ['standardized_quantity'], ascending = False).reset_index(drop = True)
        
        return df_grouped
    
    def check_and_load_sql_data(self):
        check_query = """
        SELECT EXISTS (
            SELECT FROM information_schema.tables
            WHERE table_name = 'ivory_trade'
        );
        """
        check = self.select_query_pandas(check_query, True, True).iloc[0, 0]

        if not check:
            print("No ivory trade table found in the database. Creating it now.....")
            self.postgres_connection.rollback()
            create_query = f"""
            CREATE TABLE ivory_trade (
                year SMALLINT,
                taxon VARCHAR(255),
                term VARCHAR(255),
                importer VARCHAR(255),
                exporter VARCHAR(255),
                standardized_quantity NUMERIC,
                standardized_unit VARCHAR(255)
            );
            COPY ivory_trade
            FROM '{self.csv_file_path}' DELIMITER ',' CSV HEADER;
            """
            with self.postgres_connection.cursor() as cursor:
                cursor.execute(create_query)
            self.postgres_connection.commit()
            print('Table Created and Data Loaded Successfully!')
        else:
            print('The ivory trade table already exists in postgres.')

        self.sql_data = self.select_query_pandas("SELECT * FROM ivory_trade", True, True)        

    def run_neo4j_query(self, query, **params):
        result = self.neo4j_session.run(query, **params)
        return pd.DataFrame([record.values() for record in result], columns=result.keys())

    def select_year(self, year):
        self.selected_year = year
        print(f'Creating graph for Year {self.selected_year}')
        self.data = self.group_dataframe(self.sql_data)

    def set_graph_type(self, graph_type):
        if graph_type.lower() not in ['exporter', 'importer']:
            raise ValueError("Graph type must be either 'exporter' or 'importer'")
            
        self.graph_type = graph_type.lower()
    
    def set_relationship_type(self):
        if self.graph_type == 'exporter':
            self.relationship_type = 'EXPORTS_TO'
        else:
            self.relationship_type = 'IMPORTS_FROM'
            
        print(f'Creating the {self.graph_type.capitalize()} Graph')

    def wipe_out_graph_db(self):
        query = "MATCH (n) DETACH DELETE n"
        self.neo4j_session.run(query)
        print('Neo4j memory has been wiped clean!')

    def create_nodes(self):
        unique_countries = pd.concat([self.data['importer'], self.data['exporter']]).unique()
        for country in unique_countries:
            if pd.notna(country):
                query = "MERGE (:Country {name: $country_name})"
                self.neo4j_session.run(query, country_name=country)
                
    def my_neo4j_create_relationship_one_way(self, from_country, to_country, quantity, year):
        query = f"""
        MATCH (from:Country), 
              (to:Country)
        WHERE from.name = $from_country and to.name = $to_country
        CREATE (from)-[:{self.relationship_type} {{quantity: $quantity, year: $year, graph_type: '{self.graph_type}'}}]->(to)
        """
        self.neo4j_session.run(query, from_country=from_country, to_country=to_country, quantity=quantity, year=year)

    def create_relationships(self):
        if self.selected_year is None or self.graph_type is None:
            raise ValueError("Please select a year and graph type first")

        for _, row in self.data.iterrows():
            exporter = row['exporter']
            importer = row['importer']
            quantity = row['standardized_quantity']

            if self.graph_type == 'exporter':
                self.my_neo4j_create_relationship_one_way(exporter, importer, quantity, self.selected_year)
            else:
                self.my_neo4j_create_relationship_one_way(importer, exporter, quantity, self.selected_year)
                
    def print_graph_statistics(self):
        # Print the number of nodes
        node_query = "MATCH (n) RETURN n.name AS node_name, labels(n) AS labels ORDER BY n.name"
        node_df = self.run_neo4j_query(node_query)
        number_nodes = node_df.shape[0]
        print("-------------------------")
        print("Nodes:", number_nodes)

        # Print the number of relationships
        relationship_query = """
        MATCH (n1)-[r]->(n2)
        RETURN n1.name AS node_name_1, labels(n1) AS node_1_labels, 
               type(r) AS relationship_type, n2.name AS node_name_2, labels(n2) AS node_2_labels
        ORDER BY node_name_1, node_name_2
        """
        relationship_df = self.run_neo4j_query(relationship_query)
        number_relationships = relationship_df.shape[0]
        print("Relationships:", number_relationships)
        print("-------------------------")

        # Print the graph density
        if number_nodes > 1:
            density = (2 * number_relationships) / (number_nodes * (number_nodes - 1))
            print("Density:", f'{density:.2f}')
        else:
            print("Density: N/A (requires at least 2 nodes)")
        print("-------------------------")

    def project_graph(self, graph_name):
        self.neo4j_session.run(f"CALL gds.graph.drop('{graph_name}', false)")
        print(f"Dropped graph {graph_name}")

        query = f"""
        CALL gds.graph.project(
            '{graph_name}',
            'Country',
            '{self.relationship_type}',
            {{relationshipProperties: 'quantity'}})
        """
        self.neo4j_session.run(query)

    def run_page_rank(self, graph_name, max_iterations=20, damping_factor=0.05):
        query = f"""
        CALL gds.pageRank.stream(
            '{graph_name}',
            {{ maxIterations: $max_iterations, dampingFactor: $damping_factor }}
        )
        YIELD nodeId, score
        WITH gds.util.asNode(nodeId) AS node, score
        SET node.pageRank = score
        RETURN node.name AS name, score AS page_rank
        ORDER BY page_rank DESC
        """
        print('Calculating and storing PageRank...')
        return self.run_neo4j_query(query, max_iterations=max_iterations, damping_factor=damping_factor)

    def run_harmonic(self, graph_name):
        query = f"""
        CALL gds.alpha.closeness.harmonic.stream('{graph_name}', {{}})
        YIELD nodeId, centrality
        WITH gds.util.asNode(nodeId) AS node, centrality
        SET node.closeness = centrality
        RETURN node.name AS name, centrality AS harmonic_centrality
        ORDER BY harmonic_centrality DESC
        """
        print('Calculating and storing Harmonic centrality...')
        return self.run_neo4j_query(query)

    def run_betweenness(self, graph_name):
        query = f"""
        CALL {{
            WITH '{graph_name}' AS graphName
            CALL gds.betweenness.stream(graphName, {{relationshipWeightProperty: 'quantity'}})
            YIELD nodeId, score
            WITH gds.util.asNode(nodeId) AS node, score
            SET node.betweenness = score
            RETURN node.name AS name, score AS betweenness_centrality
            ORDER BY betweenness_centrality DESC
        }}
        RETURN name, betweenness_centrality
        """
        print('Calculating and storing Betweenness centrality...')
        return self.run_neo4j_query(query)

    def run_louvain(self, graph_name):
        query = f"""
        CALL gds.louvain.stream(
            '{graph_name}',
            {{includeIntermediateCommunities: true}})
        YIELD nodeId, communityId, intermediateCommunityIds
        RETURN gds.util.asNode(nodeId).name AS name, communityId AS community, intermediateCommunityIds AS intermediate_community
        ORDER BY community, name ASC
        """
        print("Louvain Modularity:")
        return self.run_neo4j_query(query)
    
    def wipe_out_mst_relationships(self):
        query = "MATCH ()-[r:MST]->() DELETE r"
        self.neo4j_session.run(query)

    def run_minimum_spanning_tree(self, graph_name, source_node_name):
        # Wipe out existing MST relationships
        self.wipe_out_mst_relationships()

        # Drop the graph if it already exists
        self.neo4j_session.run(f"CALL gds.graph.drop('{graph_name}', false)")

        # Project the graph with undirected relationships
        query = f"""
        CALL gds.graph.project(
            '{graph_name}',
            'Country',
            {{
                {self.relationship_type.upper()}: {{
                    properties: 'quantity',
                    orientation: 'UNDIRECTED'
                }}
            }}
        )
        """
        self.neo4j_session.run(query)

        # Compute MST creation metrics
        query = f"""
        MATCH (n:Country {{name: $source_node_name}})
        CALL gds.beta.spanningTree.write(
            '{graph_name}',
            {{
                sourceNode: id(n),
                relationshipWeightProperty: 'quantity',
                writeProperty: 'writeCost',
                writeRelationshipType: 'MST'
            }}
        )
        YIELD preProcessingMillis, computeMillis, writeMillis, effectiveNodeCount
        RETURN preProcessingMillis AS data_prep_time, computeMillis AS mst_compute_time, writeMillis AS write_to_graph_time, effectiveNodeCount AS mst_nodes_included
        """
        
        
        creation_metrics = self.run_neo4j_query(query, source_node_name=source_node_name)
        
        # Examine MST relationships
        query = """
        MATCH path = (n:Country {name: $source_node_name})-[:MST*]-()
        WITH relationships(path) AS rels
        UNWIND rels AS rel
        WITH DISTINCT rel AS rel
        RETURN startNode(rel).name AS source, endNode(rel).name AS destination, rel.writeCost AS cost
        """
        
        relationship_metrics = self.run_neo4j_query(query, source_node_name=source_node_name)
        
        self.wipe_out_mst_relationships()
        return creation_metrics, relationship_metrics
    
    def run_algos(self, graph_name, algo):
        if algo.lower() == 'pagerank':
            page_rank_results = self.run_page_rank(graph_name)
            display(page_rank_results.head(6).style.set_caption(f'{graph_name}'))
        elif algo.lower() == 'harmonic':
            harmonic_results = self.run_harmonic(graph_name)
            display(harmonic_results.head(6).style.set_caption(f'{graph_name}'))
        elif algo.lower() == 'between':
            betweenness_results = self.run_betweenness(graph_name)
            display(betweenness_results.head(6).style.set_caption(f'{graph_name}'))
        else:
            pass

#         louvain_results = self.run_louvain(graph_name)
#         display(louvain_results)
        
#         if not page_rank_results.empty:
#             db_analytics, mst_results = self.run_minimum_spanning_tree(graph_name, page_rank_results.iloc[0]['name'])
#             display(mst_results)
#         else:
#             print(f"No PageRank results available for {self.graph_type.capitalize()} in {self.selected_year}. Skipping MST.")

    def plot_graph(self, year, graph_type, algo=None):
        """
        Parameters:
            year (int)
            graph_type (str): 'exporter' or 'importer'
            algo (str): 'between', 'harmonic', 'pagerank'
        """
        
        self.check_and_load_sql_data()

        self.wipe_out_graph_db()

        self.select_year(year)
        self.set_graph_type(graph_type)
        self.set_relationship_type()
        
        self.create_nodes()
        self.create_relationships()
        self.print_graph_statistics()

        graph_name = f"G_{str(self.selected_year)[-2:]}_{self.graph_type.capitalize()}"
        self.project_graph(graph_name)
        
        if algo:
            self.run_algos(graph_name, algo)

            if algo == 'pagerank':
                algo_type = 'pageRank'
            elif algo == 'harmonic':
                algo_type = 'closeness'
            elif algo == 'between':
                algo_type = 'betweenness'

            print(f"""
            // Extract top 5 nodes based on selected algorithm
            MATCH (n)
            WHERE n.{algo_type} IS NOT NULL
            WITH n
            ORDER BY n.{algo_type} DESC
            LIMIT 5

            // Plot the top 5 nodes with relationships between them
            WITH collect(n) AS topNodes
            UNWIND topNodes AS tn1
            UNWIND topNodes AS tn2
            MATCH (tn1)-[r]->(tn2)
            RETURN tn1, r, tn2;
            """)
        else:
            print(f"""
            CYPHER QUERY:

            //The below query will return the top 5 countries in the network based on total quantity traded.

            MATCH (n)-[r:{self.relationship_type}]->(m)
            WHERE r.year = {self.selected_year}
            RETURN n, sum(r.quantity) AS total_quantity
            ORDER BY total_quantity DESC
            LIMIT 5
            """)

In [4]:
### Configs

file_name = "ivory_data.csv"

postgres_params = {
    'user': 'postgres',
    'password': 'ucb',
    'host': 'postgres',
    'port': '5432',
    'database': 'postgres'
}

neo4j_uri = "neo4j://neo4j:7687"
neo4j_user = "neo4j"
neo4j_password = "ucb_mids_w205"

graph = IvoryTradeGraph(postgres_params, neo4j_uri, neo4j_user, neo4j_password, file_name)

In [3]:
graph_types = ['exporter', 'importer']
algos = ['pagerank', 'between', 'harmonic']

for graph_type in graph_types:
    for algo in algos:
        graph.plot_graph(1988, graph_type, algo)

The ivory trade table already exists in postgres.
Neo4j memory has been wiped clean!
Creating graph for Year 1988
Creating the Exporter Graph
-------------------------
Nodes: 135
Relationships: 690
-------------------------
Density: 0.08
-------------------------
Dropped graph G_88_Exporter
Calculating and storing PageRank...


Unnamed: 0,name,page_rank
0,GB,1.973632
1,US,1.704455
2,CA,1.258811
3,HK,1.238496
4,FR,1.230911
5,IT,1.22351



            // Extract top 5 nodes based on selected algorithm
            MATCH (n)
            WHERE n.pageRank IS NOT NULL
            WITH n
            ORDER BY n.pageRank DESC
            LIMIT 5

            // Plot the top 5 nodes with relationships between them
            WITH collect(n) AS topNodes
            UNWIND topNodes AS tn1
            UNWIND topNodes AS tn2
            MATCH (tn1)-[r]->(tn2)
            RETURN tn1, r, tn2;
            
The ivory trade table already exists in postgres.
Neo4j memory has been wiped clean!
Creating graph for Year 1988
Creating the Exporter Graph
-------------------------
Nodes: 135
Relationships: 690
-------------------------
Density: 0.08
-------------------------
Dropped graph G_88_Exporter
Calculating and storing Betweenness centrality...


Unnamed: 0,name,betweenness_centrality
0,GB,7624.5
1,ZA,4730.5
2,MW,3897.0
3,DE,2714.0
4,BW,2424.5
5,ZW,2066.0



            // Extract top 5 nodes based on selected algorithm
            MATCH (n)
            WHERE n.betweenness IS NOT NULL
            WITH n
            ORDER BY n.betweenness DESC
            LIMIT 5

            // Plot the top 5 nodes with relationships between them
            WITH collect(n) AS topNodes
            UNWIND topNodes AS tn1
            UNWIND topNodes AS tn2
            MATCH (tn1)-[r]->(tn2)
            RETURN tn1, r, tn2;
            
The ivory trade table already exists in postgres.
Neo4j memory has been wiped clean!
Creating graph for Year 1988
Creating the Exporter Graph
-------------------------
Nodes: 135
Relationships: 690
-------------------------
Density: 0.08
-------------------------
Dropped graph G_88_Exporter
Calculating and storing Harmonic centrality...


Unnamed: 0,name,harmonic_centrality
0,US,0.533582
1,GB,0.528607
2,CA,0.452736
3,HK,0.445274
4,FR,0.44403
5,JP,0.442786



            // Extract top 5 nodes based on selected algorithm
            MATCH (n)
            WHERE n.closeness IS NOT NULL
            WITH n
            ORDER BY n.closeness DESC
            LIMIT 5

            // Plot the top 5 nodes with relationships between them
            WITH collect(n) AS topNodes
            UNWIND topNodes AS tn1
            UNWIND topNodes AS tn2
            MATCH (tn1)-[r]->(tn2)
            RETURN tn1, r, tn2;
            
The ivory trade table already exists in postgres.
Neo4j memory has been wiped clean!
Creating graph for Year 1988
Creating the Importer Graph
-------------------------
Nodes: 135
Relationships: 690
-------------------------
Density: 0.08
-------------------------
Dropped graph G_88_Importer
Calculating and storing PageRank...


Unnamed: 0,name,page_rank
0,ZW,2.067122
1,ZA,1.534254
2,MW,1.459534
3,DE,1.369669
4,CM,1.328331
5,CN,1.195371



            // Extract top 5 nodes based on selected algorithm
            MATCH (n)
            WHERE n.pageRank IS NOT NULL
            WITH n
            ORDER BY n.pageRank DESC
            LIMIT 5

            // Plot the top 5 nodes with relationships between them
            WITH collect(n) AS topNodes
            UNWIND topNodes AS tn1
            UNWIND topNodes AS tn2
            MATCH (tn1)-[r]->(tn2)
            RETURN tn1, r, tn2;
            
The ivory trade table already exists in postgres.
Neo4j memory has been wiped clean!
Creating graph for Year 1988
Creating the Importer Graph
-------------------------
Nodes: 135
Relationships: 690
-------------------------
Density: 0.08
-------------------------
Dropped graph G_88_Importer
Calculating and storing Betweenness centrality...


Unnamed: 0,name,betweenness_centrality
0,GB,7624.5
1,ZA,4730.5
2,MW,3897.0
3,DE,2714.0
4,BW,2424.5
5,ZW,2066.0



            // Extract top 5 nodes based on selected algorithm
            MATCH (n)
            WHERE n.betweenness IS NOT NULL
            WITH n
            ORDER BY n.betweenness DESC
            LIMIT 5

            // Plot the top 5 nodes with relationships between them
            WITH collect(n) AS topNodes
            UNWIND topNodes AS tn1
            UNWIND topNodes AS tn2
            MATCH (tn1)-[r]->(tn2)
            RETURN tn1, r, tn2;
            
The ivory trade table already exists in postgres.
Neo4j memory has been wiped clean!
Creating graph for Year 1988
Creating the Importer Graph
-------------------------
Nodes: 135
Relationships: 690
-------------------------
Density: 0.08
-------------------------
Dropped graph G_88_Importer
Calculating and storing Harmonic centrality...


Unnamed: 0,name,harmonic_centrality
0,ZW,0.633085
1,MW,0.562189
2,ZA,0.548507
3,CD,0.472015
4,DE,0.465796
5,BW,0.463308



            // Extract top 5 nodes based on selected algorithm
            MATCH (n)
            WHERE n.closeness IS NOT NULL
            WITH n
            ORDER BY n.closeness DESC
            LIMIT 5

            // Plot the top 5 nodes with relationships between them
            WITH collect(n) AS topNodes
            UNWIND topNodes AS tn1
            UNWIND topNodes AS tn2
            MATCH (tn1)-[r]->(tn2)
            RETURN tn1, r, tn2;
            
