In [5]:
import pandas as pd
from neo4j import GraphDatabase
from tqdm import tqdm


In [None]:
%pip install tqdm

Collecting tqdm
  Using cached tqdm-4.67.1-py3-none-any.whl.metadata (57 kB)
Using cached tqdm-4.67.1-py3-none-any.whl (78 kB)
Installing collected packages: tqdm
Successfully installed tqdm-4.67.1


In [6]:
DATA_PATH = "../data/raw/"

TRAIN_FILE = DATA_PATH + "train.txt"
VALID_FILE = DATA_PATH + "valid.txt"
TEST_FILE  = DATA_PATH + "test.txt"


In [7]:
def load_triples(path):
    df = pd.read_csv(path, sep="\t", header=None)
    df.columns = ["head", "relation", "tail"]
    return df

train_df = load_triples(TRAIN_FILE)
valid_df = load_triples(VALID_FILE)
test_df  = load_triples(TEST_FILE)

train_df.head()


Unnamed: 0,head,relation,tail
0,/m/027rn,/location/country/form_of_government,/m/06cx9
1,/m/017dcd,/tv/tv_program/regular_cast./tv/regular_tv_app...,/m/06v8s0
2,/m/07s9rl0,/media_common/netflix_genre/titles,/m/0170z3
3,/m/01sl1q,/award/award_winner/awards_won./award/award_ho...,/m/044mz_
4,/m/0cnk2q,/soccer/football_team/current_roster./sports/s...,/m/02nzb8


In [8]:
print("Train triples:", len(train_df))
print("Valid triples:", len(valid_df))
print("Test triples:", len(test_df))

print("Unique entities:", pd.concat([
    train_df["head"], train_df["tail"]
]).nunique())

print("Unique relations:", train_df["relation"].nunique())


Train triples: 272115
Valid triples: 17535
Test triples: 20466
Unique entities: 14505
Unique relations: 237


In [28]:
import os
from dotenv import load_dotenv

load_dotenv()

NEO4J_URI = os.getenv("NEO4J_URI")
USERNAME = os.getenv("USERNAME")
NEO4J_PASSWORD = os.getenv("NEO4J_PASSWORD")

driver = GraphDatabase.driver(
    "neo4j+s://cb4b49e2.databases.neo4j.io",
    auth=("neo4j", "nq5V7b0rJ4a9MvY3_ldBgHQQ4D1H3kdoMeWOXk1VaIM")
)


python-dotenv could not parse statement starting at line 2


In [14]:
%pip install langchain langchain-community langchain-neo4j

Collecting langchain-community
  Downloading langchain_community-0.4.1-py3-none-any.whl.metadata (3.0 kB)
Collecting aiohttp<4.0.0,>=3.8.3 (from langchain-community)
  Downloading aiohttp-3.13.2-cp313-cp313-win_amd64.whl.metadata (8.4 kB)
Collecting dataclasses-json<0.7.0,>=0.6.7 (from langchain-community)
  Using cached dataclasses_json-0.6.7-py3-none-any.whl.metadata (25 kB)
Collecting httpx-sse<1.0.0,>=0.4.0 (from langchain-community)
  Downloading httpx_sse-0.4.3-py3-none-any.whl.metadata (9.7 kB)
Collecting aiohappyeyeballs>=2.5.0 (from aiohttp<4.0.0,>=3.8.3->langchain-community)
  Using cached aiohappyeyeballs-2.6.1-py3-none-any.whl.metadata (5.9 kB)
Collecting aiosignal>=1.4.0 (from aiohttp<4.0.0,>=3.8.3->langchain-community)
  Using cached aiosignal-1.4.0-py3-none-any.whl.metadata (3.7 kB)
Collecting attrs>=17.3.0 (from aiohttp<4.0.0,>=3.8.3->langchain-community)
  Downloading attrs-25.4.0-py3-none-any.whl.metadata (10 kB)
Collecting frozenlist>=1.1.1 (from aiohttp<4.0.0,>=3.8.

In [25]:
from langchain_neo4j import Neo4jGraph

graph = Neo4jGraph(refresh_schema=False,username="neo4j",
                    password="nq5V7b0rJ4a9MvY3_ldBgHQQ4D1H3kdoMeWOXk1VaIM",
                   url="neo4j+s://cb4b49e2.databases.neo4j.io")

In [29]:
def create_constraints():
    query = """
    CREATE CONSTRAINT IF NOT EXISTS
    FOR (e:Entity)
    REQUIRE e.name IS UNIQUE
    """
    with driver.session() as session:
        session.run(query)

create_constraints()


In [30]:
def load_to_neo4j(df, batch_size=1000):
    query = """
    UNWIND $rows AS row
    MERGE (h:Entity {name: row.head})
    MERGE (t:Entity {name: row.tail})
    MERGE (h)-[:RELATION {type: row.relation}]->(t)
    """

    with driver.session() as session:
        for i in tqdm(range(0, len(df), batch_size)):
            batch = df.iloc[i:i+batch_size].to_dict("records")
            session.run(query, rows=batch)


In [31]:
load_to_neo4j(train_df)


100%|██████████| 273/273 [00:54<00:00,  5.03it/s]


In [32]:
with driver.session() as session:
    result = session.run("MATCH (n:Entity) RETURN count(n)")
    print(result.single())


<Record count(n)=14529>


In [33]:
query = """
MATCH (h:Entity)-[r]->(t:Entity)
RETURN h.name, r.type, t.name
LIMIT 10
"""

with driver.session() as session:
    for record in session.run(query):
        print(record)


<Record h.name=None r.type=None t.name=None>
<Record h.name=None r.type=None t.name=None>
<Record h.name=None r.type=None t.name=None>
<Record h.name=None r.type=None t.name=None>
<Record h.name=None r.type=None t.name=None>
<Record h.name=None r.type=None t.name=None>
<Record h.name=None r.type=None t.name=None>
<Record h.name=None r.type=None t.name=None>
<Record h.name=None r.type=None t.name=None>
<Record h.name=None r.type=None t.name=None>


In [38]:
query = """
MATCH (p:Entity {name: "person"})-[r]->(x)
RETURN p.name, r.type, x.name
"""

with driver.session() as session:
    for r in session.run(query):
        print(r)
