# Shopping Cart Analysis with Neo4j

In [1]:
import pandas as pd 
import numpy as np
import matplotlib.pyplot as plt

## Data preparation and ingestion

In [2]:
df = pd.read_csv('./data/dataset_group.csv', delimiter=',', names=['date', 'transaction', 'product'])
df.head()

Unnamed: 0,date,transaction,product
0,2000-01-01,1,yogurt
1,2000-01-01,1,pork
2,2000-01-01,1,sandwich bags
3,2000-01-01,1,lunch meat
4,2000-01-01,1,all- purpose


In [3]:
transactions = df.groupby(['transaction', 'date'])['product'].apply(list).reset_index(name='products')
transactions.head()

Unnamed: 0,transaction,date,products
0,1,2000-01-01,"[yogurt, pork, sandwich bags, lunch meat, all-..."
1,2,2000-01-01,"[toilet paper, shampoo, hand soap, waffles, ve..."
2,3,2000-01-02,"[soda, pork, soap, ice cream, toilet paper, di..."
3,4,2000-01-02,"[cereals, juice, lunch meat, soda, toilet pape..."
4,5,2000-01-02,"[sandwich loaves, pasta, tortillas, mixes, han..."


In [4]:
from neo4j import GraphDatabase

# Inspiration taken from: https://towardsdatascience.com/create-a-graph-database-in-neo4j-using-python-4172d40f89c4
class Neo4jConnection:

    def __init__(self, uri, user, password):
        self.driver = GraphDatabase.driver(uri, auth=(user, password))

    def close(self):
        self.driver.close()

    def print_greeting(self, message):
        with self.driver.session() as session:
            greeting = session.execute_write(self._create_and_return_greeting, message)
            print(greeting)
    
    def query(self, query, parameters=None, db=None):
        assert self.driver is not None, "Driver not initialized!"
        session = None
        response = None
        try: 
            session = self.driver.session(database=db) if db is not None else self.driver.session() 
            response = list(session.run(query, parameters))
        except Exception as e:
            print("Query failed:", e)
        finally: 
            if session is not None:
                session.close()
        return response

In [5]:
conn = Neo4jConnection("bolt://localhost:7687", "neo4j", "!Random_Password1234")

In [6]:
# Create entities and apply unique constraints to avoid duplicates
conn.query('CREATE CONSTRAINT transactions IF NOT EXISTS ON (t:Transactions) ASSERT t.id IS UNIQUE')
conn.query('CREATE CONSTRAINT products IF NOT EXISTS ON (p:Products) ASSERT p.name IS UNIQUE')


[]

In [7]:
import time
def insert_data(query, rows, batch_size=10000):
    total = 0
    batch = 0
    start = time.time()
    result = None
        
    while batch * batch_size < len(rows):

        res = conn.query(query, 
                            parameters = {'rows': rows[batch*batch_size:(batch+1)*batch_size]})
        total += res[0]['total'] if len(res) > 0 else 0
        batch += 1
        result = {"total":total, 
                    "batches":batch, 
                    "time":time.time()-start}
        print(result)
            
    return result

def add_products(rows):
    query = '''
            UNWIND $rows AS row
            MERGE (:Products {name: row})
            RETURN count(*) as total
            '''
    return insert_data(query, rows)
    
def add_transactions(rows, batch_size=5000):
    query = '''
        UNWIND $rows as row
        MERGE (t:Transactions {id:row.transaction}) ON CREATE SET t.date = row.date
        
        WITH row, t
        UNWIND row.products AS product
        MATCH (p:Products {name: product})
        // connect products to transactions
        MERGE (p)-[:IN]->(t)
        // connect transactions to products
        MERGE (t)-[:CONTAINS]->(p)
        '''
        
    return insert_data(query, rows, batch_size)

In [8]:
products = df['product'].unique()
add_products(list(products))


{'total': 38, 'batches': 1, 'time': 0.4343256950378418}


{'total': 38, 'batches': 1, 'time': 0.4343256950378418}

In [9]:
transaction_transformed = transactions.to_dict(orient='records')
add_transactions(transaction_transformed)

{'total': 0, 'batches': 1, 'time': 7.65553092956543}


{'total': 0, 'batches': 1, 'time': 7.65553092956543}

In [10]:
conn.close()