# GraphRAG API Demo

This notebook is written as an advanced tutorial/demonstration on how to use the GraphRAG solution accelerator API. It builds on top of the concepts covered in the `1-Quickstart` notebook.

## Existing APIs

| HTTP Method | Endpoint         |
|-------------|------------------|
| GET         | /data
| POST        | /data
| DELETE      | /data/{storage_name}
| GET         | /index
| POST        | /index
| DELETE      | /index/{index_name}
| GET         | /index/status/{index_name}
| POST        | /query/global
| POST        | /query/local
| GET         | /index/config/prompts
| GET         | /source/report/{index_name}/{report_id}
| GET         | /source/text/{index_name}/{text_unit_id}
| GET         | /source/entity/{index_name}/{entity_id}
| GET         | /source/claim/{index_name}/{claim_id}
| GET         | /source/relationship/{index_name}/{relationship_id}
| GET         | /graph/graphml/{index_name}
| GET         | /graph/stats/{index_name}

## Prerequisites
Install 3rd party packages that are not part of the Python Standard Library

In [None]:
! pip install devtools pandas python-magic requests tqdm

In [None]:
import getpass
import json
import os
import sys
import time
from pathlib import Path
from zipfile import ZipFile

import magic
import pandas as pd
import requests
from devtools import pprint
from tqdm import tqdm

## (REQUIRED) User Configuration
Set the API subscription key, API base endpoint, and some file directory names that will be referenced later in this notebook.

#### API subscription key

APIM supports multiple forms of authentication and access control (e.g. managed identity). For this notebook demonstration, we will use a **[subscription key](https://learn.microsoft.com/en-us/azure/api-management/api-management-subscriptions)**. To locate this key, visit the Azure Portal. The subscription key can be found under `<my_resource_group> --> <API Management service> --> <APIs> --> <Subscriptions> --> <Built-in all-access subscription> Primary Key`. For multiple API users, individual subscription keys can be generated.

In [None]:
ocp_apim_subscription_key = getpass.getpass(
    "Enter the subscription key to the GraphRag APIM:"
)

"""
"Ocp-Apim-Subscription-Key": 
    This is a custom HTTP header used by Azure API Management service (APIM) to 
    authenticate API requests. The value for this key should be set to the subscription 
    key provided by the Azure APIM instance in your GraphRAG resource group.
"""
headers = {"Ocp-Apim-Subscription-Key": ocp_apim_subscription_key}

#### Setup directories and API endpoint

For demonstration purposes, please use the provided `get-wiki-articles.py` script to download a small set of wikipedia articles or provide your own data (graphrag requires txt files to be utf-8 encoded).

In [None]:
"""
These parameters must be defined by the notebook user:

- file_directory: a local directory of text files. The file structure should be flat,
                  with no nested directories. (i.e. file_directory/file1.txt, file_directory/file2.txt, etc.)
- storage_name:   a unique name to identify a blob storage container in Azure where files
                  from `file_directory` will be uploaded.
- index_name:     a unique name to identify a single graphrag knowledge graph index.
                  Note: Multiple indexes may be created from the same `storage_name` blob storage container.
- endpoint:       the base/endpoint URL for the GraphRAG API (this is the Gateway URL found in the APIM resource).
"""

file_directory = ""
storage_name = ""
index_name = ""
endpoint = ""

In [None]:
assert (
    file_directory != "" and storage_name != "" and index_name != "" and endpoint != ""
)

### Helper Functions

For cleanliness, we've provided helper functions below that encapsulate http requests to make API interaction with each API endpoint more intuitive.

In [None]:
def upload_files(
    file_directory: str,
    storage_name: str,
    batch_size: int = 100,
    overwrite: bool = True,
    max_retries: int = 5,
) -> requests.Response | list[Path]:
    """
    Upload files to a blob storage container.

    Args:
    file_directory - a local directory of .txt files to upload. All files must be in utf-8 encoding.
    storage_name - a unique name for the Azure storage container.
    batch_size - the number of files to upload in a single batch.
    overwrite - whether or not to overwrite files if they already exist in the storage container.
    max_retries - the maximum number of times to retry uploading a batch of files if the API is busy.

    NOTE: Uploading files may sometimes fail if the blob container was recently deleted
    (i.e. a few seconds before. The solution "in practice" is to sleep a few seconds and try again.
    """
    url = endpoint + "/data"

    def upload_batch(
        files: list, storage_name: str, overwrite: bool, max_retries: int
    ) -> requests.Response:
        for _ in range(max_retries):
            response = requests.post(
                url=url,
                files=files,
                params={"storage_name": storage_name, "overwrite": overwrite},
                headers=headers,
            )
            # API may be busy, retry
            if response.status_code == 500:
                print("API busy. Sleeping and will try again.")
                time.sleep(10)
                continue
            return response
        return response

    batch_files = []
    accepted_file_types = ["text/plain"]
    filepaths = list(Path(file_directory).iterdir())
    for file in tqdm(filepaths):
        # validate that file is a file, has acceptable file type, has a .txt extension, and has utf-8 encoding
        if (
            not file.is_file()
            or file.suffix != ".txt"
            or magic.from_file(str(file), mime=True) not in accepted_file_types
        ):
            print(f"Skipping invalid file: {file}")
            continue
        # open and decode file as utf-8, ignore bad characters
        batch_files.append(
            ("files", open(file=file, mode="r", encoding="utf-8", errors="ignore"))
        )
        # upload batch of files
        if len(batch_files) == batch_size:
            response = upload_batch(batch_files, storage_name, overwrite, max_retries)
            # if response is not ok, return early
            if not response.ok:
                return response
            batch_files.clear()
    # upload remaining files
    if len(batch_files) > 0:
        response = upload_batch(batch_files, storage_name, overwrite, max_retries)
    return response


def delete_files(storage_name: str) -> requests.Response:
    """Delete a blob storage container."""
    url = endpoint + f"/data/{storage_name}"
    return requests.delete(url=url, headers=headers)


def list_files() -> requests.Response:
    """List all data storage containers."""
    url = endpoint + "/data"
    return requests.get(url=url, headers=headers)


def build_index(
    storage_name: str,
    index_name: str,
    entity_extraction_prompt_filepath: str = None,
    community_prompt_filepath: str = None,
    summarize_description_prompt_filepath: str = None,
) -> requests.Response:
    """Create a search index.
    This function kicks off a job that builds a knowledge graph (KG) index from files located in a blob storage container.
    """
    url = endpoint + "/index"
    prompt_files = dict()
    if entity_extraction_prompt_filepath:
        prompt_files["entity_extraction_prompt"] = open(
            entity_extraction_prompt_filepath, "r"
        )
    if community_prompt_filepath:
        prompt_files["community_report_prompt"] = open(community_prompt_filepath, "r")
    if summarize_description_prompt_filepath:
        prompt_files["summarize_descriptions_prompt"] = open(
            summarize_description_prompt_filepath, "r"
        )
    return requests.post(
        url,
        files=prompt_files if len(prompt_files) > 0 else None,
        params={"index_name": index_name, "storage_name": storage_name},
        headers=headers,
    )


def delete_index(index_name: str) -> requests.Response:
    """Delete a search index."""
    url = endpoint + f"/index/{index_name}"
    return requests.delete(url, headers=headers)


def list_indexes() -> list:
    """List all search indexes."""
    url = endpoint + "/index"
    response = requests.get(url, headers=headers)
    try:
        indexes = json.loads(response.text)
        return indexes["index_name"]
    except json.JSONDecodeError:
        print(response.text)
        return response


def index_status(index_name: str) -> requests.Response:
    url = endpoint + f"/index/status/{index_name}"
    return requests.get(url, headers=headers)


def global_search(index_name: str | list[str], query: str) -> requests.Response:
    """Run a global query over the knowledge graph(s) associated with one or more indexes"""
    url = endpoint + "/query/global"
    request = {"index_name": index_name, "query": query}
    return requests.post(url, json=request, headers=headers)


def global_search_streaming(
    index_name: str | list[str], query: str
) -> requests.Response:
    """Run a global query across one or more indexes and stream back the response"""
    url = endpoint + "/experimental/query/global/streaming"
    request = {"index_name": index_name, "query": query}
    context_list = []
    with requests.post(url, json=request, headers=headers, stream=True) as r:
        r.raise_for_status()
        for chunk in r.iter_lines(chunk_size=256 * 1024, decode_unicode=True):
            try:
                payload = json.loads(chunk)
                token = payload["token"]
                context = payload["context"]
                if token != "<EOM>":
                    print(token, end="")
                elif (token == "<EOM>") and not context:
                    print("\n")  # transition from output message to context
                else:
                    context_list.append(context)
            except json.JSONDecodeError:
                print(type(chunk), len(chunk), sys.getsizeof(chunk), chunk, end="\n")
    display(pd.DataFrame.from_dict(context_list).head(10))


def local_search(index_name: str | list[str], query: str) -> requests.Response:
    """Run a local query over the knowledge graph(s) associated with one or more indexes"""
    url = endpoint + "/query/local"
    request = {"index_name": index_name, "query": query}
    return requests.post(url, json=request, headers=headers)


def get_graph_stats(index_name: str) -> requests.Response:
    """Get basic statistics about the knowledge graph constructed by GraphRAG."""
    url = endpoint + f"/graph/stats/{index_name}"
    return requests.get(url, headers=headers)


def save_graphml_file(index_name: str, graphml_file_name: str) -> None:
    """Retrieve and save a graphml file that represents the knowledge graph.
    The file is downloaded in chunks and saved to the local file system.
    """
    url = endpoint + f"/graph/graphml/{index_name}"
    if Path(graphml_file_name).suffix != ".graphml":
        raise UserWarning(f"{graphml_file_name} must have a .graphml file extension")
    with requests.get(url, headers=headers, stream=True) as r:
        r.raise_for_status()
        with open(graphml_file_name, "wb") as f:
            for chunk in r.iter_content(chunk_size=1024):
                f.write(chunk)


def get_report(index_name: str, report_id: str) -> requests.Response:
    """Retrieve a report generated by GraphRAG for a specific index."""
    url = endpoint + f"/source/report/{index_name}/{report_id}"
    return requests.get(url, headers=headers)


def get_entity(index_name: str, entity_id: str) -> requests.Response:
    """Retrieve an entity generated by GraphRAG for a specific index."""
    url = endpoint + f"/source/entity/{index_name}/{entity_id}"
    return requests.get(url, headers=headers)


def get_relationship(index_name: str, relationship_id: str) -> requests.Response:
    """Retrieve a relationship generated by GraphRAG for a specific index."""
    url = endpoint + f"/source/relationship/{index_name}/{relationship_id}"
    return requests.get(url, headers=headers)


def get_claim(index_name: str, claim_id: str) -> requests.Response:
    """Retrieve a claim/covariate generated by GraphRAG for a specific index."""
    url = endpoint + f"/source/claim/{index_name}/{claim_id}"
    return requests.get(url, headers=headers)


def get_text_unit(index_name: str, text_unit_id: str) -> requests.Response:
    """Retrieve a text unit generated by GraphRAG for a specific index."""
    url = endpoint + f"/source/text/{index_name}/{text_unit_id}"
    return requests.get(url, headers=headers)


def parse_query_response(
    response: requests.Response, return_context_data: bool = False
) -> requests.Response | dict[list[dict]]:
    """
    Prints response['result'] value and optionally
    returns associated context data.
    """
    if response.ok:
        print(json.loads(response.text)["result"])
        if return_context_data:
            return json.loads(response.text)["context_data"]
        return response
    else:
        print(response.reason)
        print(response.content)
        return response


def generate_prompts(storage_name: str, zip_file_name: str, limit: int = 1) -> None:
    """Generate graphrag prompts using data provided in a specific storage container."""
    url = endpoint + "/index/config/prompts"
    params = {"storage_name": storage_name, "limit": limit}
    with requests.get(url, params=params, headers=headers, stream=True) as r:
        r.raise_for_status()
        with open(zip_file_name, "wb") as f:
            for chunk in r.iter_content():
                f.write(chunk)

## Upload files

Use the API to upload a collection of local files. The API will create a new storage blob container to host these files in. For a set of large files, consider reducing the batch upload size in order to not overwhelm the API endpoint and prevent out-of-memory problems.

In [None]:
response = upload_files(
    file_directory=file_directory,
    storage_name=storage_name,
    batch_size=100,
    overwrite=True,
)
if not response.ok:
    print(response.text)
else:
    print(response)

#### To list all existing data storage containers:

In [None]:
response = list_files()
print(response)
pprint(response.json())

#### To remove files from the GraphRAG service:

In [None]:
# # uncomment this cell to delete data container
# response = delete_files(storage_name)
# print(response)
# pprint(response.text)

## Auto-Template Generation (Optional)

GraphRAG constructs a knowledge graph from data based on the ability to identify entities and the relationships between them. To improve the quality of the knowledge graph constructed by GraphRAG over private data, we provide a feature called "Automatic Templating". This capability takes user-provided data samples and generates custom-tailored prompts based on characteristics of that data. These custom prompts contain few-shot examples of entities and relationships, which can then be used to build a graphrag index.

In [None]:
generate_prompts(storage_name=storage_name, limit=5, zip_file_name="prompts.zip")
with ZipFile("prompts.zip", "r") as zip_ref:
    zip_ref.extractall()

After running the previous cell, a new local directory (`prompts`) will be created. Please look at the prompts (`prompts/entity_extraction.txt`, `prompts/community_report.txt`, and `prompts/summarize_descriptions.txt`) that were generated from the user-provided data. Users are encouraged to spend some time and inspect/modify these prompts, taking into account characteristics of their data and their own goals of what kind/type of knowledge they wish to extract and model with graphrag.

## Build an Index

After data files have been uploaded and (optionally) custom promps have been generated, it is time to construct a knowledge graph by building an index. If custom prompts are not provided (demonstrated in the `1-Quickstart` notebook), default built-in prompts are used that we find generally work well.

#### Start a new indexing job

In [None]:
# check if prompt files exist
entity_extraction_prompt_filepath = "prompts/entity_extraction.txt"
community_prompt_filepath = "prompts/community_report.txt"
summarize_description_prompt_filepath = "prompts/summarize_descriptions.txt"
entity_prompt = (
    entity_extraction_prompt_filepath
    if os.path.isfile(entity_extraction_prompt_filepath)
    else None
)
community_prompt = (
    community_prompt_filepath if os.path.isfile(community_prompt_filepath) else None
)
summarize_prompt = (
    summarize_description_prompt_filepath
    if os.path.isfile(summarize_description_prompt_filepath)
    else None
)

response = build_index(
    storage_name=storage_name,
    index_name=index_name,
    entity_extraction_prompt_filepath=entity_prompt,
    community_prompt_filepath=community_prompt,
    summarize_description_prompt_filepath=summarize_prompt,
)
pprint(response.json())
if not response.ok:
    print(f"Failed to submit job.\nStatus: {response.text}")

Note: An indexing job may fail sometimes due to insufficient TPM quota of the Azure OpenAI model. In this situation, an indexing job can be restarted by re-running the cell above with the same parameters. `graphrag` caches previous indexing results as a cost-savings measure so that restarting indexing jobs will "pick up" where the last job stopped.

#### Check the status of an indexing job

Please wait for your index to reach 100 percent completion before continuing on to the next section (running queries). You may rerun the next cell multiple times to monitor status. Note: the indexing speed of graphrag is directly correlated to the TPM quota of the Azure OpenAI model you are using.

In [None]:
response = index_status(index_name)
print(response)
pprint(response.json())

#### List indexes
To view a list of all indexes that exist in the GraphRAG service:

In [None]:
all_indexes = list_indexes()
pprint(all_indexes)

#### Delete an indexing job
If an index is no longer needed, remove it from the GraphRAG service.

In [None]:
# # uncomment this cell to delete an index
# response = delete_index(index_name)
# print(response)
# pprint(response.json())

## Query

Once an indexing job is complete, the knowledge graph is ready to query. Two types of queries (global and local) are currently supported. We encourage you to try both and experience the difference in responses. Note that query response time is also correlated to the TPM quota of the Azure OpenAI model you are using.

#### Global Search

Global search queries are resource-intensive, but give good responses to questions that require an understanding of the dataset as a whole.

In [None]:
%%time
# pass in a single index name as a string or to query across multiple indexes, set index_name=[myindex1, myindex2]
global_response = global_search(
    index_name=index_name, query="Summarize the main topics of this data"
)
# print the result and save context data in a variable
global_response_data = parse_query_response(global_response, return_context_data=True)
global_response_data

An *experimental* API endpoint has been designed to support streaming back the graphrag response while executing a global query (useful in applications like a chatbot).

In [None]:
global_search_streaming(
    index_name=index_name, query="Summarize the main topics of this data"
)

#### Local Search

Local search queries are best suited for narrow-focused questions that require an understanding of specific entities mentioned in the documents (e.g. What are the healing properties of chamomile?)

In [None]:
%%time
# pass in a single index name as a string or to query across multiple indexes, set index_name=[myindex1, myindex2]
local_response = local_search(
    index_name=index_name, query="Who are the primary actors in these communities?"
)
# print the result and save context data in a variable
local_response_data = parse_query_response(local_response, return_context_data=True)
local_response_data

## Sources

In a query response, citations will often appear that support GraphRAG's response. API endpoints are provided to enable retrieval of the sourced documents, entities, relationships, etc.

Multiple types of sources may be referenced in a query: Reports, Entities, Relationships, Claims, and Text Units. The API provides various endpoints to retrieve these sources for data provenance.

#### Get a Report

In [None]:
report_response = get_report(index_name, 0)
print(report_response.json()["text"]) if report_response.ok else (
    report_response.reason,
    report_response.content,
)

#### Get an Entity

In [None]:
entity_response = get_entity(index_name, 0)
entity_response.json() if entity_response.ok else (
    entity_response.reason,
    entity_response.content,
)

#### Get a Relationship

In [None]:
relationship_response = get_relationship(index_name, 1)
relationship_response.json() if relationship_response.ok else (
    relationship_response.reason,
    relationship_response.content,
)

#### Get a Claim
Note: claims are supported only if the solution accelerator deployment was initially configured to enable claims.

In [None]:
claim_response = get_claim(index_name, 1)
if claim_response.ok:
    pprint(claim_response.json())
else:
    print(claim_response)
    print(claim_response.text)

#### Get a Text Unit

In [None]:
# get a text unit id from one of the previous Source endpoint results (look for 'text_units' in the response)
text_unit_id = "a1ea6b4d13016fa863f4d76a0dd532e3"
if not text_unit_id:
    raise ValueError(
        "Must provide a text_unit_id from previous source results. Look for 'text_units' in the response."
    )
text_unit_response = get_text_unit(index_name, text_unit_id)
if text_unit_response.ok:
    print(text_unit_response.json()["text"])
else:
    print(text_unit_response.reason)
    print(text_unit_response.content)

## Exploring the GraphRAG knowledge graph
The API currently provides some basic functionality to better understand the knowledge graph that was constructed during the indexing process.

In addition, an option is available to export the graph to a graphml file which can be imported by other open source visualization software (we recommend [Gephi](https://gephi.org/)) for deeper exploration.

#### Basic knowledge graph statistics

In [None]:
response = get_graph_stats(index_name)
print(response)
print(response.text)

#### Get a GraphML file

In [None]:
# will save graphml file to the current local directory
save_graphml_file(index_name, "knowledge_graph.graphml")