# Intro to GraphRAG with Google Cloud Spanner Graph and LangChain

Spanner Graph now [integrates seamlessly with LangChain](https://cloud.google.com/python/docs/reference/langchain-google-spanner/latest#spanner-graph-store-usage), making it easier to build GraphRAG applications.

Instead of simply retrieving relevant text snippets based on keyword similarity, GraphRAG takes a more sophisticated, structured approach to Retrieval Augmented Generation. It involves creating a knowledge graph from the text, organizing it hierarchically, summarizing key concepts, and then using this structured information to enhance the accuracy and depth of responses.


### Objectives

In this tutorial, you will see a complete walkthrough of building a question-answering system using the GraphRAG method. You'll learn how to create a knowledge graph from scratch, store it efficiently in Spanner Graph, a functional FAQ system with Langchain agent.

Google Cloud [Spanner](https://cloud.google.com/spanner) is a highly scalable database that combines unlimited scalability with relational semantics, such as secondary indexes, strong consistency, schemas, and SQL providing 99.999% availability in one easy solution.

This notebook goes over how to use `Spanner Graph` for GraphRAG with the custom retriever `SpannerGraphVectorContextRetriever` and compares the response of GraphRAG with conventional RAG.

### Library Installation
The integration lives in its own `langchain-google-spanner` package, so we need to install it.

In [None]:
%pip install --quiet spanner-graph-notebook
%pip install --quiet google-cloud-spanner==3.57.0
%pip install --quiet langchain-google-spanner==0.9.0
%pip install --quiet langchain-google-vertexai==2.0.28
%pip install --quiet json-repair networkx==3.5 
%pip install --quiet langchain-text-splitters==0.3.9 
%pip install --quiet langchain-experimental==0.3.4

#### Automatically restart kernel after installs so that your environment can access the new packages

In [None]:
import IPython

app = IPython.Application.instance()
app.kernel.do_shutdown(True)

### Configure warnings

In [None]:
import warnings
warnings.filterwarnings("ignore", category=DeprecationWarning) 

In [None]:
EMBEDDING_MODEL = "text-embedding-004"
GENERATIVE_MODEL = "gemini-2.5-flash"

In [None]:
PROJECT_ID = !gcloud config list --format 'value(core.project)'
PROJECT_ID = PROJECT_ID[0]
REGION = "us-central1"
%env GOOGLE_CLOUD_PROJECT={PROJECT_ID}

### Spanner API Enablement
The `langchain-google-spanner` package requires that you [enable the Spanner API](https://console.cloud.google.com/flows/enableapi?apiid=spanner.googleapis.com) in your Google Cloud Project.

In [None]:
!gcloud services enable spanner.googleapis.com

## Usage

### Set Spanner database values
Find your database values, in the [Spanner Instances page](https://console.cloud.google.com/spanner?_ga=2.223735448.2062268965.1707700487-2088871159.1707257687).

In [None]:
INSTANCE = "graphrag-instance-v1"
DATABASE = "graphrag"
GRAPH_NAME = "retail_demo_graph"

In [None]:
!gcloud spanner instances create {INSTANCE} --config=regional-us-central1 --description="Graph RAG Instance" --nodes=1 --edition=ENTERPRISE

In [None]:
# prompt: create a spanner database and table to store the graph with nodes and edges created in graph
def create_database(project_id, instance_id, database_id):
    """Creates a database and tables for sample data."""
    from google.cloud.spanner_admin_database_v1.types import spanner_database_admin
    from google.cloud import spanner
    spanner_client = spanner.Client(project_id)
    database_admin_api = spanner_client.database_admin_api

    request = spanner_database_admin.CreateDatabaseRequest(
        parent=database_admin_api.instance_path(spanner_client.project, instance_id),
        create_statement=f"CREATE DATABASE `{database_id}`",
        extra_statements= [])

    operation = database_admin_api.create_database(request=request)

    print("Waiting for operation to complete...")
    OPERATION_TIMEOUT_SECONDS=60
    database = operation.result(OPERATION_TIMEOUT_SECONDS)

    print(
        "Created database {} on instance {}".format(
            database.name,
            database_admin_api.instance_path(spanner_client.project, instance_id)
        )
    )



In [None]:
from google.cloud import spanner

create_database(PROJECT_ID, INSTANCE, DATABASE)

### SpannerGraphStore

To initialize the `SpannerGraphStore` class you need to provide 3 required arguments and other arguments are optional and only need to pass if it's different from default ones

1.   a Spanner instance id;
2.   a Spanner database id belongs to the above instance id;
3.   a Spanner graph name used to create a graph in the above database.

In [None]:
from langchain_google_spanner import SpannerGraphStore

graph_store = SpannerGraphStore(
    instance_id=INSTANCE,
    database_id=DATABASE,
    graph_name=GRAPH_NAME,
)

#### Add Graph Documents
To add graph documents in the graph store.

In [None]:
# @title Load Documents
import os
from langchain_community.document_loaders import DirectoryLoader
from langchain_community.document_loaders import TextLoader
from langchain_core.documents import Document

!wget https://raw.githubusercontent.com/googleapis/langchain-google-spanner-python/main/samples/retaildata.zip

In [None]:
!mkdir content
!unzip -o "retaildata.zip"

In [None]:
path = "retaildata/"
directories = [
    item for item in os.listdir(path) if os.path.isdir(os.path.join(path, item))
]

document_lists = []
for directory in directories:
    loader = DirectoryLoader(
        os.path.join(path, directory), glob="**/*.txt", loader_cls=TextLoader
    )
    document_lists.append(loader.load())

#### Extract Nodes and Edges

In [None]:
import copy
from langchain_experimental.graph_transformers import LLMGraphTransformer
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings


def print_graph(graph_documents):
    for doc in graph_documents:
        print(doc.source.page_content[:100])
        nodes = copy.deepcopy(doc.nodes)
        for node in nodes:
            if "embedding" in node.properties:
                node.properties["embedding"] = "..."
        print(nodes)
        print(doc.relationships)
        print()


llm = ChatVertexAI(model=GENERATIVE_MODEL, temperature=0)
llm_transformer = LLMGraphTransformer(
    llm=llm,
    allowed_nodes=["Category", "Segment", "Tag", "Product", "Bundle", "Deal"],
    allowed_relationships=[
        "In_Category",
        "Tagged_With",
        "In_Segment",
        "In_Bundle",
        "Is_Accessory_Of",
        "Is_Upgrade_Of",
        "Has_Deal",
    ],
    node_properties=[
        "name",
        "price",
        "weight",
        "deal_end_date",
        "features",
    ],
)

graph_documents = []
for document_list in document_lists:
    graph_documents.extend(llm_transformer.convert_to_graph_documents(document_list))

# Add embeddings to the graph documents for Product nodes
embedding_service = VertexAIEmbeddings(model_name=EMBEDDING_MODEL)
for graph_document in graph_documents:
    for node in graph_document.nodes:
        if "features" in node.properties:
            node.properties["embedding"] = embedding_service.embed_query(
                node.properties["features"]
            )

print_graph(graph_documents)

#### Post process extracted nodes and edges
Apply your domain knowledge to clean up and make desired fixes to the
generated graph in the earlier step.

In [None]:
# set of all valid products
products = set()


def prune_invalid_products():
    for graph_document in graph_documents:
        nodes_to_remove = []
        relationships_to_remove = []
        for node in graph_document.nodes:
            if node.type == "Product" and "features" not in node.properties:
                nodes_to_remove.append(node)
            else:
                products.add(node.id)
        for node in nodes_to_remove:
            graph_document.nodes.remove(node)


def prune_invalid_segments(valid_segments):
    for graph_document in graph_documents:
        nodes_to_remove = []
        for node in graph_document.nodes:
            if node.type == "Segment" and node.id not in valid_segments:
                nodes_to_remove.append(node)
        for node in nodes_to_remove:
            graph_document.nodes.remove(node)


def is_not_a_listed_product(node):
    if node.type == "Product" and node.id not in products:
        return True
    return False


def fix_directions(relation_name, wrong_source_type):
    for graph_document in graph_documents:
        for relationship in graph_document.relationships:
            if relationship.type == relation_name:
                if relationship.source.type == wrong_source_type:
                    source = relationship.source
                    target = relationship.target
                    relationship.source = target
                    relationship.target = source


def prune_dangling_relationships():
    # now remove all dangling relationships
    for graph_document in graph_documents:
        relationships_to_remove = []
        for relationship in graph_document.relationships:
            if is_not_a_listed_product(relationship.source) or is_not_a_listed_product(
                relationship.target
            ):
                relationships_to_remove.append(relationship)
        for relationship in relationships_to_remove:
            graph_document.relationships.remove(relationship)


def prune_unwanted_relationships(relation_name, source, target):
    node_types = set([source, target])
    for graph_document in graph_documents:
        relationships_to_remove = []
        for relationship in graph_document.relationships:
            if (
                relationship.type == relation_name
                and set([relationship.source.type, relationship.target.type])
                == node_types
            ):
                relationships_to_remove.append(relationship)
        for relationship in relationships_to_remove:
            graph_document.relationships.remove(relationship)


prune_invalid_products()
prune_invalid_segments(set(["Home", "Office", "Fitness"]))
prune_unwanted_relationships("IN_CATEGORY", "Bundle", "Category")
prune_unwanted_relationships("IN_CATEGORY", "Deal", "Category")
prune_unwanted_relationships("IN_SEGMENT", "Bundle", "Segment")
prune_unwanted_relationships("IN_SEGMENT", "Deal", "Segment")
prune_dangling_relationships()
fix_directions("HAS_DEAL", "Deal")
fix_directions("IN_BUNDLE", "Bundle")
print_graph(graph_documents)

#### Load data to Spanner Graph
Cleanup database from previous iterations.
!!! THIS COULD REMOVE DATA FROM YOUR DATABASE !!!

In [None]:
graph_store.cleanup()
graph_store.add_graph_documents(graph_documents)

### Visualization

In [None]:
%load_ext spanner_graphs

In [None]:
%%spanner_graph --project {PROJECT_ID} --instance {INSTANCE} --database {DATABASE}

GRAPH retail_demo_graph
MATCH p = ()->()
RETURN TO_JSON(p) AS path_json

## GraphRAG flow using Spanner Graph

In [None]:
# @title Question Answering Prompt
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts.prompt import PromptTemplate
from langchain_google_vertexai import ChatVertexAI, VertexAIEmbeddings

from IPython.display import Markdown
import textwrap
import json
import pprint

# Retrieve and generate using the relevant snippets of the blog.
def format_docs(docs):
    print("Context Retrieved: \n")
    for doc in docs:
        print("-"*80)
        pprint.pprint(json.loads(doc.page_content)[0], width=80, indent=4)
        #print(json.dumps(json.loads(doc.page_content)[0], indent=4))
        print("-"*80)
        print("\n")

    context = "\n\n".join(doc.page_content for doc in docs)
    return context


SPANNERGRAPH_QA_TEMPLATE = """
You are a helpful and friendly AI assistant for question answering tasks for an electronics
retail online store.
Create a human readable answer for the for the question.
You should only use the information provided in the context and not use your internal knowledge.
Don't add any information.
Here is an example:

Question: Which funds own assets over 10M?
Context:[name:ABC Fund, name:Star fund]"
Helpful Answer: ABC Fund and Star fund have assets over 10M.

Follow this example when generating answers.
You are given the following information:
- `Question`: the natural language question from the user
- `Graph Schema`: contains the schema of the graph database
- `Graph Query`: A Spanner Graph GQL query equivalent of the question from the user used to extract context from the graph database
- `Context`: The response from the graph database as context. The context has nodes and edges. Use the relationships.
Information:
Question: {question}
Graph Schema: {graph_schema}
Context: {context}

Format your answer to be human readable.
Use the relationships in the context to answer the question.
Only include information that is relevant to a customer.
Helpful Answer:"""

prompt = PromptTemplate(
    template=SPANNERGRAPH_QA_TEMPLATE,
    input_variables=["question", "graph_schema", "context"],
)

llm = ChatVertexAI(model=GENERATIVE_MODEL, temperature=0)

chain = prompt | llm | StrOutputParser()

### Specify user query:

In [None]:
USER_QUERY = "Give me recommendations for a beginner drone with a good battery and camera"

In [None]:
# @title GraphRAG using Vector Search and Graph Expansion
import textwrap
from langchain_google_spanner import SpannerGraphVectorContextRetriever
from langchain_google_vertexai import VertexAIEmbeddings


def use_node_vector_retriever(
    question, graph_store, embedding_service, label_expr, expand_by_hops
):
    retriever = SpannerGraphVectorContextRetriever.from_params(
        graph_store=graph_store,
        embedding_service=embedding_service,
        label_expr=label_expr,
        expand_by_hops=expand_by_hops,
        top_k=3,
        #k=10,
    )
    context = format_docs(retriever.invoke(question))
    return context


embedding_service = VertexAIEmbeddings(model_name=EMBEDDING_MODEL)

context = use_node_vector_retriever(
    USER_QUERY, graph_store, embedding_service, label_expr="Product", expand_by_hops=1
)

answer = chain.invoke(
    {"question": question, "graph_schema": graph_store.get_schema, "context": context}
)

print("\n\nAnswer:\n")
print(textwrap.fill(answer, width=80))

## Compare with Conventional RAG

In [None]:
TABLE_NAME = "rag_table"

In [None]:
# @title Setup and load data for vector search
from langchain_google_spanner import SpannerVectorStore
from langchain_google_vertexai import VertexAIEmbeddings
from langchain_text_splitters import RecursiveCharacterTextSplitter

import uuid


def load_data_for_vector_search(splits):
    embeddings = VertexAIEmbeddings(model_name=EMBEDDING_MODEL)

    SpannerVectorStore.init_vector_store_table(
        instance_id=INSTANCE,
        database_id=DATABASE,
        table_name=TABLE_NAME,
    )
    db = SpannerVectorStore(
        instance_id=INSTANCE,
        database_id=DATABASE,
        table_name=TABLE_NAME,
        embedding_service=embeddings,
    )
    # Add the chunks to Spanner Vector Store with embeddings
    ids = [str(uuid.uuid4()) for _ in range(len(splits))]
    row_ids = db.add_documents(splits, ids)


# Create splits for documents
text_splitter = RecursiveCharacterTextSplitter(chunk_size=250, chunk_overlap=100)
splits = text_splitter.split_documents(
    [document for document_list in document_lists for document in document_list]
)

# Initialize table and load data
embeddings = VertexAIEmbeddings(model_name=EMBEDDING_MODEL)
load_data_for_vector_search(splits)

In [None]:
from langchain_core.runnables import RunnablePassthrough
from langchain_google_spanner import SpannerVectorStore
import textwrap


# Retrieve and generate using the relevant snippets of the blog.
def format_docs(docs):
    print("Context Retrieved: \n")
    for doc in docs:
        print("-"*80)
        print(textwrap.fill(doc.page_content, width=80))
        print("-"*80)
        print("\n")

    context = "\n\n".join(doc.page_content for doc in docs)
    return context


prompt = PromptTemplate(
    template="""
    You are a friendly digital shopping assistant.
    Use the following pieces of retrieved context to answer the question.
    If you don't know the answer, just say that you don't know.
    Question: {question}
    Context: {context}
    Answer:
  """,
    input_variables=["context", "question"],
)

# Create a rag chain
embeddings = VertexAIEmbeddings(model_name=EMBEDDING_MODEL)

db = SpannerVectorStore(
    instance_id=INSTANCE,
    database_id=DATABASE,
    table_name=TABLE_NAME,
    embedding_service=embeddings,
)
vector_retriever = db.as_retriever(search_kwargs={"k": 3})
rag_chain = (
    {
        "context": vector_retriever | format_docs,
        "question": RunnablePassthrough(),
    }
    | prompt
    | llm
    | StrOutputParser()
)

In [None]:
import textwrap

resp = rag_chain.invoke(USER_QUERY)
print("\n\nRag Response:\n")
print(textwrap.fill(resp, width=80))

## Clean up the graph

> USE IT WITH CAUTION!

**Clean up all the nodes/edges in your graph and remove your graph definition.**

In [None]:
#graph_store.cleanup()

Copyright 2025 Google LLC

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

     https://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.