# Use embeddings to search the dataset with a query string

In [1]:
import os
import pandas as pd
import openai
from openai.embeddings_utils import get_embedding, cosine_similarity

In [2]:
%load_ext dotenv
%dotenv

In [3]:
openai.api_key = os.getenv("OPENAI_API_KEY")
assert openai.api_key is not None

Load the dataset and the embeddings:

In [4]:
import numpy as np

def load_df(folder_path, file_name) -> pd.DataFrame:
    """Load dataset from a csv file into a pandas DataFrame. Ensure the types are correct.
    
    Expected columns:
    - index: int
    - file_name: str
    - content: str
    - combined: str
    - n_tokens: int
    - embedding: np.ndarray
    """
    df = pd.read_csv(os.path.join(folder_path, file_name))
    df["n_tokens"] = df["n_tokens"].astype(int)
    df["embedding"] = df.embedding.apply(eval).apply(np.array)
    return df

In [5]:
# FOLDER = "tmp/dnd-notes-main"
FOLDER = "/home/tyto/Documents/dnd/DND.SRD.Wiki-0.5.1/"
df = load_df(FOLDER, "embeddings-sections.csv")
df.head()

Unnamed: 0,filename,heading,content,combined,n_tokens,embedding
0,Paladin.md,Class Features,"As a paladin, you gain the following class fea...","Paladin.md:Class Features\nAs a paladin, you g...",19,"[0.0018332210602238774, -0.021341584622859955,..."
1,Paladin.md,Hit Points,**Hit Dice:** 1d10 per paladin level\n**Hit Po...,Paladin.md:Hit Points\n**Hit Dice:** 1d10 per ...,64,"[0.00903189368546009, -0.009907001629471779, 0..."
2,Paladin.md,Proficiencies,"**Armor:** All armor, shields\n**Weapons:** Si...",Paladin.md:Proficiencies\n**Armor:** All armor...,60,"[0.010927453637123108, 6.236664194148034e-05, ..."
3,Paladin.md,Equipment,"You start with the following equipment, in add...",Paladin.md:Equipment\nYou start with the follo...,821,"[0.018981102854013443, -0.0007141521782614291,..."
4,Paladin.md,Divine Sense,The presence of strong evil registers on your ...,Paladin.md:Divine Sense\nThe presence of stron...,188,"[0.031038733199238777, -0.004664028529077768, ..."


In [6]:
EMBEDDINGS_MODEL = "text-embedding-ada-002"

def search_df(df, query, filter=None, top_n: int=10) -> pd.DataFrame:
    query_embedding = get_embedding(query, engine=EMBEDDINGS_MODEL)
    df = df.copy()
    if filter is not None:
        df = df[df.filename.apply(filter)]
    df["similarity"] = df.embedding.apply(
        lambda x: cosine_similarity(query_embedding, x)
    )
    return df.sort_values("similarity", ascending=False).head(top_n)

In [7]:
COMPLETION_MODEL = "gpt-3.5-turbo"
MAX_TOKENS = 3500  # 4097 is the max, but we want to leave some room for the prompt

def generate_answer(prompt, context=None) -> str:
    if context is not None:
        context_str = "\n".join(context)
        prompt_with_context = f"{prompt}\n\nContext:\n{context_str}"
    else:
        prompt_with_context = prompt
    try:
        response = openai.ChatCompletion.create(
            model=COMPLETION_MODEL,
            messages=[
                {
                    "content": prompt_with_context,
                    "role": "user",
                },
            ],
        )
    except openai.error.InvalidRequestError as e:
        print(f"Error: {e}, prompt: {prompt_with_context}, context: {context}")
        return ""
    return response.choices[0].message.content.strip()

def generate_answer_with_search(prompt, top_n: int = 5, show_included=False) -> str:
    context_df = search_df(df, prompt, top_n=top_n)
    context = trim_context(prompt, context_df).combined
    if show_included:
        # show the included context
        print("Included context: " + '\n'.join(context_df.filename + context_df.heading + context_df.similarity.apply(str)))
    return generate_answer(prompt, context)

import tiktoken
encoding = tiktoken.encoding_for_model(COMPLETION_MODEL)

def trim_context(prompt, context: pd.DataFrame):
    """Trim context to fit within the max tokens."""
    prompt_tokens = encoding.encode(prompt)
    # ensure context is sorted by similarity
    context = context.sort_values("similarity", ascending=False)
    # trim context to fit within MAX_TOKENS
    context_tokens = len(prompt_tokens)
    for i, row in enumerate(context.iterrows()):
        _, row = row
        tokens = len(encoding.encode(row.combined))
        if context_tokens + tokens > MAX_TOKENS:
            context = context.iloc[:i - 1]
            break
        context_tokens += tokens
    return context

In [14]:
query = "Is the spell Legend Lore in the SRD?"
response = generate_answer_with_search(query, top_n=50, show_included=True)
from IPython.display import Markdown
Markdown(f"**{query}**\n\n{response}")

Included context: Legend Lore.mdLegend Lore0.8282462080328881
# Monster Statistics.mdA Legendary Creature's Lair0.8277920096890437
# Monster Statistics.mdLegendary Creatures0.8180924857981615
# Spellcasting.mdSPELLCASTING0.8128123544261647
# Monster Statistics.mdSpellcasting0.8122918715921268
Ring of Spell Turning.mdRing of Spell Turning0.8114901333387204
Bard.mdRitual Casting0.8109401658329012
# Spellcasting.mdCasting a Spell0.8100590903287319
Robe of the Archmagi.mdRobe of the Archmagi0.8087173343407306
Commune with Nature.mdCommune with Nature0.8084635663105468
Sorcerer.mdSorcerous Origin0.8077411046332814
Ranger.mdSpellcasting0.8073649321335876
Spell Scroll.mdSpell Scroll0.8061377090995736
Bard.mdMagical Secrets0.8060528813956777
Warlock.mdBook of Ancient Secrets0.8059130962478845
Ranger.mdPrimeval Awareness0.8055250831625166
Bard.mdBard College0.8042201417947717
Warlock.mdPact of the Tome0.8041390964383894
Ring of Three Wishes.mdRing of Three Wishes0.8029386232296213
# Monster Sta

**Is the spell Legend Lore in the SRD?**

Yes, the spell Legend Lore is in the SRD.