In [13]:
#!pip install -q -U sentence-transformers
#!pip install -q -U google-generativeai
#!pip install torch
# !pip install tiktoken

In [4]:
from dateutil.parser import parse
import pandas as pd
import requests

from sentence_transformers import SentenceTransformer, util
import numpy as np
import torch

import tiktoken

import google.generativeai as genai
import os

  from tqdm.autonotebook import tqdm, trange


In [25]:
#genai.configure(api_key=os.environ["API_KEY"])
genai.configure(api_key="YOUR_API_KEY")

In [26]:
for m in genai.list_models():
  if 'generateContent' in m.supported_generation_methods:
    print(m.name)

models/gemini-1.0-pro-latest
models/gemini-1.0-pro
models/gemini-pro
models/gemini-1.0-pro-001
models/gemini-1.0-pro-vision-latest
models/gemini-pro-vision
models/gemini-1.5-pro-latest
models/gemini-1.5-pro-001
models/gemini-1.5-pro
models/gemini-1.5-pro-exp-0801
models/gemini-1.5-pro-exp-0827
models/gemini-1.5-flash-latest
models/gemini-1.5-flash-001
models/gemini-1.5-flash-001-tuning
models/gemini-1.5-flash
models/gemini-1.5-flash-exp-0827
models/gemini-1.5-flash-8b-exp-0827


In [27]:
genai_model = genai.GenerativeModel("gemini-1.5-flash-001",
                              system_instruction=[
                                  "Do not generate the output in markdown format"
    ])

In [None]:
# Get the Wikipedia page for "2024"
resp = requests.get("https://en.wikipedia.org/w/api.php?action=query&prop=extracts&exlimit=1&titles=2024&explaintext=1&formatversion=2&format=json")

In [5]:
# Load page text into a dataframe
df = pd.DataFrame()
df["text"] = resp.json()["query"]["pages"][0]["extract"].split("\n")

# Clean up text to remove empty lines and headings
df = df[(df["text"].str.len() > 0) & (~df["text"].str.startswith("=="))]
df.head()

Unnamed: 0,text
0,"2024 (MMXXIV) is the current year, and is a le..."
1,"So far, this year has seen the continuation of..."
2,"Approximately 79 countries, representing aroun..."
9,January 1
10,"Egypt, Ethiopia, Iran and the United Arab Emir..."


In [6]:
# Prepare Dataset: Loading and Wrangling Data

prefix = ""
for (i, row) in df.iterrows():
    # If the row already has " - ", it already has the needed date prefix
    if " – " not in row["text"]:
        try:
            # If the row's text is a date, set it as the new prefix
            parse(row["text"])
            prefix = row["text"]
        except:
            # If the row's text isn't a date, add the prefix
            row["text"] = prefix + " – " + row["text"]
df = df[df["text"].str.contains(" – ")]
df.head()

Unnamed: 0,text
0,"– 2024 (MMXXIV) is the current year, and is a..."
1,"– So far, this year has seen the continuation..."
2,"– Approximately 79 countries, representing ar..."
10,"January 1 – Egypt, Ethiopia, Iran and the Unit..."
11,January 1 – The Republic of Artsakh is formall...


In [8]:
# Generate embeddings
from sentence_transformers import SentenceTransformer

embeddings_model = SentenceTransformer('paraphrase-MiniLM-L6-v2') # or any other suitable model

knowledge_base = df["text"].tolist()
all_embeddings = embeddings_model.encode(knowledge_base)

df['embeddings'] = [x for x in all_embeddings]
df.head()

Unnamed: 0,text,embeddings
0,"– 2024 (MMXXIV) is the current year, and is a...","[-0.35123074, 0.1045167, -0.10050657, -0.41728..."
1,"– So far, this year has seen the continuation...","[0.017497486, 0.199402, 0.09975363, -0.0497397..."
2,"– Approximately 79 countries, representing ar...","[0.4582122, 0.1434245, -0.23085447, -0.1877673..."
10,"January 1 – Egypt, Ethiopia, Iran and the Unit...","[0.15804261, -0.020021707, -0.39913028, -0.214..."
11,January 1 – The Republic of Artsakh is formall...,"[0.14056459, 0.11909537, -0.13620375, -0.07031..."


In [9]:
df.to_csv("embeddings.csv")

In [20]:
# Create a Function that Finds Related Pieces of Text from the knowledge base for a Given Question

def get_rows_sorted_by_relevance(question, df, embedding_model):
    # Get embeddings for the question text
    question_embeddings = embeddings_model.encode(question)

    # Make a copy of the dataframe and add a "distances" column containing
    # the cosine distances between each row's embeddings and the
    # embeddings of the question
    df_copy = df.copy()

    # 5. Find Most Similar Passage (Basic Similarity Search)
    all_embeddings = list(df_copy["embeddings"])
    cosine_similarities = util.cos_sim(question_embeddings, all_embeddings)
    df_copy["distances"] = torch.transpose(cosine_similarities, 0, 1)

    # Sort the copied dataframe by the distances and return it
    # (shorter distance = more relevant so we sort in ascending order)
    df_copy.sort_values("distances", ascending=False, inplace=True)
    return df_copy

In [21]:
get_rows_sorted_by_relevance("Where were the Olympics held in 2024?", df, embeddings_model)

Unnamed: 0,text,embeddings,distances
150,July 26 – August 11 – The 2024 Summer Olympics...,"[0.15916829, 0.20322593, -0.3195624, -0.034673...",0.504496
174,August 28 – September 8 – The 2024 Summer Para...,"[0.36730748, 0.27258122, -0.18684757, -0.29181...",0.418873
125,June 14 – July 14 – UEFA Euro 2024 is held in ...,"[0.028763652, -0.08352816, -0.29847503, -0.236...",0.387371
126,June 20 – July 14 – The 2024 Copa América is h...,"[0.096463755, -0.24353446, -0.011640335, -0.36...",0.373329
195,November 24 – 2024 Romanian presidential elect...,"[-0.6341786, 0.4318894, -0.29025233, -0.201624...",0.282455
...,...,...,...
13,January 1 – Ethiopia announces an agreement wi...,"[-0.2743133, -0.17097084, -0.42006847, -0.0958...",-0.053272
25,January 16 – Iran carries out a series of miss...,"[-0.08747466, 0.42818704, -0.31315857, -0.1089...",-0.062791
59,March 25 – The UN Security Council passes a re...,"[-0.2280021, 0.17325671, -0.23184435, -0.20882...",-0.071245
26,January 18 – Pakistan conducts retaliatory air...,"[-0.14258529, 0.21218927, -0.19479619, 0.00249...",-0.092434


In [22]:
# Create a Function that Composes a Text Prompt

# We want to fit as much of our dataset as possible into the "context" part of
# the prompt without exceeding the number of tokens allowed by our model, For a
# safe side we consider the model max context length as 4000 tokens (in reality
# its way more than this and its increasing probably with each day passing)

def create_prompt(question, df, max_token_count):
    """
    Given a question and a dataframe containing rows of text and their
    embeddings, return a text prompt to send to a Completion model
    """
    # Create a tokenizer that is designed to align with our embeddings
    tokenizer = tiktoken.get_encoding("cl100k_base")

    # Count the number of tokens in the prompt template and question
    prompt_template = """
Answer the question based on the context below, and if the question
can't be answered based on the context, say "I don't know"

Context:

{}

---

Question: {}
Answer:"""

    current_token_count = len(tokenizer.encode(prompt_template)) + \
                            len(tokenizer.encode(question))

    context = []
    for text in get_rows_sorted_by_relevance(question, df, embeddings_model)["text"].values:

        # Increase the counter based on the number of tokens in this row
        text_token_count = len(tokenizer.encode(text))
        current_token_count += text_token_count

        # Add the row of text to the list if we haven't exceeded the max
        if current_token_count <= max_token_count:
            context.append(text)
        else:
            break

    return prompt_template.format("\n\n###\n\n".join(context), question)

In [32]:
print(create_prompt("Where were the Olympics held in 2024?", df, 200))


Answer the question based on the context below, and if the question
can't be answered based on the context, say "I don't know"

Context: 

July 26 – August 11 – The 2024 Summer Olympics are held in Paris, France. The controversial opening ceremony and the boxing match of Luca Hámori and Imane Khelif spark international debate.

###

August 28 – September 8 – The 2024 Summer Paralympics are held in Paris, France.

###

June 20 – July 14 – The 2024 Copa América is held in the United States, and is won by Argentina.

###

October 3–20 – The 2024 ICC Women's T20 World Cup is scheduled to be held in the United Arab Emirates.

---

Question: Where were the Olympics held in 2024?
Answer:


In [30]:
# Create a Function that Answers a Question

def answer_question(
    question, df, max_prompt_tokens=1800, max_answer_tokens=150
):
    """
    Given a question, a dataframe containing rows of text, and a maximum
    number of desired tokens in the prompt and response, return the
    answer to the question according to an OpenAI Completion model

    If the model produces an error, return an empty string
    """

    prompt = create_prompt(question, df, max_prompt_tokens)

    try:
        response = genai_model.generate_content(prompt)
        return response.text
    except Exception as e:
        print(e)
        return ""

In [33]:
custom_answer = answer_question("Where were the Olympics held in 2024?", df)
print(custom_answer)

Paris, France 

