In [None]:
import yaml

with open('config.yml', 'r') as file:
    config = yaml.safe_load(file)

api_key = config['OPENAI_API_KEY']
chroma_path = config['CHROMA_PATH']
chroma_collection = config['CHROMA_COLLECTION']

# Load series.yml to create a mapping from series_metadata_name to series_id
with open('series.yml', 'r') as file:
    series_list = yaml.safe_load(file)

# Run queries through story_sage module

In [None]:
import logging
from story_sage.story_sage import StorySage

# Configure the logger

logger = logging.getLogger('story_sage')
logger.setLevel(logging.DEBUG)
# Create a console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)

# Create a formatter and set it for the handler
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)

# Add the handler to the logger
logger.addHandler(console_handler)

# Filter out logs from other modules
class StorySageFilter(logging.Filter):
    def filter(self, record):
        return record.name.startswith('story_sage')

logger.addFilter(StorySageFilter())

# Load all character dictionaries and merge them using the metadata_to_id mapping
# Load entities.json
with open('entities.json', 'r') as file:
    entities = yaml.safe_load(file)

story_sage = StorySage(
    api_key=api_key,
    chroma_path=chroma_path,
    chroma_collection_name=chroma_collection,
    entities=entities,
    series_yml_path='series.yml',
    n_chunks=10
)


# Add a handler to the StorySage logger
story_sage.logger = logger

def invoke_story_sage(data: dict):
    required_keys = ['question', 'book_number', 'chapter_number', 'series_id']
    if not all(key in data for key in required_keys):
        return {'error': f'Missing parameter! Request must include {", ".join(required_keys)}'}, 400

    try:
        result, context = story_sage.invoke(**data)
        return result, context
    except Exception as e:
        raise e
        return {'error': 'Internal server error.'}, 500
    
data = {
    'question': 'Explain the interactions between Cenn and Rand',
    'book_number': 2,
    'chapter_number': 1,
    'series_id': 3
}

if False:
    response, context = invoke_story_sage(data)
    print(response)

# Configure and send queries to ChromaDB Directly

In [None]:
import chromadb
from story_sage.story_sage_embedder import StorySageEmbedder
from langchain.embeddings import SentenceTransformerEmbeddings
class EmbeddingAdapter(SentenceTransformerEmbeddings):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _embed_documents(self, texts):
        return super().embed_documents(texts)  

    def __call__(self, input):
        return self._embed_documents(input)  

embedder = EmbeddingAdapter
client = chromadb.PersistentClient(path=chroma_path)
vector_store = client.get_collection(name=chroma_collection)

In [None]:
filter_dict = {'$and': [
                {'$or': [
                    {'book_number': {'$lt': 1}},
                    {'$and': [
                        {'book_number': 1}, 
                        {'chapter_number': {'$lt': 25}}
                    ]}
                ]}, 
                {'a_3_12': True}
               ]}

filter_dict = {'$or': [
                    {'book_number': {'$lt': 1}},
                    {'$and': [
                        {'book_number': 1}, 
                        {'chapter_number': {'$lt': 25}}
                    ]}
                ]}
#client.delete_collection('wot_retriever_test')
if False:
    vector_store.query(query_texts=['trolloc'],
                    n_results=5,
                    where=filter_dict,
                    include=['metadatas','documents'])

# New entity extraction

In [None]:
import glob
import re
import pickle
import os

path_to_chunks = './chunks/wheel_of_time/semantic_chunks'
chunks = {}
for filepath in glob.glob(f'{path_to_chunks}/*.pkl'):
    match = re.match(r'(\d+)_(\d+)\.pkl', os.path.basename(filepath))
    if match:
        book_number, chapter_number = map(int, match.groups())
        with open(filepath, 'rb') as f:
            if book_number not in chunks:
                chunks[book_number] = {}
            chunks[book_number][chapter_number] = pickle.load(f)

## OpenAI based entity extraction

TODO: Change this to only extract "people" and "other"

In [None]:
from pydantic import BaseModel
from openai import OpenAI
import httpx

# Create a custom HTTPX client with SSL verification disabled
req_client = httpx.Client(verify=False)

client = OpenAI(api_key=api_key, http_client=req_client)

full_response = None

class StorySageEntities(BaseModel):
  people: list[str]
  places: list[str]
  groups: list[str]
  animals: list[str]

def extract_named_entities(text):
    completion = client.beta.chat.completions.parse(
        model="gpt-4o-mini",
        messages=[
            {"role": "system", "content": """
                You are a highly advanced natural language processing agent that 
                is optimized to do named entity recognition (NER). Your goal is to
                extract entities and a summary from text provided to you.
                
                For example, if the text is:
                    Standing with the other Whitecloaks, Perrin saw the Lugard Road near the Manetherendrelle and the border of Murandy.
                    If dogs had been able to make footprints on stone, he would have said the tracks were the prints of a pack of large hounds.
             
                Extract:
                    People: Perrin
                    Places: Lugard Road, Manetherendrelle, Murandy
                    Groups: Whitecloaks, pack
                    Animals: dogs
                """},
            {"role": "user", "content": text},
        ],
        response_format=StorySageEntities
    )

    extracted_entity = completion.choices[0].message.parsed
    usage_information = completion.usage

    return extracted_entity, usage_information


#entities = extract_named_entities(chunk_to_extract)
#print(entities)

### Run chunks thru OpenAI for extraction

In [None]:
import time

num_chapters = 0
result = []

def extract_entities_from_chunks(chunks):

    num_chapters = len(chunks[1])
    result = []
    
    counter = 0
    len_cap = 400000
    book_chunks = chunks[1]
    for i in range(num_chapters):
        chapter_chunks = book_chunks[i]
        chapter_text = '\n'.join(chapter_chunks)
        chapter_len = len(chapter_text)
        if counter + chapter_len > len_cap:
            print(f'Waiting for 30 seconds to avoid exceeding the character limit. Current chapter: {i + 1}. Current length: {counter}')
            time.sleep(30)
            counter = 0
        result.append(extract_named_entities(chapter_text))
        counter += chapter_len

    print(f'Finished extracting from {len(result)} chapters')
    if result:
        print(result[-5])
    return result

### Dump result into a json so I don't have to run this every time

In [None]:
import json

if False:
    result = extract_entities_from_chunks(chunks)
    with open('01_the_eye_of_the_world.json', 'w') as json_file:
        json.dump(result, json_file, default=lambda o: o.__dict__, indent=4)

with open('01_the_eye_of_the_world.json', 'r') as json_file:
    result = json.load(json_file)

# Process Extracted Entities

In [None]:
num_chapters = len(result[0])

entities_dict = {
    'series': {
        'wheel_of_time': {
            'series_metadata_name': 'wheel_of_time', 
            'series_id': 3, 
            'series_name': 'The Wheel of Time',
            'series_entities': {
                'people_by_id': {}, 
                'people_by_name': {}, 
                'places_by_id': {}, 
                'places_by_name': {},
                'groups_by_id': {},
                'groups_by_name': {},
                'animals_by_id': {},
                'animals_by_name': {}
            },
            'books': [
                {
                    'book_name': 'The Eye of the World', 
                    'book_number': 1, 
                    'chapter_count': num_chapters,
                    'chapters': [],
                    'book_entities': {
                        'people_by_id': {}, 
                        'people_by_name': {}, 
                        'places_by_id': {},
                        'places_by_name': {},
                        'groups_by_id': {},
                        'groups_by_name': {},
                        'animals_by_id': {},
                        'animals_by_name': {}
                    }
                }
            ]
        }
    }
}

for i, chapter_entities in enumerate(result):
    entities_obj = {'chapter': i, 'people': chapter_entities[0]['people'], 'places': chapter_entities[0]['places'], 'groups': chapter_entities[0]['groups'], 'animals': chapter_entities[0]['animals']}
    entities_obj['people'] = list(set([str.lower(person.replace('’', "'").replace('‘', "'")) for person in entities_obj['people']]))
    entities_obj['places'] = list(set([str.lower(place.replace('’', "'").replace('‘', "'")) for place in entities_obj['places']]))
    entities_obj['groups'] = list(set([str.lower(group.replace('’', "'").replace('‘', "'")) for group in entities_obj['groups']]))
    entities_obj['animals'] = list(set([str.lower(animal.replace('’', "'").replace('‘', "'")) for animal in entities_obj['animals']]))
    entities_dict['series']['wheel_of_time']['books'][0]['chapters'].append(entities_obj)

with open('entities.json', 'w') as json_file:
    json.dump(entities_dict, json_file, indent=4)

with open('entities.json', 'r') as json_file:
    entities_dict = json.load(json_file)

## Use OpenAI to cluster similar characters

This is an opportunity to make it MUCH smarter. Think things like using semantic understanding to differentiate characters who appear together as separate individuals or other things like that.

Maybe do multiple steps? OpenAI cluster, then bounce those against the text to see if it sounds like they're different characters?

In [None]:
all_people = set()
all_places = set()
all_groups = set()
all_animals = set()

for book in entities_dict['series']['wheel_of_time']['books']:
    for chapter in book['chapters']:
        all_people.update(chapter['people'])
        all_places.update(chapter['places'])
        all_groups.update(chapter['groups'])
        all_animals.update(chapter['animals'])

class GroupedEntities(BaseModel):
    all_people: list[list[str]]

def group_similar_names(names_to_group):
    text = ', '.join(names_to_group)
    completion = client.beta.chat.completions.parse(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": """
                You are a highly advanced natural language processing agent that 
                is optimized to do named entity recognition (NER). Your goal is to
                group together names that represent the same person from the text provided to you.
             
                Names usually follow a standard pattern. Haral Luhhan and Alsbet Luhhan are likely to be different people, but Haral Luhhan and Master Luhhan are likely to be the same person.

                Make sure all names in the input are present in the output.   
             
                For example:
                    Input: Bran, Mat, Bran al'Vere, Haral Luhhan, Breyan, Matrim Cauthon, Alsbet Luhhan, Master al'Vere, Mat Cauthon
                    Output: [['Bran', "Bran al'Vere", "Master al'Vere"], ['Mat', 'Matrim Cauthon', 'Mat Cauthon'], ['Breyan'], ['Haral Luhhan'], ['Alsbet Luhhan']]
                """},
            {"role": "user", "content": text},
        ],
        response_format=GroupedEntities
    )

    return completion.choices[0].message.parsed

In [None]:
grouped_entities = group_similar_names(all_people)
all_people = grouped_entities.all_people

In [None]:
all_people

# Process Extracted Entities into Series-level Metadata

In [None]:
# Create names->id and id->name maps
series_entities = entities_dict['series']['wheel_of_time']['series_entities']
series_id = 3

series_entities['people_by_name'] = {}
series_entities['people_by_id'] = {}
for idx, group in enumerate(all_people):
    id = f"p_{series_id}_{idx}"
    series_entities['people_by_id'][id] = group
    for name in group:
        series_entities['people_by_name'][name] = id

series_entities['places_by_name'] = {}
series_entities['places_by_id'] = {}
for idx, place in enumerate(all_places):
    id = f"pl_{series_id}_{idx}"
    series_entities['places_by_id'][id] = place
    series_entities['places_by_name'][place] = id

series_entities['groups_by_id'] = {}
series_entities['groups_by_name'] = {}
for idx, group in enumerate(all_groups):
    id = f"g_{series_id}_{idx}"
    series_entities['groups_by_id'][id] = group
    series_entities['groups_by_name'][group] = id

series_entities['animals_by_id'] = {}
series_entities['animals_by_name'] = {}
for idx, animal in enumerate(all_animals):
    id = f"a_{series_id}_{idx}"
    series_entities['animals_by_id'][id] = animal
    series_entities['animals_by_name'][animal] = id

with open('entities.json', 'w') as json_file:
    json.dump(entities_dict, json_file, indent=4)