<a href="https://colab.research.google.com/github/Vinothsuku/graphai-insights/blob/main/telecom_data_graphRAG_subgraph.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install neo4j faiss-cpu transformers torch numpy

Collecting neo4j
  Downloading neo4j-5.28.1-py3-none-any.whl.metadata (5.9 kB)
Collecting faiss-cpu
  Downloading faiss_cpu-1.11.0-cp311-cp311-manylinux_2_28_x86_64.whl.metadata (4.8 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)


In [2]:
from neo4j import GraphDatabase
import numpy as np
import random
import spacy
from spacy.training.example import Example
from spacy.util import minibatch, compounding
from sentence_transformers import SentenceTransformer
import faiss
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import re
from transformers import pipeline
import json

In [3]:
#neo4j configs
NEO4J_URI="neo4j+s://652b8c5a.databases.neo4j.io"
NEO4J_USERNAME="neo4j"
NEO4J_PASSWORD="6z2zbwT5fWNfRehB_nB64zFiM5oPMONsXBz8qNLb2HI"

In [4]:
driver = GraphDatabase.driver(NEO4J_URI, auth=(NEO4J_USERNAME, NEO4J_PASSWORD))

In [5]:
#get nodes and relationships from neo4j
def fetch_graph_data():
    with driver.session() as session:
        result = session.run("MATCH (p:person) RETURN p.oid AS oid, p.name AS name")
        nodes = [{"oid": record["oid"], "name": record["name"]} for record in result]

    with driver.session() as session:
        result = session.run("MATCH (p:person)-[r:USES_DEVICE]->(d:device) RETURN p.oid AS person_oid, d.oid AS device_oid, r.starttime AS starttime, r.endtime AS endtime")
        relationships = [{"person_oid": record["person_oid"], "device_oid": record["device_oid"], "starttime": record["starttime"], "endtime": record["endtime"]} for record in result]

    return nodes, relationships

nodes, relationships = fetch_graph_data()
print(f"Nodes: {len(nodes)}, Relationships: {len(relationships)}")


Nodes: 4, Relationships: 10000


In [6]:
def run_query(query):
    with driver.session() as session:
        result = session.run(query)
        return [record.data() for record in result]

#node labels
node_labels = run_query("CALL db.labels()")

#relationship types
rel_types = run_query("CALL db.relationshipTypes()")

#node property keys by label
node_props = run_query("""
MATCH (n)
WITH labels(n) AS label, keys(n) AS props
UNWIND label AS l
RETURN DISTINCT l AS node_label, collect(DISTINCT props) AS property_keys
LIMIT 100
""")

#relationship property keys
rel_props = run_query("""
MATCH ()-[r]->()
RETURN DISTINCT type(r) AS rel_type, collect(DISTINCT keys(r)) AS property_keys
LIMIT 100
""")

print("Node Labels:")
for item in node_labels:
    print(item)

print("\nRelationship Types:")
for item in rel_types:
    print(item)

print("\nNode Properties:")
for item in node_props:
    print(item)

print("\nRelationship Properties:")
for item in rel_props:
    print(item)


Node Labels:
{'label': 'person'}
{'label': 'device'}
{'label': 'phonenumber'}
{'label': 'celltower'}

Relationship Types:
{'relationshipType': 'USES_DEVICE'}
{'relationshipType': 'CALLS'}
{'relationshipType': 'CONNECTED_TO'}

Node Properties:
{'node_label': 'person', 'property_keys': [['oid', 'activities']]}
{'node_label': 'device', 'property_keys': [['oid']]}
{'node_label': 'phonenumber', 'property_keys': [['oid']]}
{'node_label': 'celltower', 'property_keys': [['oid']]}

Relationship Properties:
{'rel_type': 'USES_DEVICE', 'property_keys': [['endtime', 'starttime']]}
{'rel_type': 'CALLS', 'property_keys': [['endtime', 'starttime', 'duration']]}
{'rel_type': 'CONNECTED_TO', 'property_keys': [['endtime', 'starttime']]}


In [7]:
#get intents (multiple if relevant) and entities
classifier = pipeline("zero-shot-classification", model="facebook/bart-large-mnli")
ner_model = spacy.load("en_core_web_sm")

intent_labels = [
    "get_devices_by_person",
    "get_calls_by_person",
    "get_activities_by_person",
    "get_phonenumbers_by_person",
    "get_activities_by_celltower",
    "get_calls_between_people"
]

intent_keywords = {
    "get_devices_by_person": ["device", "devices", "used"],
    "get_calls_by_person": ["call", "calls", "called"],
    "get_activities_by_person": ["activity", "activities", "did"],
    "get_phonenumbers_by_person": ["phone number", "phone numbers", "number"],
    "get_activities_by_celltower": ["celltower", "tower", "connected"],
    "get_calls_between_people": ["between", "call between", "calls between"]
}

intent_to_entity_label = {
    "get_devices_by_person": ["person"],
    "get_calls_by_person": ["person"],
    "get_activities_by_person": ["person"],
    "get_phonenumbers_by_person": ["person"],
    "get_activities_by_celltower": ["celltower"],
    "get_calls_between_people": ["person"]
}

def classify_intents_and_entities(query):
    result = classifier(query, candidate_labels=intent_labels, multi_label=True)
    predicted_intents = [
        (label, score) for label, score in zip(result["labels"], result["scores"])
        if score >= 0.6
    ]

    #entity extraction
    doc = ner_model(query)
    raw_entities = [(ent.label_, ent.text) for ent in doc.ents]
    numbers = re.findall(r"\b\d+\b", query)
    for num in numbers:
        raw_entities.append(("CARDINAL", num))

    query_lower = query.lower()
    intent_entity_map = {}

    for intent, score in predicted_intents:
        if not any(kw in query_lower for kw in intent_keywords[intent]):
            continue

        expected_labels = intent_to_entity_label.get(intent, [])
        found = None
        for label, value in raw_entities:
            if any(e in label.lower() for e in expected_labels):
                found = value
                break
        if not found:
            match = re.search(r"(person|device|tower|number)\s+(\d+)", query_lower)
            if match and any(t in intent for t in match.group(1)):
                found = match.group(2)

        intent_entity_map[intent] = found

    return intent_entity_map


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/1.15k [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/1.63G [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

Device set to use cuda:0


In [8]:
query = "devices used by person 94 and phone numbers called"
intent = classify_intents_and_entities(query)
print(f"Identified intent: {intent}")

Identified intent: {'get_devices_by_person': '94', 'get_calls_by_person': '94', 'get_phonenumbers_by_person': '94'}


In [9]:
#generating data for Spacy - custom NER training
persons = [f"person {i}" for i in range(1, 50)]
devices = [f"device {i}" for i in range(1, 20)]
phonenumbers = [f"phonenumber {i}" for i in range(100, 150)]
celltowers = [f"celltower {i}" for i in range(20, 40)]
dates = [f"2004-09-{str(i).zfill(2)}" for i in range(1, 30)]

templates = [
    ("Which devices did {person} use?", [("PERSON", "person")]),
    ("What is the phone number used by {person}?", [("PERSON", "person")]),
    ("Show me calls between {phone1} and {phone2}", [("PHONENUMBER", "phone1"), ("PHONENUMBER", "phone2")]),
    ("List activities for {person}", [("PERSON", "person")]),
    ("Which celltower was used on {date}?", [("DATE", "date")]),
    ("Who used {device}?", [("DEVICE", "device")]),
    ("Which phone numbers were connected to {celltower}?", [("CELLTOWER", "celltower")]),
    ("Did {person} connect to {celltower} on {date}?", [("PERSON", "person"), ("CELLTOWER", "celltower"), ("DATE", "date")]),
    ("Between {date1} and {date2}, who used {device}?", [("DATE", "date1"), ("DATE", "date2"), ("DEVICE", "device")]),
]

TRAIN_DATA = []

for _ in range(300):
    template, entity_labels = random.choice(templates)

    phone1 = random.choice(phonenumbers)
    phone2 = random.choice([x for x in phonenumbers if x != phone1])

    entity_map = {
        "person": random.choice(persons),
        "device": random.choice(devices),
        "phone1": phone1,
        "phone2": phone2,
        "celltower": random.choice(celltowers),
        "date": random.choice(dates),
        "date1": random.choice(dates),
        "date2": random.choice(dates),
    }

    sentence = template.format(**entity_map)

    seen_spans = set()
    entities = []
    for label, key in entity_labels:
        val = entity_map[key]
        for match in re.finditer(re.escape(val), sentence):
            span = (match.start(), match.end(), label)
            if span not in seen_spans:
                entities.append(span)
                seen_spans.add(span)
            break

    if entities:
        TRAIN_DATA.append((sentence, {"entities": entities}))

In [10]:
TRAIN_DATA[20:30]

[('Which phone numbers were connected to celltower 39?',
  {'entities': [(38, 50, 'CELLTOWER')]}),
 ('What is the phone number used by person 16?',
  {'entities': [(33, 42, 'PERSON')]}),
 ('Who used device 18?', {'entities': [(9, 18, 'DEVICE')]}),
 ('Show me calls between phonenumber 103 and phonenumber 117',
  {'entities': [(22, 37, 'PHONENUMBER'), (42, 57, 'PHONENUMBER')]}),
 ('Between 2004-09-12 and 2004-09-06, who used device 4?',
  {'entities': [(8, 18, 'DATE'), (23, 33, 'DATE'), (44, 52, 'DEVICE')]}),
 ('List activities for person 22', {'entities': [(20, 29, 'PERSON')]}),
 ('Did person 27 connect to celltower 25 on 2004-09-04?',
  {'entities': [(4, 13, 'PERSON'), (25, 37, 'CELLTOWER'), (41, 51, 'DATE')]}),
 ('Which celltower was used on 2004-09-28?', {'entities': [(28, 38, 'DATE')]}),
 ('Between 2004-09-21 and 2004-09-06, who used device 19?',
  {'entities': [(8, 18, 'DATE'), (23, 33, 'DATE'), (44, 53, 'DEVICE')]}),
 ('Who used device 4?', {'entities': [(9, 17, 'DEVICE')]})]

In [11]:
#training custom NER model
nlp = spacy.blank("en")

if "ner" not in nlp.pipe_names:
    ner = nlp.add_pipe("ner")
else:
    ner = nlp.get_pipe("ner")


for _, annotations in TRAIN_DATA:
    for ent in annotations.get("entities"):
        ner.add_label(ent[2])

other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"]
with nlp.disable_pipes(*other_pipes):
    optimizer = nlp.begin_training()
    for itn in range(12):
        random.shuffle(TRAIN_DATA)
        losses = {}
        batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001))
        for batch in batches:
            for text, annotations in batch:
                doc = nlp.make_doc(text)
                example = Example.from_dict(doc, annotations)
                nlp.update([example], drop=0.3, losses=losses)
        print(f"Iteration {itn+1} - Losses: {losses}")

output_dir = "custom_ner_model"
nlp.to_disk(output_dir)
print(f"Model saved to {output_dir}")


Iteration 1 - Losses: {'ner': np.float32(438.3538)}
Iteration 2 - Losses: {'ner': np.float32(0.21570101)}
Iteration 3 - Losses: {'ner': np.float32(0.00039316108)}
Iteration 4 - Losses: {'ner': np.float32(0.0012614067)}
Iteration 5 - Losses: {'ner': np.float32(0.0059129535)}
Iteration 6 - Losses: {'ner': np.float32(3.0511909)}
Iteration 7 - Losses: {'ner': np.float32(5.945227)}
Iteration 8 - Losses: {'ner': np.float32(11.053511)}
Iteration 9 - Losses: {'ner': np.float32(6.0718756)}
Iteration 10 - Losses: {'ner': np.float32(0.010025921)}
Iteration 11 - Losses: {'ner': np.float32(2.9227538)}
Iteration 12 - Losses: {'ner': np.float32(1.4020135e-05)}
Model saved to custom_ner_model


In [12]:
#checking entity extraction from custom ner model
nlp = spacy.load("custom_ner_model")
text = "did person 94 connected to tower 1 on 2024-09-04?"
doc = nlp(text)
print("Entities:", [(ent.text, ent.label_) for ent in doc.ents])


Entities: [('person 94', 'PERSON'), ('tower 1', 'CELLTOWER'), ('2024-09-04', 'DATE')]


In [13]:
text = "is phonenumber 2233 is from device nokia connected to celltower 26"
doc = nlp(text)
print("Entities:", [(ent.text, ent.label_) for ent in doc.ents])

Entities: [('phonenumber 2233', 'PHONENUMBER'), ('celltower 26', 'CELLTOWER')]


In [14]:
model = SentenceTransformer("all-MiniLM-L6-v2")

all_docs = []
all_metadata = []

def chunk_text(text, max_words=150):
    words = text.split()
    return [" ".join(words[i:i + max_words]) for i in range(0, len(words), max_words)]

with driver.session() as session:

    ###Person level Subgraphs
    person_query = "MATCH (p:person) RETURN p.oid AS oid"
    for row in session.run(person_query):
        oid = row["oid"]
        parts = [f"Person {oid}"]

        #Including activities
        activity_q = """
            MATCH (p:person {oid: $oid}) RETURN p.activities AS activities
        """
        activity_row = session.run(activity_q, oid=oid).single()
        if activity_row and activity_row["activities"]:
            activities = activity_row["activities"]
            if isinstance(activities, list):
                parts.append("has activity logs: " + "; ".join(str(a) for a in activities))
            else:
                parts.append(f"has activity logs: {activities}")

        #device usage
        dev_q = """
            MATCH (p:person {oid: $oid})-[r:USES_DEVICE]->(d:device)
            RETURN d.oid AS device, r.starttime AS starttime, r.endtime AS endtime
        """
        for dev in session.run(dev_q, oid=oid):
            parts.append(f"used device {dev['device']} from {dev['starttime']} to {dev['endtime']}.")

        #Multi hop calls
        call_q = """
            MATCH (p:person {oid: $oid})-[:USES_DEVICE]->(d:device)
            WITH d.oid AS phone_oid
            MATCH (ph:phonenumber {oid: phone_oid})-[:CALLS]->(dst:phonenumber)
            RETURN ph.oid AS src, dst.oid AS dst
        """
        for call in session.run(call_q, oid=oid):
            parts.append(f"via phone {call['src']} called {call['dst']}.")

        #Multi hop calls
        mob_q = """
            MATCH (p:person {oid: $oid})-[:USES_DEVICE]->(d:device)
            WITH d.oid AS phone_oid
            MATCH (ph:phonenumber {oid: phone_oid})-[r:CONNECTED_TO]->(c:celltower)
            RETURN c.oid AS celltower, r.starttime AS starttime, r.endtime AS endtime
        """
        for mob in session.run(mob_q, oid=oid):
            parts.append(f"connected to celltower {mob['celltower']} from {mob['starttime']} to {mob['endtime']}.")

        full_text = " ".join(parts)
        chunks = chunk_text(full_text)
        all_docs.extend(chunks)
        all_metadata.extend([{"label": "person", "oid": oid}] * len(chunks))

    ###Device level Subgraphs
    device_query = "MATCH (d:device) RETURN d.oid AS oid"
    for row in session.run(device_query):
        oid = row["oid"]
        parts = [f"Device {oid}"]

        usage_q = """
            MATCH (p:person)-[r:USES_DEVICE]->(d:device {oid: $oid})
            RETURN p.oid AS person, r.starttime AS starttime, r.endtime AS endtime
        """
        for usage in session.run(usage_q, oid=oid):
            parts.append(f"was used by person {usage['person']} from {usage['starttime']} to {usage['endtime']}.")

        call_q = """
            MATCH (ph:phonenumber {oid: $oid})-[:CALLS]->(dst:phonenumber)
            RETURN dst.oid AS dst
        """
        for call in session.run(call_q, oid=oid):
            parts.append(f"enabled call to {call['dst']}.")

        tower_q = """
            MATCH (ph:phonenumber {oid: $oid})-[r:CONNECTED_TO]->(c:celltower)
            RETURN c.oid AS celltower, r.starttime AS starttime, r.endtime AS endtime
        """
        for conn in session.run(tower_q, oid=oid):
            parts.append(f"connected to celltower {conn['celltower']} from {conn['starttime']} to {conn['endtime']}.")

        full_text = " ".join(parts)
        chunks = chunk_text(full_text)
        all_docs.extend(chunks)
        all_metadata.extend([{"label": "device", "oid": oid}] * len(chunks))

    ###Phone Number level Subgraphs
    phone_query = "MATCH (ph:phonenumber) RETURN ph.oid AS oid"
    for row in session.run(phone_query):
        oid = row["oid"]
        parts = [f"Phone number {oid}"]

        call_q = """
            MATCH (ph:phonenumber {oid: $oid})-[r:CALLS]->(dst:phonenumber)
            RETURN dst.oid AS dst, r.starttime AS starttime, r.endtime AS endtime, r.duration AS duration
        """
        for call in session.run(call_q, oid=oid):
            parts.append(f"called {call['dst']} from {call['starttime']} to {call['endtime']} lasting {call['duration']} seconds.")

        conn_q = """
            MATCH (ph:phonenumber {oid: $oid})-[r:CONNECTED_TO]->(c:celltower)
            RETURN c.oid AS celltower, r.starttime AS starttime, r.endtime AS endtime
        """
        for conn in session.run(conn_q, oid=oid):
            parts.append(f"connected to celltower {conn['celltower']} from {conn['starttime']} to {conn['endtime']}.")

        full_text = " ".join(parts)
        chunks = chunk_text(full_text)
        all_docs.extend(chunks)
        all_metadata.extend([{"label": "phonenumber", "oid": oid}] * len(chunks))

    ###Celltower level Subgraphs
    tower_query = "MATCH (c:celltower) RETURN c.oid AS oid"
    for row in session.run(tower_query):
        oid = row["oid"]
        parts = [f"Celltower {oid}"]

        conn_q = """
            MATCH (ph:phonenumber)-[r:CONNECTED_TO]->(c:celltower {oid: $oid})
            RETURN ph.oid AS phonenumber, r.starttime AS starttime, r.endtime AS endtime
        """
        for conn in session.run(conn_q, oid=oid):
            parts.append(f"was connected by phone {conn['phonenumber']} from {conn['starttime']} to {conn['endtime']}.")

        full_text = " ".join(parts)
        chunks = chunk_text(full_text)
        all_docs.extend(chunks)
        all_metadata.extend([{"label": "celltower", "oid": oid}] * len(chunks))

#Embedding documents
embeddings = model.encode(all_docs, convert_to_numpy=True)

#create and save FAISS index
dimension = embeddings.shape[1]
faiss_index = faiss.IndexFlatL2(dimension)
faiss_index.add(embeddings)
faiss.write_index(faiss_index, "graph_subgraph_entities.index")

#save chunks and metadata
np.save("graph_subgraph_entities_texts.npy", np.array(all_docs))
np.save("graph_embeddings_all_entities.npy", embeddings)
with open("graph_metadata_all_entities.json", "w") as f:
    json.dump(all_metadata, f)

print(f"FAISS index created. Total chunks: {len(all_docs)}")


modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

README.md:   0%|          | 0.00/10.5k [00:00<?, ?B/s]

sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]

Xet Storage is enabled for this repo, but the 'hf_xet' package is not installed. Falling back to regular HTTP download. For better performance, install the package with: `pip install huggingface_hub[hf_xet]` or `pip install hf_xet`


model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

FAISS index created. Total chunks: 2728


In [15]:
#LLM for cypher generation
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
model = AutoModelForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")

#Embedding model
embedding_model = SentenceTransformer("all-MiniLM-L6-v2")

#Load FAISS index
faiss_index = faiss.read_index("graph_subgraph_entities.index")
with open("graph_subgraph_entities_texts.npy", "rb") as f:
    all_docs = np.load(f, allow_pickle=True).tolist()
with open("graph_metadata_all_entities.json", "r") as f:
    all_metadata = json.load(f)

assert len(all_docs) == faiss_index.ntotal == len(all_metadata)


tokenizer_config.json:   0%|          | 0.00/1.29k [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/551 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/608 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.20G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/124 [00:00<?, ?B/s]

In [16]:
#query embedding
def get_query_embedding(query):
    return embedding_model.encode([query]).astype("float32")

In [17]:
def search_faiss_index(query_embedding, top_k=20, entity_filter=None, faiss_k=50):
    distances, indices = faiss_index.search(query_embedding, faiss_k)
    retrieved_docs = [all_docs[i] for i in indices[0]]
    retrieved_meta = [all_metadata[i] for i in indices[0]]

    candidates = list(zip(retrieved_docs, retrieved_meta))

    if entity_filter:
        extended_filters = set()
        if isinstance(entity_filter, dict):
            for label, value in entity_filter.items():
                extended_filters.add(str(value).lower())
                extended_filters.add(f"{label.lower()} {value}".lower())
        elif isinstance(entity_filter, list):
            extended_filters = set(str(v).lower() for v in entity_filter)

        def relevance_score(doc):
            score = 0
            doc_lower = doc.lower()
            for f in extended_filters:
                if f in doc_lower:
                    score += 5 if ' ' in f else 2
            return score

        #keeping only non 0 relevance score
        filtered = [(doc, meta) for doc, meta in candidates if relevance_score(doc) > 0]
        filtered_sorted = sorted(filtered, key=lambda x: relevance_score(x[0]), reverse=True)
        return filtered_sorted[:top_k] if filtered_sorted else candidates[:top_k]

    return candidates[:top_k]


In [18]:
def extract_entities(text):
    doc = nlp(text)
    entity_dict = {}
    for ent in doc.ents:
        if ent.label_ not in entity_dict:
            entity_dict[ent.label_] = ent.text.strip()
        else:
            if isinstance(entity_dict[ent.label_], list):
                entity_dict[ent.label_].append(ent.text.strip())
            else:
                entity_dict[ent.label_] = [entity_dict[ent.label_], ent.text.strip()]
    return entity_dict


In [34]:
def generate_cypher_query(query, intent, entities, docs):
    examples = """
Example 1:
Query: Which devices did person 123 use?
Intent: get_devices_by_person
Entities: PERSON=123
Graph context:
- person 123 used device 10 from 2004-09-01 to 2004-09-10
- person 123 used device 12 from 2004-09-11 to 2004-09-12
Cypher:
MATCH (p:person {oid: '123'})-[r:USES_DEVICE]->(d:device)
RETURN d.oid AS device_id, r.starttime, r.endtime

Example 2:
Query: What phone numbers did person 456 call?
Intent: get_calls_by_person
Entities: PERSON=456
Graph context:
- person 456 called phonenumber 789 from 2004-09-01 to 2004-09-10
Cypher:
MATCH (p:person {oid: '456'})-[r:CALLS]->(ph:phonenumber)
RETURN ph.oid AS phonenumber, r.starttime, r.endtime, r.duration

Example 3:
Query: What activities are logged for person 789?
Intent: get_activities_by_person
Entities: PERSON=789
Graph context:
- person 789 has recorded activities between 2004-09-05 and 2004-09-08
Cypher:
MATCH (p:person {oid: '789'})
RETURN p.activities

Example 4:
Query: How many devices did person 94 use?
Intent: count_devices_by_person
Entities: PERSON=94
Graph context:
- person 94 used device 5 from 2004-09-01 to 2004-09-10
Cypher:
MATCH (p:person {oid: '94'})-[:USES_DEVICE]->(d:device)
RETURN COUNT(d) AS device_count

Example 5:
Query: What cell towers was person 321 connected to?
Intent: get_celltowers_by_person
Entities: PERSON=321
Graph context:
- person 321 was connected to celltower A from 2004-09-03 to 2004-09-04
- person 321 was connected to celltower B from 2004-09-05 to 2004-09-06
Cypher:
MATCH (p:person {oid: '321'})-[r:CONNECTED_TO]->(ct:celltower)
RETURN ct.oid AS celltower_id, r.starttime, r.endtime
"""

    facts = "\n".join(f"- {d}" for d, _ in docs[:5])
    normalized_entities = {k: v.split()[-1] for k, v in entities.items()}
    entity_context = ", ".join([f"{k}={v}" for k, v in normalized_entities.items()])

    prompt = f"""
You are a Neo4j Cypher expert. Given a user query, its mapped intent, extracted entities, and relevant graph facts, your task is to generate a valid and optimized Cypher query.

### SCHEMA DEFINITION ###
Node Labels:
- person
- device
- phonenumber
- celltower

Relationship Types:
- USES_DEVICE
- CALLS
- CONNECTED_TO

Node Properties:
- person: oid, activities
- device: oid
- phonenumber: oid
- celltower: oid

Relationship Properties:
- USES_DEVICE: starttime, endtime
- CALLS: starttime, endtime, duration
- CONNECTED_TO: starttime, endtime

Use exact property and label names. If a relationship includes time information (e.g., starttime, endtime), include it in the `RETURN` clause when relevant. When referring to a person's activities, use `p.activities`.

{examples}

### NEW QUERY ###
Query: {query}
Intent: {intent}
Entities: {entity_context}
Graph context:
{facts}
Cypher:
"""

    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
    #attention_mask = attention_mask.to(model.device)
    outputs = model.generate(
        inputs["input_ids"],
        max_new_tokens=100,
        do_sample=True,
        #attention_mask=attention_mask,
        #temperature=0.5,
        pad_token_id=tokenizer.eos_token_id
    )

    decoded = tokenizer.decode(outputs[0], skip_special_tokens=True)
    if "Cypher:" in decoded:
        decoded = decoded.split("Cypher:")[-1].strip()

    cypher_lines = decoded.splitlines()
    for i, line in enumerate(cypher_lines):
        if "MATCH" in line:
            return "\n".join(cypher_lines[i:]).strip()

    return decoded.strip()


In [None]:
query = "which tower is used highly by person 94"

In [20]:
query ="what activities were logged for person 94"

In [21]:
intent = classify_intents_and_entities(query)
intent

{'get_activities_by_person': '94'}

In [22]:
entities = extract_entities(query)
entities

{'PERSON': 'person 94'}

In [23]:
query_embedding = get_query_embedding(query)
#query_embedding

In [24]:
context = search_faiss_index(query_embedding, top_k=20,entity_filter=entities)
#context

In [25]:
context

[('Device 184 was used by person 94 from 2004-09-29T14:02:52 to 2004-09-29T14:02:52.',
  {'label': 'device', 'oid': 184}),
 ('Device 839 was used by person 94 from 2005-03-26T14:20:00 to 2005-03-26T14:25:12.',
  {'label': 'device', 'oid': 839}),
 ('used by person 94 from 2005-01-21T13:29:25 to 2005-01-21T14:19:17. was used by person 94 from 2005-02-28T16:08:04 to 2005-02-28T16:08:04.',
  {'label': 'device', 'oid': 35}),
 ('Device 199 was used by person 94 from 2004-10-03T13:31:16 to 2004-10-03T13:31:16.',
  {'label': 'device', 'oid': 199}),
 ('Device 729 was used by person 94 from 2005-02-21T11:31:45 to 2005-02-21T11:31:45.',
  {'label': 'device', 'oid': 729}),
 ('Device 174 was used by person 94 from 2004-09-20T16:06:35 to 2004-09-20T16:22:14. was used by person 94 from 2004-09-20T17:35:22 to 2004-09-20T17:46:11. was used by person 94 from 2004-10-25T16:13:18 to 2004-10-25T16:13:18. was used by person 94 from 2004-11-01T16:07:35 to 2004-11-01T16:23:15. was used by person 94 from 2004-

In [35]:
cypher_query = generate_cypher_query(query, intent, entities, context)

In [36]:
print(f"\nGenerated Cypher:\n{cypher_query}")


Generated Cypher:
MATCH (p:person {oid: '94'})-[:USES_DEVICE]->(d:device)
-> (f:activity {startime: $start_time, endtime: $end_time})
RETURN f

Notice that there are no use cases with an `IF` statement. We will add those with a future refactor. This query is now a bit shorter.


In [37]:
model

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=256, bias=False)
          (v_proj): Linear(in_features=2048, out_features=256, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (up_proj): Linear(in_features=2048, out_features=5632, bias=False)
          (down_proj): Linear(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb): 