In [21]:
import os
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import faiss
from langchain_community.vectorstores import FAISS, InMemoryVectorStore
from langchain import hub, PromptTemplate
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langgraph.graph import START, StateGraph
from typing_extensions import List, TypedDict
from sentence_transformers import SentenceTransformer
import numpy as np
import yaml
import ipywidgets as widgets
from IPython.display import display
import httpx
import threading
import torch
import pickle

if torch.backends.mps.is_available():
    device = torch.device('mps')
    print('MPS backend available')
else:
    device = torch.device('cpu')
    print('MPS backend not available. Using CPU')

with open('config.yml', 'r') as file:
    config = yaml.safe_load(file)

os.environ['OPENAI_API_KEY'] = config['OPENAI_API_KEY']

llm = ChatOpenAI(model='gpt-4o', http_client = httpx.Client(verify=False))
tokenizer = SentenceTransformer('all-MiniLM-L6-v2')
tokenizer = tokenizer.to(device)

MPS backend available


In [24]:
import glob
from collections import OrderedDict
import re

def read_text_file(file_path):
    text_dict = OrderedDict()
    for file in glob.glob(file_path):
        fname = os.path.basename(file)
        book_number = int(re.match(r'^(\d+)_', fname).group(1))
        print(f'Name: {fname} Book Number: {book_number}')
        with open(file, 'r') as f:
            book_info = {'book_number': book_number, 'chapters': {0: []}}
            # Remove any line breaks between the word "chapter" and following digits
            content = f.read()
            content = re.sub(r'(CHAPTER)\s+(\d+)', r'\1 \2', content, flags=re.IGNORECASE)
            chapter_number = 0
            for line in content.split('\n'):
                line = line.strip()
                if len(line) == 0:
                    continue
                if re.match(r'CHAPTER \d+', line, re.IGNORECASE):
                    chapter_number += 1
                    if chapter_number not in book_info['chapters']:
                        book_info['chapters'][chapter_number] = []
                book_info['chapters'][chapter_number].append(line)
            text_dict[fname] = book_info
    return text_dict

file_path = './books/*.txt'
text_dict = read_text_file(file_path)
doc_collection = []
for book_name, book_info in text_dict.items():
    book_number = book_info['book_number']
    for chapter_number, chapter_text in book_info['chapters'].items():
        for chunk in chapter_text:
            chunk = chunk.strip()
            if len(chunk) > 0:
                doc = Document(
                    page_content=chunk,
                    metadata={
                        'book_number': book_number,
                        'chapter_number': chapter_number
                    }
                )
                doc_collection.append(doc)

Name: 01_the_eye_of_the_world.txt Book Number: 1
Name: 06_lord_of_chaos.txt Book Number: 6
Name: 07_crown_of_swords.txt Book Number: 7
Name: 09_winters_heart.txt Book Number: 9
Name: 03_the_dragon_reborn.txt Book Number: 3
Name: 05_fires_of_heaven.txt Book Number: 5
Name: 04_shadow_rising.txt Book Number: 4
Name: 02_the_great_hunt.txt Book Number: 2


In [25]:
def encode_documents(doc_collection):
  # Encode the documents and store the embeddings along with metadata
  embeddings = []
  metadata = []

  for doc in doc_collection:
    embedding = tokenizer.encode([doc.page_content])[0]
    embeddings.append(embedding)
    metadata.append(doc.metadata)

  embeddings = np.array(embeddings)

  # Create and populate the FAISS index
  index = faiss.IndexFlatL2(embeddings.shape[1])
  index.add(embeddings)

  return index, metadata

if True:
  index, index_metadata = encode_documents(doc_collection)

  # Save the index and index_metadata to disk for further use
  faiss.write_index(index, 'index.faiss')
  with open('index_metadata.pkl', 'wb') as f:
    pickle.dump(index_metadata, f)
else:
  index = faiss.read_index('index.faiss')
  with open('index_metadata.pkl', 'rb') as f:
    index_metadata = pickle.load(f)

In [26]:
prompt = PromptTemplate(
    input_variables=['question', 'context'],
    template="""
HUMAN

You are an assistant to help a reader keep track of people, places, and plot points in books.
The following pieces of retrieved context are excerpts from the books related to the reader's question. Use them to generate your response.

Guidelines for the response:
* If you don't know the answer, just say that you don't know. 
* If you're not sure about something, you can say that you're not sure.
* Take as much time as you need to answer the question.
* Use as many words as you need to answer the question completely, but don't provide any irrelevant information.
* Use bullet points to provide examples from the context that support your answer.

Question: {question} 
Context: {context} 
Answer:
    """
)

def retrieve_chunks(query, index, book_number, chapter_number, top_k=10):
    query_embedding = tokenizer.encode([query])
    D, I = index.search(np.array(query_embedding), 1000)
    relevant_chunks = [
        doc_collection[i] for i in I[0]
        if int(index_metadata[i]['book_number']) < book_number or
           (int(index_metadata[i]['book_number'] == book_number and index_metadata[i]['chapter_number'] < chapter_number))
    ]
    # sort relevant_chunks by book_number and chapter_number in descending order
    relevant_chunks = sorted(relevant_chunks, key=lambda x: (x.metadata['book_number'], x.metadata['chapter_number']), reverse=True)
    return relevant_chunks[:top_k]

class State(TypedDict):
    question: str
    context: List[str]
    answer: str
    book_number: int
    chapter_number: int
    top_k: int

def retrieve(state: State):
    retrieved_docs = retrieve_chunks(state['question'], index, state['book_number'], state['chapter_number'], state['top_k'])
    state['references'] = retrieved_docs
    return {'context': retrieved_docs}

def generate(state: State):
    docs_content = "\n\n".join(doc.page_content for doc in state['context'])
    messages = prompt.invoke({'question': state['question'], 'context': docs_content})
    response = llm.invoke(messages)
    return {'answer': response.content}

graph_builder = StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(START, 'retrieve')
graph = graph_builder.compile()

In [28]:
import IPython
import markdown

# Define the input and output widgets
input_box = widgets.Text(
    value='',
    placeholder='Type your question here...',
    description='Question:',
    continuous_update=False,
    disabled=False
)

submit_button = widgets.Button(
    description='Submit',
    disabled=False,
    button_style='',
    tooltip='Click to submit your question',
    icon='check'
)

book_number_box = widgets.IntText(
    value=10,
    description='Book Number:',
    disabled=False
)

chapter_number_box = widgets.IntText(
    value=None,
    description='Chapter Number:',
    disabled=False
)

top_k_box = widgets.IntText(
    value=25,
    description='Top K:',
    disabled=False
)

status_box = widgets.Output(layout={'min_height': '50px'})
output_box = widgets.Output(layout={'min_height': '200px'})
context_box = widgets.Output(layout={'min_height': '200px'})

# Create a spinner widget
spinner = widgets.HTML(
    value="""<i class="fa fa-spinner fa-spin" style="font-size:24px; color:#2a9df4;"></i>""",
    placeholder='Loading...',
    description=''
)

# Make sure Font Awesome is available
display(widgets.HTML("<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css'>"))

def wrap_answer(answer):
    html = markdown.markdown(answer)
    return f"<div style='background-color: #f9f9f9; padding: 10px; border-radius: 5px;'>{html}</div>"

def show_results(state):
    with output_box:
        output_box.clear_output()
        output_box_contents = []
        output_box_contents.append("<h3>Answer</h3>")
        output_box_contents.append(wrap_answer(state['answer']))
        display(widgets.HTML(''.join(output_box_contents)))
        with context_box:
            context_box.clear_output()
            context_box_contents = []
            context_box_contents.append("<h3>Context</h3>")
            for doc in state['context']:
                context_box_contents.append(f"<p><strong>Book Number:</strong> {doc.metadata['book_number']} <strong>Chapter Number:</strong> {doc.metadata['chapter_number']}</p>")
                context_box_contents.append(f"<p>{doc.page_content}</p>")
            display(widgets.HTML(wrap_answer("".join(context_box_contents))))

def send_query(state):
    result = graph.invoke(state)
    state['answer'] = result['answer']
    state['context'] = result['context']
    show_results(state)
    

# Define the function to handle the button click
def submit_question(b):
    top_k = top_k_box.value
    with status_box:
        status_box.clear_output()
        display(widgets.HTML(f"<h3>Retrieving top {top_k} relevant chunks...</h3>"))
        with output_box:
            output_box.clear_output()
            display(spinner)

            state = State(
                question=input_box.value,
                book_number=book_number_box.value or 1,
                chapter_number=chapter_number_box.value or 0,
                top_k=top_k_box.value or 10,
                context=[],
                references=[],
                answer=''
            )
            send_query(state)

# Attach the handler to the button
submit_button._click_handlers.callbacks.clear()
submit_button.on_click(submit_question)

# Attach the handler to the input box for the return key
input_box.observe(submit_question)

# Display the widgets
display(status_box, book_number_box, chapter_number_box, top_k_box, input_box, submit_button, output_box, context_box)

HTML(value="<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-aw…

Output(layout=Layout(min_height='50px'))

IntText(value=10, description='Book Number:')

IntText(value=0, description='Chapter Number:')

IntText(value=25, description='Top K:')

Text(value='', continuous_update=False, description='Question:', placeholder='Type your question here...')

Button(description='Submit', icon='check', style=ButtonStyle(), tooltip='Click to submit your question')

Output(layout=Layout(min_height='200px'))

Output(layout=Layout(min_height='200px'))