![RelationalAI - Getting started with GraphRAG](assets/header.png)

In [None]:
import os
import re
import sys
import warnings

import pandas as pd
import relationalai as rai
from relationalai.clients.snowflake import Session, Snowflake
from relationalai.std import aggregates, alias
from relationalai.std.graphs import Graph

# Solution packages.
sys.path.append("../python/src/")

from sf_rai_graphrag.rai import *
from sf_rai_graphrag.snowflake import *
from sf_rai_graphrag.util import *

In [None]:
logging.getLogger().setLevel(logging.ERROR)

## Snowflake session setup

In [None]:
session = Session.builder.configs({
    "user": "<your_snowflake_username>", 
    "password": "<your_snowflake_password>", 
    "account": "<your_snowflake_account_identifier>", 
    "database": "graph_rag", 
    "schema": "graph_rag", 
    "role": "accountadmin", 
    "warehouse=": "graph_rag"
}).create()

In [None]:
session.use_role("accountadmin")
session.use_database("graph_rag")
session.use_schema("graph_rag")
session.use_warehouse("graph_rag")

## Creating a graph with the RelationalAI Native App

### Requirements

- The [RelationalAI Native App](https://app.snowflake.com/marketplace/listing/GZTYZOOIX8H/relationalai-relationalai?search=relationalai&originTab=provider&providerName=RelationalAI&profileGlobalName=GZTYZOOIX7W) **must have been installed** in the Snowflake account being used.
- The [RelationalAI CLI](https://relational.ai/docs/reference/cli/) **must have been installed**.
- A `raiconfig.toml` must be present in the project root. To create a `raiconfig.toml` execute `rai init` in the project root and provide the required information. More info on the [RelationalAI CLI reference](https://relational.ai/docs/reference/cli/) page.

Creating the RelationalAI resources required can either be done either with the use of the CLI or by the functions provided in this notebook, for convenience.

### Manual resource provisioning

- Provision a RAI engine by executing `rai engines:create` in the project root and providing the required information.
  
- Setup the Snowflake - RelationalAI Data Streams:
    - `rai imports:stream --source graph_rag.graph_rag.nodes --model graph_rag`
    - `rai imports:stream --source graph_rag.graph_rag.edges --model graph_rag`

Having submitted the streams, we should wait for the data to synchronize by checking that the status is `LOADED` for both. Checking the lifecycle stage of the streams creation is performed by issuing the following CLI command:

- `rai imports:list --model graph_rag`

### Automatic resource provisioning

#### Provisioning a RelationalAI engine

The following convenience function wraps the `rai engines:create` CLI command.

Note: the `engine_size` and `engine_pool` values can be retrieved with by following the manual `rai engines:create` and taking note of the `Engine size` and `Compute pool` prompt options.

In [None]:
%%time

# Note: this process takes a few minutes to complete if a new engine must be provisioned.
output = create_engine(
    config={
        "engine": "graph_rag", 
        "engine_size": "<your_engine_size>", 
        "engine_pool": "<your_engine_pool>"
    }
)

#### Provisioning Data Streams

The following convenience function wraps `rai imports:stream --source graph_rag.graph_rag.nodes --model graph_rag` and `rai imports:stream --source graph_rag.graph_rag.edges --model graph_rag` CLI commands to create the Data Streams. Also, it blocks until both Data Streams are in `LOADED` status. `2` streams must be in that status for us to proceed.

In [None]:
%%time

output = setup_cdc_and_wait(
    config={
        "database": "graph_rag", 
        "schema": "graph_rag", 
        "model_name": "graph_rag"
    }
)

### Creating a RelationalAI model

A RelationalAI model is the abstraction that defines the data structures for representing the graph. The model is tightly knitted to the data synchronized through the Data Streams already provisioned.

To proceed, we shall first define a model named `graph_rag`. Subsequently, we shall define `Entity` and `Relation` model types. Observe how the source of data for each type are tables `nodes` and `edges`, respectively.

In [None]:
%%time

# Defining a RelationalAI model named `graph_rag`.
rai_model = rai.Model("graph_rag", dry_run=False)
snowflake_model = Snowflake(rai_model)

# Defining an `Entity` type. A `Entity` is a node in the graph.
Entity = rai_model.Type("Entity", source="graph_rag.graph_rag.nodes")

# Defining a `Relation` type. A `Relation` is an edge in the graph.
# Note how we attach `src` and `dst` properties in the `Relation`, 
# indicatiing the source and destination entities, respectivelly.
Relation = rai_model.Type("Relation", source="graph_rag.graph_rag.edges")
Relation.define(
    src=(Entity, "src_node_id", "id"), 
    dst=(Entity, "dst_node_id", "id")
)

#### (Optional) Exploring the RelationalAI model data

Having created our model, let's take a look at how we can inspect the properties of the types created. Also, we shall fetch some data to inspect how they are being represented in our model:

In [None]:
%%time

# Get the list of properties of each defined type.
print(
    f"""
    Entity.known_properties(): {Entity.known_properties()}
    Relation.known_properties(): {Relation.known_properties()}
    """
)

Querying our model to retrieve all `Entity` instances:

In [None]:
%%time

with rai_model.query() as select:
    entity = Entity()
    response = select.distinct(entity, entity.id, entity.type)
model_entities = response.results

In [None]:
model_entities.head(3)

We can also count how many `Entity` instances we have in our model:

In [None]:
%%time

with rai_model.query() as select:
    entity = Entity()
    response = select(aggregates.count(entity))
print("Entity count:", response.results)

Repeating the same for `Relations`:

In [None]:
%%time

with rai_model.query() as select:
    relation = Relation()
    response = select.distinct(relation, relation.src_node_id, relation.dst_node_id, relation.type)
model_relations = response.results

In [None]:
model_relations.head(3)

In [None]:
%%time

with rai_model.query() as select:
    relation = Relation()
    response = select(aggregates.count(relation))
print("Relation count:", response.results)

### Creating the graph, computing community identifiers and visualizing the graph

We shall now use the RelationalAI model we have defined previously and get a graph out of it.

First, we shall define a `Graph` data structure out of our RelationalAI model.

Subsequentyl, we shall execute the [Louvain](https://relational.ai/docs/reference/python/std/graphs/Compute/louvain/) community detection algorithm on the graph to identify communities.

Finally, we shall visualize the graph.

In [None]:
community_color_map = get_random_color_map(communities_length=200)

In [None]:
%%time

# Creating a graph representation of our model.
graph = Graph(model=rai_model, undirected=True)

# Applying the Louvain community detection on the model.
with rai_model.rule():
    entity = Entity()
    community_id = graph.compute.louvain(node=entity, max_levels=5, max_sweeps=10, level_tolerance=1e-2, sweep_tolerance=1e-4)
    entity.set(community_id=community_id)

with rai_model.rule():
    relation = Relation()
    graph.Node.extend(Entity, id=Entity.id, corpus_id=str(Entity.corpus_id), type=Entity.type, community_id=Entity.community_id)
    graph.Edge.add(from_=relation.src, to=relation.dst, corpus_id=str(relation.corpus_id), type=relation.type) # label=concat(relation.src.id, relation.dst.id)

Note that computations are defined and are applied lazily. 

For example, let's see how many communities we were able to identify:

In [None]:
%%time

with rai_model.query() as select:
    entity = Entity()
    response = select(aggregates.count(entity.community_id))
print("Communities count:", response.results)

Let's also visualize the graph, making sure that each community has a unique color:

In [None]:
vis = graph.visualize(
    three=False, 
    graph_height=750, 
    show_node_label=True, 
    show_edge_label=True, 
    layout_algorithm_active = True, 
    layout_algorithm = "hierarchicalRepulsion", 
    avoid_overlap = 1.0,
    style={
        "node": {
            "label": lambda n: f"{n.get('id')} ({n.get('type')})", 
            "color": lambda n: community_color_map.get(n["community_id"], "black"), 
            "size": 30,
            "border_color": "white",
            "border_size": 1,
            "hover": lambda n: f"{n.get('id')} (type: {n.get('type')}, community: {n.get('community_id')})"
        }, 
        "edge": {
            "label": lambda e: e.get("type"), 
            "color": "grey", 
            "hover": lambda e: e.get("type")
        }
    }
)

vis.display(inline=True)

## Community-based summarization

Having identified graph communities, we shall now produce summaries out of the text of all corpus items in a community.

In [None]:
%%time

with rai_model.query() as select:
    entity = Entity()
    response = select(alias(entity.id, "id"), alias(entity.community_id, "community_id"), alias(entity.corpus_id, "corpus_id"))

In [None]:
results_df = response.results
results_df.sort_values(by=["community_id", "corpus_id"]).head(3)

In [None]:
# Getting a (community, corpus-id) multi-index.
communities_count_df = results_df.groupby(by=["community_id", "corpus_id"]).count().rename(columns={"id": "entities_count"}).sort_index()

# Convert the multi index to a dict.
index = communities_count_df.index.to_flat_index()
d = {}
for x, y in index:
    d.setdefault(x, []).append(y)

For each community, we are producing a summary of all corpus items of this community, utilizing a Snowflake LLM:

In [None]:
%%time

execute_statement(
    session=session, 
    statement=f"TRUNCATE TABLE community_summary"
)

# Summarize all corpus items of a community.
for k, v in d.items():
    corpus_ids = ", ".join([str(i) for i in v])
    logger.info(f"Producing summarized versions of IDs ({corpus_ids}) for community {k}")
    try:
        execute_statement(
            session=session, 
            statement="""
                INSERT INTO community_summary(COMMUNITY_ID, CONTENT)
                WITH c AS (
                    SELECT 
                        LISTAGG(content, '\n\n') WITHIN GROUP(ORDER BY id) AS content 
                    FROM 
                        CORPUS
                    WHERE 
                        id IN ({CORPUS_IDS})
                )
                SELECT 
                    {COMMUNITY_ID} AS community_id
                    , PARSE_JSON(LLM_EXTRACT_JSON(r.response)):answer AS response
                FROM 
                    c
                JOIN TABLE(LLM_SUMMARIZE('llama3-70b', c.content)) AS r;
            """, 
            parameters={
                "COMMUNITY_ID": str(k), 
                "CORPUS_IDS": corpus_ids
            }
        )
    except Exception as error:
        logger.error(f"Error producing summarized versions of IDs ({corpus_ids}) for community {k}")

## Question answering

In [None]:
question = "Describe in detail the connection between Samuel Altman and Elon Musk, if one exists."

### Querying with community summaries as context

Using a window of concatenated community summaries and asking the LLM if the question can be asked from evidence in the context window.

LLM calls: `#community summaries / window`

In [None]:
%%time

# The previous-before-last parameter of the procedure call is the summarization window i.e. how many per-community summaries to include as context in the answer.
# A smaller value safeguards that we will not exceed the LLM token limit, whereas a larger one provides richer context to the LLM.
execute_statement(
    session=session, 
    statement="""
        CALL LLM_ANSWER_SUMMARIES('llama3-70b', 30, '{QUESTION}');
    """, 
    parameters={
        "QUESTION": question
    }
)

# Gather results.
answer = execute_statement(
    session=session, 
    statement="""
        SELECT 
            * 
        FROM 
            TABLE(result_scan(last_query_id()));
    """
)[0][0]

print(f"""
    Q: {question}
    A: {answer}
""")

#### Visualizing the graph again

This time, we are highlighting the entities mentioned in the answer to validate against the graph.

In [None]:
entities_of_interest = ["Samuel Harris Altman", "Elon Musk", "OpenAI"]

vis = graph.visualize(
    three=False, 
    graph_height=1000, 
    show_node_label=True, 
    show_edge_label=True, 
    layout_algorithm_active = True, 
    layout_algorithm = "hierarchicalRepulsion", 
    avoid_overlap = 1.0,
    style={
        "node": {
            "label": lambda n: f"{n.get('id')} ({n.get('type')})", 
            "color": lambda n: community_color_map.get(n["community_id"], "black"), 
            "size": lambda n: 60 if n.get('id') in entities_of_interest else 30,
            "border_color":  lambda n: "black" if n.get('id') in entities_of_interest else "white",
            "border_size": lambda n: 3 if n.get('id') in entities_of_interest else 1,
            "hover": lambda n: f"{n.get('id')} (type: {n.get('type')}, community: {n.get('community_id')})"
        }, 
        "edge": {
            "label": lambda e: e.get("type"), 
            "color": "grey", 
            "hover": lambda e: e.get("type")
        }
    }
)

vis.display(inline=True)

#### Verify that the entities are connected in the graph

Testing the reachability between entities `Elon Musk` and `Samuel Harris Altman` by querying the model.

In [None]:
# Can Elon reach Sam?
with rai_model.query() as select:
    entity_1 = Entity(id="Elon Musk")
    entity_2 = Entity(id="Samuel Harris Altman")
    with rai_model.match() as reachable:
        with rai_model.case():
            graph.compute.is_reachable(entity_1, entity_2)
            reachable.add(True)
        with rai_model.case():
            reachable.add(False)
    response = select(alias(reachable, "connected"))
print(f"Are Elon and Sam connected? {response.results['connected'][0]}")

---