<a href="https://colab.research.google.com/github/run-llama/llama_index/blob/main/docs/examples/vector_stores/postgres.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>


# Issue [#17](https://github.com/ai-cfia/llamaindex-db/issues/17)


In [1]:
# %pip install -r ../../requirements.txt

In [2]:
import logging
import sys
from llama_index.core import StorageContext
from llama_index.core import VectorStoreIndex
from llama_index.vector_stores.postgres import PGVectorStore
from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding
from llama_index.llms.azure_openai import AzureOpenAI
from llama_index.core import Settings
import os
from dotenv import load_dotenv
from llama_index.storage.index_store.postgres import PostgresIndexStore
from llama_index.storage.docstore.postgres import PostgresDocumentStore
from llama_index.core.node_parser import SentenceSplitter
from pprint import pprint
import psycopg
import pickle
from llama_index.readers.web import SimpleWebPageReader
from llama_index.core.extractors import QuestionsAnsweredExtractor
from psycopg.sql import SQL, Identifier
import nest_asyncio

nest_asyncio.apply()
load_dotenv()

# # Uncomment to see debug logs
# logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
# logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))

True

In [4]:
def save_to_pickle(data, filename):
    with open(filename, "wb") as file:
        pickle.dump(data, file)


def load_from_pickle(filename):
    with open(filename, "rb") as file:
        return pickle.load(file)

## Setup LLM and Embed Model


In [5]:
llm = AzureOpenAI(
    model="gpt-4",
    deployment_name="ailab-llm",
    api_key=os.getenv("API_KEY"),
    azure_endpoint=os.getenv("AZURE_ENDPOINT"),
    api_version=os.getenv("API_VERSION"),
)

embed_model = AzureOpenAIEmbedding(
    model="text-embedding-ada-002",
    deployment_name="ada",
    api_key=os.getenv("API_KEY"),
    azure_endpoint=os.getenv("AZURE_ENDPOINT"),
    api_version=os.getenv("API_VERSION"),
    max_retries=1000,
)

Settings.llm = llm
Settings.embed_model = embed_model

## Variables


In [7]:
database = os.getenv("DB_NAME")
host = os.getenv("DB_HOST")
password = os.getenv("DB_PASSWORD")
port = os.getenv("DB_PORT")
user = os.getenv("DB_USER")
llamaindex_db = "llamaindex_db_legacy"
llamaindex_schema = "v_0_0_2"

def verify(context_nodes, answer_urls):
    found = False
    for i, n in enumerate(context_nodes):
        if n.metadata["url"] in answer_urls:
            found = True
            print(f"Position: {i+1}", n.metadata["title"])

    if not found:
        print("Not found")

## Observed problem


In [60]:
vector_store_1 = PGVectorStore.from_params(
    database=llamaindex_db,
    host=host,
    password=password,
    port=port,
    user=user,
    embed_dim=1536,
    schema_name="v_0_0_1",
)

retriever_1 = VectorStoreIndex.from_vector_store(
    vector_store=vector_store_1
).as_retriever(similarity_top_k=100)

In [61]:
question = "Procédure opérationnelle : Inspection des contrôles préventifs des aliments - Vérification de la mise en oeuvre"
answer_urls_1 = [
    "https://inspection.canada.ca/fra/1679330680137/1679330831583",
]
context_nodes_1 = retriever_1.retrieve(question)

In [62]:
verify(context_nodes_1, answer_urls_1)

Not found


We can see that the page is not found despite searching for it's exact title.


## Root cause

After investigating the data in `louis_v005`, I noticed that for a given url, the page title is not included in the content of any of the generated chunks. And assuming that the embeddings were generated only on the contents, the observed problem is expected.


## Solution: add metadata in embeddings

Obviously we'll need to regenerate the embeddings, which can be done by providing llamaindex nodes (chunks) without embeddings. The node metadata are then automatically included in the embedding generation. We can specify which metadata to exclude from the embeddings. We want to include at least the document `title` so that searching those yields accurate results. Other metadata that would be interesting to include are `keywords` as configured in the webpage's `head` section.

Exemple of a `head` section:

```html
<head>
  <meta charset="utf-8" />
  <title>
    Certifications biologiques révoquées - Agence canadienne d'inspection des
    aliments
  </title>
  <meta content="width=device-width,initial-scale=1" name="viewport" />
  <meta name="dcterms.title" content="Certifications biologiques révoquées" />
  <meta
    name="description"
    content="Les organismes de certification accrédités par ..."
  />
  <meta
    name="dcterms.description"
    content="Les organismes de certification ..."
  />
  <meta
    name="keywords"
    content="Loi sur les produits agricoles au Canada, R&#232;glement sur les produits biologiques, 2009, Organisme de certification accr&#233;dit&#233;, annulation, suspension, 2023"
  />
  <meta
    name="dcterms.subject"
    content="inspection,législation,réglementation"
    title="gccore"
  />
  <meta
    name="dcterms.creator"
    content="Gouvernement du Canada,Agence canadienne d'inspection des aliments"
  />
  <meta name="dcterms.language" content="fra" title="ISO639-2" />
  <meta name="dcterms.issued" content="2023-05-12" title="W3CDTF" />
  <meta name="dcterms.modified" content="2024-04-29" title="W3CDTF" />
  <meta
    name="dcterms.type"
    content="législation et règlements,matériel de référence,statistiques"
    title="gctype"
  />
  <meta
    name="dcterms.audience"
    content="entreprises,grand public,gouvernement"
    title="gcaudience"
  />
  <meta property="dcterms:service" content="CFIA-ACIA" />
  <meta property="dcterms:accessRights" content="2" />
  ...
</head>
```

These are important information. We should probably not include them all in the embeddings because they require additional tokens, but they should be added to the nodes metadata.

I haven't researched the optimal chunk size topic (in terms of tokens) deeply, but the general recommendation seems to be between `256` and `1024`. Too low, chunks may not provide enough contextual information. Too high, the information can be diluted. Both can result in loss of accuracy.


## Implementation

### Get the list of all `url_ids`

`url_ids` are unique webpage identifiers. Ex: `/eng/1569265402509/1569265408887`. Combine with `https://inspection.canada.ca` to access a page. They are introduced in [fix #21](https://github.com/ai-cfia/llamaindex-db/pull/21) to prevent having alternate urls to access a single page. It's an effective way of removing duplicates.


In [6]:
conn_string = (
    f"dbname={database} "
    f"user={user} "
    f"password={password} "
    f"host={host} "
    f"port={port}"
)
query = """
    SELECT DISTINCT url_id
    FROM louis_v005.unique_documents
    WHERE url_id IS NOT NULL;
    """
with psycopg.connect(conn_string) as conn:
    with conn.cursor() as cur:
        results = cur.execute(query).fetchall()
        url_ids = [r[0] for r in results]

pprint(url_ids[0:5])
save_to_pickle(url_ids, "url_ids.pkl")

### Create documents from `url_ids`

Documents are just bigger nodes.

#### Utility classes and functions


In [25]:
import re
from datetime import datetime, timezone
from typing import Callable
from urllib.parse import urljoin
import asyncio
import aiohttp

import html2text
from bs4 import BeautifulSoup
from llama_index.core.schema import Document


DEFAULT_BASE_URL = "https://inspection.canada.ca"
DEFAULT_N_WORKERS = 100
DEFAULT_EXCLUDED_KEYS = [  # <------ excluded metadata. "title" and "keywords" are not.
    "url_id",
    "last_crawled",
    "url",
    "description",
    "viewport",
    "language",
    "type",
    "subject",
    "creator",
    "issued",
    "modified",
    "audience",
]
DEFAULT_URL_ID_REGEX = r"/[a-z]{3}/[0-9]+/[0-9]+$"


class AsyncWorkerQueue:
    """A class representing an asynchronous worker queue."""

    def __init__(self, num_workers, maxsize=0):
        self.num_workers = num_workers
        self.queue = asyncio.Queue(maxsize=maxsize)
        self.futures = []
        self.workers = []

    async def worker(self, name):
        """A coroutine that processes tasks from the queue."""
        while True:
            future, task, args, kwargs = await self.queue.get()
            if task is None:
                break
            try:
                result = await task(*args, **kwargs)
                future.set_result(result)
            except Exception as e:
                print(f"Worker {name} encountered an error: {e}")
                # error_trace = traceback.format_exc()
                # print(f"Worker {name} encountered an error: {e}\n{error_trace}")
                future.set_exception(e)
            finally:
                self.queue.task_done()

    async def run(self, tasks_with_params):
        """
        A coroutine that runs the worker tasks and processes the given tasks
        with parameters.
        """
        self.workers = [
            asyncio.create_task(self.worker(f"Worker-{i}"))
            for i in range(self.num_workers)
        ]

        for task, args, kwargs in tasks_with_params:
            future = asyncio.Future()
            self.futures.append(future)
            await self.queue.put((future, task, args, kwargs))

        await self.queue.join()

        for _ in self.workers:
            await self.queue.put((None, None, None, None))
        await asyncio.gather(*self.workers, return_exceptions=True)

        results = []
        for future in self.futures:
            try:
                result = await future
                results.append(result)
            finally:
                continue

        return results


class WebPageReader:
    @classmethod
    async def get_html_from_url(cls, url):
        """Retrieve HTML content from a given URL asynchronously"""
        async with aiohttp.ClientSession() as session:
            async with session.get(url) as response:
                response.raise_for_status()
                content_type = response.headers.get("content-type")
                if "text/html" not in content_type:
                    raise ValueError("The returned content is not HTML")
                return await response.text()


class MetadataExtractor:  # <--------- This extracts the metadata from the head section
    @classmethod
    async def extract_from_soup(cls, soup: BeautifulSoup):
        """Extract metadata from the soup"""
        metadata: dict[str, str] = {}
        if header := soup.find("head"):
            metas = header.find_all("meta")
            for meta in metas:
                name = meta.get("name", "").replace("dcterms.", "")
                content = meta.get("content")
                if name and content:
                    metadata[name] = content
        return metadata


class AiLabWebPageReader:
    """AiLab web page reader."""

    html_to_text: bool
    base_url: str
    _metadata_fn: Callable[[str], dict] | None
    url_id_regex: str

    def __init__(
        self,
        html_to_text: bool = False,
        base_url: str = DEFAULT_BASE_URL,
        url_id_regex: str = DEFAULT_URL_ID_REGEX,
        metadata_fn: Callable[[BeautifulSoup], dict] | None = None,
    ) -> None:
        """Initialize with parameters."""
        self.html_to_text = html_to_text
        self.base_url = base_url
        self.url_id_regex = url_id_regex
        self._metadata_fn = metadata_fn

    async def load_document(
        self, url_id: str, excluded_keys: list[str] = DEFAULT_EXCLUDED_KEYS
    ):
        """Load data from the input directory."""
        assert re.match(self.url_id_regex, url_id), "Invalid URL ID"
        url = urljoin(self.base_url, url_id)
        response = await WebPageReader.get_html_from_url(url)
        response = BeautifulSoup(response, "html.parser")

        metadata: dict[str, str] | None = None  # <----------- metadata extraction here
        if self._metadata_fn is not None:
            metadata = await self._metadata_fn(response)
            metadata["url"] = url
            metadata["url_id"] = url_id
            metadata["last_crawled"] = datetime.now(timezone.utc).isoformat()

        # prune irrelevant sections like header, footer, etc.
        response = (response.find("main") or response).prettify()

        if self.html_to_text:
            response = html2text.html2text(response)

        return Document(
            text=response,
            id_=url,
            metadata=metadata or {},  # <---------- metadata are added to the nodes here
            excluded_embed_metadata_keys=excluded_keys,
            excluded_llm_metadata_keys=excluded_keys,
        )

    async def load_documents(
        self,
        url_ids: list[str],
        n_workers: int = DEFAULT_N_WORKERS,
        excluded_keys: list[str] = DEFAULT_EXCLUDED_KEYS,
    ) -> list[Document]:
        """Load data from the input directory."""
        assert isinstance(url_ids, list), "urls must be a list of strings."
        tasks = [(self.load_document, (u, excluded_keys), {}) for u in url_ids]
        async_queue = AsyncWorkerQueue(num_workers=n_workers)
        documents = await async_queue.run(tasks)
        return documents

#### Creating the documents


In [None]:
url_ids = load_from_pickle("url_ids.pkl")
pprint(url_ids[2:4])
documents = await AiLabWebPageReader(
    html_to_text=True, metadata_fn=MetadataExtractor.extract_from_soup
).load_documents(url_ids)
save_to_pickle(documents, "documents.pkl")

In [28]:
documents = load_from_pickle("documents.pkl")
print(len(documents))

11056


### Create nodes from documents

By trial, I found that a chunk size of `512` tokens raises an error where, for some nodes, the metadata size is already larger than `512` tokens, leaving no room for the actual node text. I resorted to using `768` which is still reasonable.

The `overlap` parameter used below helps ensure that contextual information isn't lost by cutting the text abruptly from one chunk to another.


In [20]:
parser = SentenceSplitter(chunk_size=768, chunk_overlap=50)  # <--------- chunk size 768
nodes = parser.get_nodes_from_documents(documents, show_progress=True)
save_to_pickle(nodes, "nodes.pkl")

Parsing nodes: 100%|██████████| 11056/11056 [00:48<00:00, 229.83it/s]


In [31]:
nodes = load_from_pickle("nodes.pkl")
print(len(nodes))
pprint(parser._get_metadata_str(nodes[0]))

80818
('title: Alpha Meat Packers Ltd. brand Beef Burgers and Lean Ground Beef '
 'recalled due to E. coli O157:H7\n'
 'food safety, food borne illness, food poisoning, foodborne pathogens, E. '
 'coli, Alpha Meat Packers Ltd., Beef Burgers, Lean Ground Beef,')


- ✅ `title` and `keywords` metadata will be included in embeddings


### Create the tables


In [None]:
# nodes = load_from_pickle("nodes.pkl")
vector_store = PGVectorStore.from_params(
    database=llamaindex_db,
    host=host,
    password=password,
    port=port,
    user=user,
    embed_dim=1536,
    schema_name=llamaindex_schema,
)

document_store = PostgresDocumentStore.from_params(
    database=llamaindex_db,
    host=host,
    password=password,
    port=port,
    user=user,
    schema_name=llamaindex_schema,
)

index_store = PostgresIndexStore.from_params(
    database=llamaindex_db,
    host=host,
    password=password,
    port=port,
    user=user,
    schema_name=llamaindex_schema,
)

storage_context = StorageContext.from_defaults(
    docstore=document_store,
    index_store=index_store,
    vector_store=vector_store,
)

storage_context.docstore.add_documents(nodes, batch_size=512)

index = VectorStoreIndex(
    nodes, storage_context=storage_context, show_progress=True, insert_batch_size=512
)

### Create the index


In [7]:
connection_string = (
    f"dbname={llamaindex_db} "
    f"user={user} "
    f"password={password} "
    f"host={host} "
    f"port={port}"
)

schema = Identifier(llamaindex_schema)
query = SQL(
    "CREATE INDEX ON {}.data_llamaindex USING hnsw (embedding vector_cosine_ops)"
).format(schema)

with psycopg.connect(connection_string) as conn:
    conn.autocommit = True
    with conn.cursor() as cur:
        cur.execute(query)

## Testing


In [8]:
vector_store = PGVectorStore.from_params(
    database=llamaindex_db,
    host=host,
    password=password,
    port=port,
    user=user,
    embed_dim=1536,
    schema_name=llamaindex_schema,
)

retriever = VectorStoreIndex.from_vector_store(vector_store=vector_store).as_retriever(
    similarity_top_k=5
)

In [9]:
question = "Procédure opérationnelle : Inspection des contrôles préventifs des aliments - Vérification de la mise en oeuvre"
answer_urls = [
    "https://inspection.canada.ca/fra/1679330680137/1679330831583",
]
context_nodes = retriever.retrieve(question)

In [10]:
verify(context_nodes, answer_urls)

Position: 1 Procédure opérationnelle : Inspection des contrôles préventifs des aliments – Vérification de la mise en oeuvre
Position: 2 Procédure opérationnelle : Inspection des contrôles préventifs des aliments – Vérification de la mise en oeuvre
Position: 3 Procédure opérationnelle : Inspection des contrôles préventifs des aliments – Vérification de la mise en oeuvre
Position: 4 Procédure opérationnelle : Inspection des contrôles préventifs des aliments – Vérification de la mise en oeuvre
Position: 5 Procédure opérationnelle : Inspection des contrôles préventifs des aliments – Vérification de la mise en oeuvre


- ✅ searching by page title yields good results

In [11]:
pprint(context_nodes[0].dict()) 

{'class_name': 'NodeWithScore',
 'node': {'class_name': 'TextNode',
          'embedding': None,
          'end_char_idx': 30370,
          'excluded_embed_metadata_keys': ['url_id',
                                           'last_crawled',
                                           'url',
                                           'description',
                                           'viewport',
                                           'language',
                                           'type',
                                           'subject',
                                           'creator',
                                           'issued',
                                           'modified',
                                           'audience'],
          'excluded_llm_metadata_keys': ['url_id',
                                         'last_crawled',
                                         'url',
                                         'description',
     