In [7]:
#pip install --no-cache-dir jupyter langchain_openai langchain_community langchain langgraph faiss-cpu sentence-transformers ipywidgets transformers nltk scikit-learn matplotlib markdown langchain_chroma

import os
from langchain_openai import ChatOpenAI
import faiss
from langchain import hub, PromptTemplate
from langchain_core.documents import Document
from langgraph.graph import START, StateGraph
from typing_extensions import List, TypedDict
from sentence_transformers import SentenceTransformer
import numpy as np
import yaml
import ipywidgets as widgets
from IPython.display import display
import httpx
import torch
import pickle
from tqdm import tqdm
import markdown
import glob
import re
import chromadb

from chromadb import Documents, EmbeddingFunction, Embeddings

from story_sage import StorySage, StorySageGraph

In [20]:
with open('config.yml', 'r') as file:
    config = yaml.safe_load(file)

os.environ['OPENAI_API_KEY'] = config['OPENAI_API_KEY']

if torch.backends.mps.is_available():
    device = torch.device('mps')
    print('MPS backend available')
else:
    device = torch.device('cpu')
    print('MPS backend not available. Using CPU')

llm = ChatOpenAI(model='gpt-4o-mini', http_client = httpx.Client(verify=False))
#tokenizer = SentenceTransformer('multi-qa-mpnet-base-dot-v1')
tokenizer = SentenceTransformer('all-MiniLM-L6-v2')
tokenizer = tokenizer.to(device)


with open('merged_characters.pkl', 'rb') as f:
    character_dict = pickle.load(f)
print('Loaded character dictionary')

class Embedder(EmbeddingFunction):
    def __init__(self, model_name: str = 'all-MiniLM-L6-v2'):
      self.model = SentenceTransformer(model_name)
      self.model = self.model.to(device)

    def __call__(self, input: Documents) -> Embeddings:
       return self.model.encode(input).tolist()
    
    def embed_documents(self, documents: Documents) -> Embeddings:
       embedded_documents = []
       for document in tqdm(documents, desc='Embedding documents'):
          embedded_document = self.model.encode(document)
          embedded_documents.append(embedded_document)
       return embedded_documents
  
embedder = Embedder()

chroma_client = chromadb.PersistentClient(path='./chroma_data')

vector_store = chroma_client.get_collection(
     name="wheel_of_time",
     embedding_function=embedder
  )

MPS backend available
Loaded character dictionary


In [31]:
prompt = PromptTemplate(
    input_variables=['question', 'context'],
    template="""
HUMAN

You are an assistant to help a reader keep track of people, places, and plot points in books.
The following pieces of retrieved context are excerpts from the books related to the reader's question. Use them to generate your response.

Guidelines for the response:
* If you don't know the answer, just say that you don't know. 
* If you're not sure about something, you can say that you're not sure.
* Take as much time as you need to answer the question.
* Use as many words as you need to answer the question completely, but don't provide any irrelevant information.
* Use bullet points to provide examples from the context that support your answer.

Question: {question} 
Context: {context} 
Answer:
    """
)

def retrieve_chunks(query, vector_store: chromadb.Collection, book_number, chapter_number, characters=[], top_k=10):
    
    book_chapter_filter = {
        '$or': [
            {'book_number': {'$lt': book_number}},
            {'$and': [
                {'book_number': book_number},
                {'chapter_number': {'$lt': chapter_number}}
            ]}
        ]
    }

    if characters:
        characters_filter = []
        for character in characters:
            characters_filter.append({f'character_{character}': True})
        if len(characters_filter) == 1:
            characters_filter = characters_filter[0]
        else:
            characters_filter = {'$or': characters_filter}
        query_filter = {
            '$and': [
                characters_filter,
                book_chapter_filter
            ]
        }
    else:
        query_filter = book_chapter_filter


    
    chunks = vector_store.query(
        query_texts=[query],
        n_results=top_k,
        include=['metadatas', 'documents'],
        where=query_filter
    )
    # D, I = index.search(np.array(query_embedding), 20)
    # relevant_chunks = [
    #     doc_collection[i] for i in I[0]
    #     if int(index_metadata[i]['book_number']) < book_number or
    #        (int(index_metadata[i]['book_number'] == book_number and index_metadata[i]['chapter_number'] < chapter_number))
    # ]
    # sort relevant_chunks by book_number and chapter_number in descending order
    #relevant_chunks = sorted(relevant_chunks, key=lambda x: (x.metadata['book_number'], x.metadata['chapter_number']), reverse=True)
    #print('retrieved chunks:', len(relevant_chunks))
    #return relevant_chunks[:top_k]
    return chunks

def retrieve_chunks_from_chroma(query, vector_store: chromadb.Collection, book_number, chapter_number, top_k=10):
    query_embedding = tokenizer.encode([query])
    results = vector_store.query(
        query_embedding=query_embedding,
        n_results=top_k,
        include='metadatas'
    )
    return results

class State(TypedDict):
    question: str
    context: List[str]
    answer: str
    book_number: int
    chapter_number: int
    top_k: int
    characters: List[int]

def get_characters(state: State):
    print('get characters in question')
    characters_in_question = set()
    for character in character_dict.keys():
        if str.lower(character) in str.lower(state['question']):
            characters_in_question.add(character_dict[character])
    print(f'characters in question: {list(characters_in_question)}')
    return {'characters': list(characters_in_question)}

def retrieve(state: State):
    print('retrieve')
    print(state)
    retrieved_docs = retrieve_chunks(state['question'], vector_store, state['book_number'], state['chapter_number'], state['characters'], state['top_k'])
    return {'context': retrieved_docs}

def generate(state: State):
    docs_content = "\n\n".join(doc for doc in state['context']['documents'][0])
    print(f'begin generation with {len(docs_content)} characters of context')
    messages = prompt.invoke({'question': state['question'], 'context': docs_content})
    print('generated message:', messages)
    response = llm.invoke(messages)
    return {'answer': response.content}

graph_builder = StateGraph(State).add_sequence([get_characters, retrieve, generate])
graph_builder.add_edge(START, 'get_characters')
graph = graph_builder.compile()

In [33]:
# Define the input and output widgets
input_box = widgets.Text(
    value='',
    placeholder='Type your question here...',
    description='Question:',
    continuous_update=False,
    disabled=False
)

submit_button = widgets.Button(
    description='Submit',
    disabled=False,
    button_style='',
    tooltip='Click to submit your question',
    icon='check'
)

book_number_box = widgets.IntText(
    value=10,
    description='Book Number:',
    disabled=False
)

chapter_number_box = widgets.IntText(
    value=None,
    description='Chapter Number:',
    disabled=False
)

top_k_box = widgets.IntText(
    value=25,
    description='Top K:',
    disabled=False
)

status_box = widgets.Output(layout={'min_height': '50px'})
output_box = widgets.Output(layout={'min_height': '200px'})
context_box = widgets.Output(layout={'min_height': '200px'})

# Create a spinner widget
spinner = widgets.HTML(
    value="""<i class="fa fa-spinner fa-spin" style="font-size:24px; color:#2a9df4;"></i>""",
    placeholder='Loading...',
    description=''
)

# Make sure Font Awesome is available
display(widgets.HTML("<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-awesome.min.css'>"))

def wrap_answer(answer):
    html = markdown.markdown(answer)
    return f"<div style='background-color: #f9f9f9; padding: 10px; border-radius: 5px;'>{html}</div>"

def show_results(state):
    with output_box:
        output_box.clear_output()
        output_box_contents = []
        output_box_contents.append("<h3>Answer</h3>")
        output_box_contents.append(wrap_answer(state['answer']))
        display(widgets.HTML(''.join(output_box_contents)))
        with context_box:
            context_box.clear_output()
            context_box_contents = []
            context_box_contents.append("<h3>Context</h3>")
            for idx in range(len(state['context']['metadatas'])):
                meta = state['context']['metadatas'][0][idx]
                content = state['context']['documents'][0][idx]
                context_box_contents.append(f"<p><strong>Book Number:</strong> {meta['book_number']} <strong>Chapter Number:</strong> {meta['chapter_number']}</p>")
                context_box_contents.append(f"<p>{content}</p>")
            display(widgets.HTML(wrap_answer("".join(context_box_contents))))

def send_query(state):
    result = graph.invoke(state)
    state['answer'] = result['answer']
    state['context'] = result['context']
    show_results(state)
    

# Define the function to handle the button click
def submit_question(b):
    top_k = top_k_box.value
    with status_box:
        status_box.clear_output()
        display(widgets.HTML(f"<h3>Retrieving top {top_k} relevant chunks...</h3>"))
        with output_box:
            output_box.clear_output()
            display(spinner)

            state = State(
                question=input_box.value,
                book_number=book_number_box.value or 1,
                chapter_number=chapter_number_box.value or 0,
                top_k=top_k_box.value or 10,
                context=[],
                references=[],
                answer=''
            )
            send_query(state)

# Attach the handler to the button
submit_button._click_handlers.callbacks.clear()
submit_button.on_click(submit_question)

# Attach the handler to the input box for the return key
input_box.observe(submit_question)

# Display the widgets
display(status_box, book_number_box, chapter_number_box, top_k_box, input_box, submit_button, output_box, context_box)

HTML(value="<link rel='stylesheet' href='https://cdnjs.cloudflare.com/ajax/libs/font-awesome/4.7.0/css/font-aw…

Output(layout=Layout(min_height='50px'))

IntText(value=10, description='Book Number:')

IntText(value=0, description='Chapter Number:')

IntText(value=25, description='Top K:')

Text(value='', continuous_update=False, description='Question:', placeholder='Type your question here...')

Button(description='Submit', icon='check', style=ButtonStyle(), tooltip='Click to submit your question')

Output(layout=Layout(min_height='200px'))

Output(layout=Layout(min_height='200px'))