# Building a knowledge graph with an LLM

This notebook shows how to build up a knowledge base from unstructured data using a large language model (LLM). This approach is useful if you have a lot of unstructured data like meeting notes or short articles, and you want to automatically see the relationships between different concepts.

Our approach starts by extracting a list of nodes and entities using Anthropic's Claude 3 model via Amazon Bedrock. We take the resulting nodes and entities and store them in Amazon Neptune, a graph database. Then we can use the typical set of graph visualizations and queries to understand the data.

// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: MIT-0

## Load local configuration

Create the file `config.yml` and then add settings for your Neptune graph ID and AWS region. For example:

    aws:
        region: us-east-1
    neptune:
        graph: your_neptune_analytics_graph_id

You should not include `config.yml` in your version control. If you use Git, add it to your `.gitignore` file.

In [None]:
import yaml
config = yaml.safe_load(open("config.yml"))

In [None]:
%opencypher_status

## Install dependencies and load data

We'll use the `datasets` module to load a sample set of financial news articles, and the `neo4j` library to interact with Neptune programmatically.

In [None]:
%pip install --upgrade --quiet boto3 botocore langchain datasets neo4j

In [None]:
import lzma

lines = []
with lzma.open('/home/ec2-user/SageMaker/kleister-nda/train/in.tsv.xz', mode='rt', encoding='utf-8') as fid:
    for line in fid:
        fields = line.split('\t')
        lines.append(fields[2])

## Bedrock setup

Here we'll define helper methods to use both Claude and Meta's Llama-2 model. This includes methods to invoke the models for regular chat, and methods that have prompts designed for node and entity extraction.

In [None]:
import boto3
import json

In [None]:
llamaModelId = 'meta.llama2-70b-chat-v1' 
bedrock_runtime = boto3.client(
    service_name='bedrock-runtime', 
    region_name=config['aws']['region']
)

def call_llama(query):

    prompt = f"[INST]{query}[/INST]"
    llamaPayload = json.dumps({ 
    	'prompt': prompt,
        'max_gen_len': 512,
    	'top_p': 0.9,
    	'temperature': 0.2
    })

    response = bedrock_runtime.invoke_model(
        body=llamaPayload, 
        modelId=llamaModelId, 
        accept='application/json', 
        contentType='application/json'
    )

    body = response.get('body').read().decode('utf-8')
    response_body = json.loads(body)
    return response_body['generation'].strip()

In [None]:
call_llama("Tell me a story about Mars")

In [None]:
claudeModelId = 'anthropic.claude-3-sonnet-20240229-v1:0' 

def call_claude(query):

    claudePayload = json.dumps({ 
        "anthropic_version": "bedrock-2023-05-31",
        'max_tokens': 2048,
    	"messages": [
          {
            "role": "user",
            "content": [
              {
                "type": "text",
                "text": query
              }
            ]
          }
        ]
    })
    

    response = bedrock_runtime.invoke_model(
        body=claudePayload, 
        modelId=claudeModelId, 
        accept='application/json', 
        contentType='application/json'
    )

    body = response.get('body').read().decode('utf-8')

    response_body = json.loads(body)
    return response_body['content'][0]['text']

In [None]:
call_claude("Tell me a story about Mars")

In [None]:
def call_llama_kg(query):

    prompt = """[INST]You are a robot that extracts information from financial news to build a knowledge graph. You only output JSON. Nodes represent entities, like a company.  Edges represent the relationships between nodes, like the fact that a person is the CEO of a company. When extracting nodes, it's vital to ensure consistency. If a node, such as "Acme Corp", is mentioned multiple times in the text but is referred to by different names (e.g., "Acme"), always use the most complete identifier for that entity throughout the knowledge graph. In this example, use "Acme Corp" as the node ID. 

Example input: "John Doe was recently named the CEO of Acme Corp."
Example output: 

{
"nodes": [
   {
        "label": "person",
        "id": "John Doe",
        "firstName": "john",
        "lastName": "doe"
    },
    {
        "label": "company",
        "id": "Acme Corp",
    }
],
"edges": [
    {
        "label": "executive",
        "id": "e-john-doe-acme-corp",
        "node1": "John Doe",
        "node2": "Acme Corp"
    }
]
}

Use the given format to extract information from the following input, responding only with JSON and no extra text:
"""
    
    llamaPayload = json.dumps({ 
    	'prompt': prompt + query + "[/INST]",
        'max_gen_len': 2048,
    	'top_p': 0.9,
    	'temperature': 0.2
    })

    response = bedrock_runtime.invoke_model(
        body=llamaPayload, 
        modelId=llamaModelId, 
        accept='application/json', 
        contentType='application/json'
    )

    body = response.get('body').read().decode('utf-8')
    response_body = json.loads(body)
    return response_body['generation'].strip()

In [None]:
def format_llama_kg(j):
    c = j.replace("\n", "").replace("\t", "")
    idx = c.find('{')
    return json.loads(c[idx:])

In [None]:
def call_claude_kg(query):
    
    prompt_template = """

Below is an article from a financial news source. Your job is to extract nodes and edges to build a knowledge graph. A node is an entity like a company. An edge is a relationship between two nodes, like "John Smith is the CEO of Acme Corp". When extracting nodes, it's vital to ensure consistency. If a node, such as "Acme Corp", is mentioned multiple times in the text but is referred to by different names (e.g., "Acme"), always use the most complete identifier for that entity throughout the knowledge graph. In this example, use "Acme Corp" as the node ID. Use camel case for node id, like "acme_corp" instead of "Acme Corp". If you find additional information, add it as a property on the node or edge. For example, if Acme Corp is a mining company, you can add a property "industry" set to "mining". 

Each node should have at least an `id` field and a `type` field. The `id` is the unique identifier, and the `type` is the type of entity, like 'company' or 'executive'. You can include other properties if you find them.

Example output:

<json>
{
  "nodes": [
      {
          "id": "acme_corp",
          "type": "company",
          "name": "Acme Corp",
          "industry": "chemicals"
      },
      {
          "id": "john_doe",
          "type": "executive",
          "name": "John Doe"
      }
  ],
  "edges": [
      {
          "source": "acme_corp",
          "target": "john_doe",
          "type": "employee",
          "employee_type": "CEO"
      }
  ]
}
</json>

<article>
ARTICLE_HERE
</article>

You must output only valid JSON. Be concise - do not provide any extra text before or after the JSON.
"""

    prompt = prompt_template.replace("ARTICLE_HERE", query)
    
    claudePayload = json.dumps({ 
        "anthropic_version": "bedrock-2023-05-31",
        'max_tokens': 2048,
    	"messages": [
          {
            "role": "user",
            "content": [
              {
                "type": "text",
                "text": prompt
              }
            ]
          }
        ]
    })
    
    
    response = bedrock_runtime.invoke_model(
        body=claudePayload, 
        modelId=claudeModelId, 
        accept='application/json', 
        contentType='application/json'
    )

    body = response.get('body').read().decode('utf-8')

    response_body = json.loads(body)
    return response_body['content'][0]['text']

In [None]:
def format_claude_kg(j):
    if '<json>' in j:
        idx1 = j.find('<json>')
        idx2 = j.find('</json>')
        s = j[idx1+6:idx2]
        return(json.loads(s))
    elif '```json' in j:
        idx1 = j.find('```json')
        idx2 = j.rfind('```')
        s = j[idx1+7:idx2]
        return(json.loads(s))
    else:
        raise Exception("Unknown Claude response format")

In [None]:
embedModelId = 'amazon.titan-embed-text-v1' 

def call_embed(query):

    accept = 'application/json' 
    content_type = 'application/json'
    body = json.dumps({
        "inputText": query,
    })

    # Invoke model 
    response = bedrock_runtime.invoke_model(
        body=body, 
        modelId=embedModelId, 
        accept=accept, 
        contentType=content_type
    )

    # Print response
    response_body = json.loads(response['body'].read())
    embedding = response_body.get('embedding')

    return embedding

In [None]:
call_embed("example text")

## Node and edge extraction

Let's look at a single article and test our extraction methods.

In [None]:
text = lines[0]

In [None]:
text

In [None]:
j = call_llama_kg(text)

In [None]:
j

In [None]:
j = call_claude_kg(text)

In [None]:
j

In [None]:
print(format_claude_kg(j))

### Neptune

Let's check connectivity to the cluster and then try a few Cypher queries using Bolt.

In [None]:
import boto3
graph_client = boto3.client('neptune-graph')

In [None]:
response = graph_client.execute_query(
    graphIdentifier=config['neptune']['graphid'],
    queryString='MATCH (p:company) RETURN p.name AS name',
    language='OPEN_CYPHER',
)

In [None]:
response['payload'].read().decode('utf-8')

In [None]:
def query_graph(graph_client, query):
    response = graph_client.execute_query(
        graphIdentifier=config['neptune']['graphid'],
        queryString=query,
        language='OPEN_CYPHER',
    )
    return json.loads(response['payload'].read().decode('utf-8'))

In [None]:
query_graph(graph_client, 'MATCH (p:company) RETURN p.name AS name')

In [None]:
def update_graph(graph_client, query):
    graph_client.execute_query(
        graphIdentifier=config['neptune']['graphid'],
        queryString=query,
        language='OPEN_CYPHER',
    )

### Process a few articles

Here we'll pick a few random articles from the dataset and process them.

In [None]:
article_indices = [0,1,2,3,4]

In [None]:
max_articles = len(dataset['train'])
max_articles

In [None]:
def sanitize_prop(s):
    t = str(s)
    return t.replace('"', '').replace("'", "").replace('[', '').replace(']', '')

def insert_node(nid, nlabel, nprops, graph_client):
    propstr = []
    for p in nprops.keys():
        propstr.append(f"{p}: '{sanitize_prop(nprops[p])}'")
    q = "MERGE (:" + nlabel + " {" + ",".join(propstr) + "})"
    print(f"Query: {q}")
    update_graph(graph_client, q)
    
def insert_edge(elabel, en1, en2, et1, et2, eprops, graph_client):
    eprops['name'] = elabel
    propstr = []
    for p in eprops.keys():
        propstr.append(f"{p}: '{eprops[p]}'")
    print(f"eprops: {json.dumps(eprops)}")
    q = "MATCH (" + en1 + ":" + et1 + " {name: '" + en1 + "'}), (" + en2 + ":" + et2 + " {name: '" + en2 + "'}) " + \
        "CREATE (" + en1 + ")-[:" + elabel+ " {" + ",".join(propstr) + "}]->(" + en2 + ")"
    print(f"Query: {q}")
    update_graph(graph_client, q)
    
def add_embedding(nid, embedding, graph_client):
    q = "MATCH (n) WHERE n.name in ['" + nid + "'] CALL neptune.algo.vectors.upsert(n, " + str(embedding) + ") " + \
        "YIELD node, embedding, success RETURN node, embedding, success"
    update_graph(graph_client, q)

def process_article(a, text_embed, graph_client):
    n = a['nodes']
    e = a['edges']
    n_types = []
    e_types = []
    id_label_map = {}
    
    print(f"Processing nodes: {len(n)}")
    for node in n:
        try:
            nid = node['id']
            nlabel = node['type']
            n_types.append(nlabel)

            nprops = {}
            nprops['name'] = nid
            for k in node.keys():
                if k in ['id', 'type', 'name']:
                    continue
                else:
                    nprops[k] = node[k]
            if 'name' in node:
                nprops['nname'] = node['name']

            insert_node(nid, nlabel, nprops, graph_client)
            add_embedding(nid, text_embed, graph_client)
            id_label_map[nid] = nlabel
        except Exception as ee: 
            print(f"Unable to process node {node} - {ee}")
    print(f"Processing edges: {len(e)}")
    for edge in e:
        try:
            elabel = edge['type']
            e_types.append(elabel)
            en1 = edge['source']
            en2 = edge['target']
            et1 = id_label_map[en1]
            et2 = id_label_map[en2]

            eprops = {}
            for k in edge.keys():
                if k in ['source', 'type', 'target']:
                    continue
                else:
                    eprops[k] = edge[k]

            insert_edge(elabel, en1, en2, et1, et2, eprops, graph_client)
        except Exception as ee: 
            print(f"Unable to process edge {edge} - {ee}")
          
    return n_types, e_types
            

In [None]:
for adx in article_indices:
    print(f"Article number {adx}")
    text = lines[adx]
    text_embed = call_embed(text)
    raw = call_claude_kg(text)
    print(f"Got Claude answer: {raw}")
    answer = format_claude_kg(raw)
    print(f"Claude JSON: {json.dumps(answer)}")
    process_article(answer, text_embed, graph_client)

In [None]:
%%oc

MATCH (n) 
WHERE n.name in ['albitar_oncology_consulting']
CALL neptune.algo.vectors.get(n)
YIELD node, embedding
RETURN n.code, embedding

## Explore the data

Now we can use regular Neptune queries to visualize the data. For example, let's say we have a company named `armanino`. First we can make sure we have this company in the graph.

In [None]:
%%oc

MATCH (a:company {name: 'albitar_oncology_consulting'}) RETURN a

Next we can run a Cypher query to show this company and all its relationships.

In [None]:
%%oc

MATCH (n {name: 'albitar_oncology_consulting'}) 
MATCH (n)-[r]-(m)
RETURN n,r, m

This Gremlin query is similar but will label each node and edge with a more descriptive label.

## Graph RAG

A more sophisticated way to use the graph is to follow this process.

* First, create an embedding of the query.
* Second, query the graph for any related nodes using vector search.
* Extract a subgraph that includes the related nodes to a certain depth.
* Include the subgraph as context to the overall response.

In [None]:
def graph_vector_search(graph_client, embedding):
    q = "CALL neptune.algo.vectors.topKByEmbedding(" + str(embedding) + ", {topK: 3})" + \
        " YIELD node, score RETURN node, score"
    
    response = graph_client.execute_query(
        graphIdentifier=config['neptune']['graphid'],
        queryString=q,
        language='OPEN_CYPHER',
    )
    return json.loads(response['payload'].read().decode('utf-8'))

In [None]:
graph_vector_search(graph_client, text_embed)

In [None]:
g = graph_vector_search(graph_client, text_embed)

In [None]:
g['results'][0]['node']['~properties']['name']

In [None]:
r = query_graph(graph_client, "MATCH (src {name: 'albitar_oncology_consulting'}) MATCH (src)-[rel]-(tgt) RETURN src,rel,tgt")
r['results']

In [None]:
def call_claude_graph_rag(query, relationships):

    prompt_template = """

Below is a question asked by a person. In order to help you answer, we include related relationship information the concepts in the question, extracted from a knowledge graph. Use the information from the knowledge graph to answer the question.

Here's an example.

<example_question>
Can you tell me about Acme Corp?
</example_question>

<example_relationships>
{'src': {'name': 'acme_corp'}, 'rel': ({'name': 'acme_corp'}, 'leadership', {'name': 'john_doe'}), 'tgt': {'name': 'john_doe'}}
</example_relationships>

<example_output>
Acme Corp employes John Doe as a senior leader.
</example_output>

<question>
QUESTION_HERE
</question>

<relationships>
RELS_HERE
</relationships>

Be concise.
"""
    if isinstance(relationships, list):
        rel_str =  "\n".join([json.dumps(x) for x in relationships])
        prompt = prompt_template.replace("QUESTION_HERE", query).replace("RELS_HERE", rel_str)
    else:
        prompt = prompt_template.replace("QUESTION_HERE", query).replace("RELS_HERE", json.dumps(relationships))
    claudePayload = json.dumps({ 
        "anthropic_version": "bedrock-2023-05-31",
        'max_tokens': 2048,
    	"messages": [
          {
            "role": "user",
            "content": [
              {
                "type": "text",
                "text": prompt
              }
            ]
          }
        ]
    })
    
    
    response = bedrock_runtime.invoke_model(
        body=claudePayload, 
        modelId=claudeModelId, 
        accept='application/json', 
        contentType='application/json'
    )

    body = response.get('body').read().decode('utf-8')

    response_body = json.loads(body)
    return response_body['content'][0]['text']

In [None]:
def graph_rag(query):
    q_embed = call_embed(query)
    related_nodes = graph_vector_search(graph_client, q_embed)
    
    subgraphs = []
    for c in related_nodes['results']:
        nid = c['node']['~properties']['name']
        records = query_graph(graph_client, 
            "MATCH (src {name: '" + nid + "'}) MATCH (src)-[rel]-(tgt) RETURN src,rel,tgt"
        )
    
        for r in records['results']:
            subgraphs.append(r)
    
    print(f"Found {len(subgraphs)} subgraphs")
    print(subgraphs)
    return call_claude_graph_rag(query, subgraphs)

In [None]:
graph_rag("Which executives work at Albitar Oncology?")