In [2]:
import os
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
import faiss
from langchain_community.vectorstores import FAISS, InMemoryVectorStore
from langchain import hub
from langchain_core.documents import Document
from langchain_text_splitters import RecursiveCharacterTextSplitter
from langgraph.graph import START, StateGraph
from typing_extensions import List, TypedDict
from sentence_transformers import SentenceTransformer
import numpy as np
import yaml

with open('config.yml', 'r') as file:
    config = yaml.safe_load(file)

os.environ['OPENAI_API_KEY'] = config['OPENAI_API_KEY']

llm = ChatOpenAI(model='gpt-4o-mini')
tokenizer = SentenceTransformer('all-MiniLM-L6-v2')

  from tqdm.autonotebook import tqdm, trange


In [12]:
import glob
from collections import OrderedDict

def read_text_file(file_path):
    text_dict = OrderedDict()
    for file in glob.glob(file_path):
        fname = os.path.basename(file)
        print(fname)
        with open(file, 'r') as f:
            text_dict[fname] = f.read()
    return text_dict

file_path = './books/*.txt'
text_dict = read_text_file(file_path)
all_text = "\n".join(text_dict.values())
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=200)
all_splits = text_splitter.split_text(all_text)

01_the_eye_of_the_world.txt
06_lord_of_chaos.txt
07_crown_of_swords.txt
09_winters_heart.txt
03_the_dragon_reborn.txt
05_fires_of_heaven.txt
04_shadow_rising.txt
02_the_great_hunt.txt


In [14]:
embeddings = tokenizer.encode(all_splits)
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)

In [15]:
prompt = hub.pull('rlm/rag-prompt')

def retrieve_chunks(query, index, top_k=5):
    query_embedding = tokenizer.encode([query])
    D, I = index.search(np.array(query_embedding), top_k)
    relevant_chunks = [all_splits[i] for i in I[0]]
    return relevant_chunks

class State(TypedDict):
    question: str
    context: List[str]
    answer: str

def retrieve(state: State):
    retrieved_docs = retrieve_chunks(state['question'], index)
    return {'context': retrieved_docs}

def generate(state: State):
    docs_content = "\n\n".join(text for text in state['context'])
    messages = prompt.invoke({'question': state['question'], 'context': docs_content})
    response = llm.invoke(messages)
    return {'answer': response.content}

graph_builder = StateGraph(State).add_sequence([retrieve, generate])
graph_builder.add_edge(START, 'retrieve')
graph = graph_builder.compile()



In [16]:
import ipywidgets as widgets
from IPython.display import display

# Define the input and output widgets
input_box = widgets.Text(
    value='',
    placeholder='Type your question here...',
    description='Question:',
    continuous_update=False,
    disabled=False
)

submit_button = widgets.Button(
    description='Submit',
    disabled=False,
    button_style='',
    tooltip='Click to submit your question',
    icon='check'
)

output_box = widgets.Output()

# Define the function to handle the button click
def submit_question(b):
    with output_box:
        output_box.clear_output()
        state = State(question=input_box.value, context=[], answer='')
        result = graph.invoke(state)
        wrapped_answer = f"<div style='background-color: #f9f9f9; padding: 10px; border-radius: 5px;'>{result['answer']}</div>"
        display(widgets.HTML(wrapped_answer))

# Attach the handler to the button
submit_button.on_click(submit_question)

# Attach the handler to the input box for the return key
input_box.observe(submit_question)

# Display the widgets
display(input_box, submit_button, output_box)

Text(value='', continuous_update=False, description='Question:', placeholder='Type your question here...')

Button(description='Submit', icon='check', style=ButtonStyle(), tooltip='Click to submit your question')

Output()