In [2]:
import os

os.environ['GEMINI_API_KEY'] = 'AIzaSyBWhyNY9PwVDY7PPu6BZqVfxGqFdPw97sY'

In [3]:
from pypdf import PdfReader

def load_pdf(file_path):
    # Logic to read pdf
    reader = PdfReader(file_path)

    # Loop over each page and store it in a variable
    text = ""
    for page in reader.pages:
        text += page.extract_text()

    return text

# replace the path with your file path
pdf_text = load_pdf(file_path = 'state_of_the_union.pdf')

In [4]:
import re

def split_text(text):
    #split the text by paragraph
    split_text = re.split('\n \n', text)
    
    return [i for i in split_text if i != ""]

chunked_text = split_text(pdf_text)

In [5]:
import google.generativeai as genai
from chromadb import Documents, EmbeddingFunction, Embeddings

class GeminiEmbeddingFunction(EmbeddingFunction):
    def __call__(self, input: Documents) -> Embeddings:
        gemini_api_key = os.getenv("GEMINI_API_KEY")
        if not gemini_api_key:
            raise ValueError("Gemini API Key not provided. Please provide GEMINI_API_KEY as an environment variable")
        genai.configure(api_key = gemini_api_key)
        model = "models/embedding-001"
        title = "Custom query"
        return genai.embed_content(model = model, content = input, task_type = "retrieval_document", title = title)["embedding"]

In [6]:
import chromadb
from typing import List
def create_chroma_db(documents:List, path:str, name:str):
    chroma_client = chromadb.PersistentClient(path = path)
    db = chroma_client.create_collection(name = name, embedding_function = GeminiEmbeddingFunction())

    for i, d in enumerate(documents):
        db.add(documents = d, ids = str(i))

    return db, name

db,name = create_chroma_db(documents = chunked_text, path = os.getcwd(), name = "state_of_the_union")

In [7]:
def load_chroma_collection(path, name):
    chroma_client = chromadb.PersistentClient(path = path)
    db = chroma_client.get_collection(name = name, embedding_function = GeminiEmbeddingFunction())
    return db

db = load_chroma_collection(path = os.getcwd(), name = "state_of_the_union")

In [8]:
def get_relevant_passage(query, db, n_results):
  passage = db.query(query_texts = [query], n_results = n_results)['documents'][0]
  return passage

#Example usage
relevant_text = get_relevant_passage(query = "Sanctions on Russia", db = db, n_results = 3)

In [9]:
relevant_text

['Together with our allies, we are right now enforcing powerful economic sanctions. We are cutting off \nRussia’s largest banks from the international financial system. Preventing Russia’s central bank from \ndefending the Russian Ruble, making Putin’s $630 Bill ion “war fund” worthless. We are choking off \nRussia’s access to technology that will sap its economic strength and weaken its military for years to \ncome.  ',
 'And tonight I am announcing that we will join our allies in closing off American airspace to all Russian \nflights – further isolating Russia – and adding an additional squeeze on their economy. The Ruble has lost \n30% of its value. The Russian stock market h as lost 40% of its value and trading remains suspended. \nRussia’s economy is reeling and Putin alone is to blame.  ',
 'We meet tonight in an America that has lived through two of the hardest years this nation has ever \nfaced. The pandemic has been punishing.  ']

In [10]:
def make_rag_prompt(query, relevant_passage):
  escaped = ''.join(relevant_passage).replace("'", "").replace('"', "").replace("\n", " ")
  prompt = ("""You are a helpful and informative bot that answers questions using text from the reference passage included below. \
  Be sure to respond in a complete sentence, being comprehensive, including all relevant background information. \
  However, you are talking to a non-technical audience, so be sure to break down complicated concepts and \
  strike a friendly and converstional tone. \
  If the passage is irrelevant to the answer, you may ignore it.
  QUESTION: '{query}'
  PASSAGE: '{relevant_passage}'

  ANSWER:
  """).format(query = query, relevant_passage = escaped)

  return prompt

In [15]:
import google.generativeai as genai
def generate_answer1(prompt):
    gemini_api_key = os.getenv("GEMINI_API_KEY")
    if not gemini_api_key:
        raise ValueError("Gemini API Key not provided. Please provide GEMINI_API_KEY as an environment variable")
    genai.configure(api_key = gemini_api_key)
    model = genai.GenerativeModel('gemini-pro')
    answer = model.generate_content(prompt)
    return answer.text

In [16]:
def generate_answer(db,query):
    #retrieve top 3 relevant text chunks
    relevant_text = get_relevant_passage(query,db,n_results=3)
    prompt = make_rag_prompt(query, 
                             relevant_passage="".join(relevant_text)) # joining the relevant chunks to create a single passage
    answer = generate_answer1(prompt)

    return answer

In [17]:
db=load_chroma_collection(path=os.getcwd(), #replace with path of your persistent directory
                          name="state_of_the_union") #replace with the collection name

answer = generate_answer(db,query="what sanctions have been placed on Russia")
print(answer)

International organizations are imposing these severe economic sanctions against Russia:
  * Removing access to the international financial system for Russia's major banks.
  * Preventing Russia's central bank from defending the ruble, making Putin's "$630 Billion 'war fund'" worthless.
  * Restricting Russia's access to technology, which will weaken its economy and military for years to come.
  * Closing off the airspace of the United States and its allies to all Russian aircraft, further isolating Russia and putting additional strain on its economy.
  * Imposing restrictions on Russia's energy sector.
  * Freezing the assets of Russian oligarchs and their family members, as well as the assets of the Russian Central Bank.
  * Banning the export of luxury goods to Russia.
  * Suspending Russia from international organizations, such as the G8 and the World Trade Organization
