In [1]:
import yaml

with open('config.yml', 'r') as file:
    config = yaml.safe_load(file)

api_key = config['OPENAI_API_KEY']
chroma_path = config['CHROMA_PATH']
chroma_collection = config['CHROMA_COLLECTION']

# Load series.yml to create a mapping from series_metadata_name to series_id
with open('series.yml', 'r') as file:
    series_list = yaml.safe_load(file)

# Run queries through story_sage module

In [2]:
import logging
from story_sage.story_sage import StorySage

# Configure the logger

logger = logging.getLogger('story_sage')
logger.setLevel(logging.DEBUG)
# Create a console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)

# Create a formatter and set it for the handler
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)

# Add the handler to the logger
logger.addHandler(console_handler)

# Filter out logs from other modules
class StorySageFilter(logging.Filter):
    def filter(self, record):
        return record.name.startswith('story_sage')

logger.addFilter(StorySageFilter())

# Load all character dictionaries and merge them using the metadata_to_id mapping
# Load entities.json
with open('entities.json', 'r') as file:
    entities = yaml.safe_load(file)

story_sage = StorySage(
    api_key=api_key,
    chroma_path=chroma_path,
    chroma_collection_name=chroma_collection,
    entities=entities,
    series_yml_path='series.yml',
    n_chunks=10
)


# Add a handler to the StorySage logger
story_sage.logger = logger

def invoke_story_sage(data: dict):
    required_keys = ['question', 'book_number', 'chapter_number', 'series_id']
    if not all(key in data for key in required_keys):
        return {'error': f'Missing parameter! Request must include {", ".join(required_keys)}'}, 400

    try:
        result, context = story_sage.invoke(**data)
        return result, context
    except Exception as e:
        raise e
        return {'error': 'Internal server error.'}, 500
    
data = {
    'question': 'Explain the interactions between Cenn and Rand',
    'book_number': 2,
    'chapter_number': 1,
    'series_id': 3
}

In [3]:
response, context = invoke_story_sage(data)
print(response)

2024-12-17 12:26:05,667 - story_sage - INFO - Processing question: Explain the interactions between Cenn and Rand


The interactions between Cenn and Rand primarily revolve around Cenn's role in the village and his often grumpy demeanor, which contrasts with Rand's more youthful perspective. Here are some key points from the excerpts that illustrate their interactions:

- **Cenn's Skepticism**: Cenn expresses skepticism about the village festivities and the expenses related to them, indicating a serious side that contrasts with the more carefree attitudes of the younger characters like Rand and Mat.
  - *“I still say it’s a foolish waste of money. And those fireworks you all insisted on sending off for.”* (Book 1, Chapter 1)

- **Cenn's Authority**: Cenn is shown to be a member of the Village Council, and he often asserts his views strongly. This authority can lead to tension, especially when others, including Rand's father, challenge him.
  - *“Act your age,” Bran added. “And for once remember you’re a member of the Council.”* (Book 1, Chapter 4)

- **Rand's Perspective**: Rand recognizes Cenn as a

# Configure and send queries to ChromaDB Directly

In [23]:
import chromadb
from langchain.embeddings import SentenceTransformerEmbeddings
class EmbeddingAdapter(SentenceTransformerEmbeddings):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _embed_documents(self, texts):
        return super().embed_documents(texts)  

    def __call__(self, input):
        return self._embed_documents(input)  

embedder = EmbeddingAdapter
client = chromadb.PersistentClient(path=chroma_path)
vector_store = client.get_collection(name=chroma_collection)

In [None]:
filter_dict = {'$and': [
                {'$or': [
                    {'book_number': {'$lt': 1}},
                    {'$and': [
                        {'book_number': 1}, 
                        {'chapter_number': {'$lt': 25}}
                    ]}
                ]}, 
                {'3_e_8': True}
               ]}

# filter_dict = {'$or': [
#                     {'book_number': {'$lt': 1}},
#                     {'$and': [
#                         {'book_number': 1}, 
#                         {'chapter_number': {'$lt': 25}}
#                     ]}
#                 ]}
#client.delete_collection('wot_retriever_test')
if True:
    result = vector_store.query(query_texts=['trolloc'],
                                n_results=5,
                                where=filter_dict,
                                include=['metadatas','documents'])

print(result)

# New entity extraction

In [6]:
import glob
import re
import pickle
import os

path_to_chunks = './chunks/wheel_of_time/semantic_chunks'
chunks = {}
for filepath in glob.glob(f'{path_to_chunks}/*.pkl'):
    match = re.match(r'(\d+)_(\d+)\.pkl', os.path.basename(filepath))
    if match:
        book_number, chapter_number = map(int, match.groups())
        with open(filepath, 'rb') as f:
            if book_number not in chunks:
                chunks[book_number] = {}
            chunks[book_number][chapter_number] = pickle.load(f)

## OpenAI based entity extraction

TODO: Change this to only extract "people" and "other"

In [7]:
from pydantic import BaseModel
from openai import OpenAI
import httpx

# Create a custom HTTPX client with SSL verification disabled
req_client = httpx.Client(verify=False)

client = OpenAI(api_key=api_key, http_client=req_client)

full_response = None

class StorySageEntities(BaseModel):
  people: list[str]
  places: list[str]
  groups: list[str]
  animals: list[str]

def extract_named_entities(text):
    completion = client.beta.chat.completions.parse(
        model="gpt-4o-mini",
        messages=[
            {"role": "system", "content": """
                You are a highly advanced natural language processing agent that 
                is optimized to do named entity recognition (NER). Your goal is to
                extract entities and a summary from text provided to you.
                
                For example, if the text is:
                    Standing with the other Whitecloaks, Perrin saw the Lugard Road near the Manetherendrelle and the border of Murandy.
                    If dogs had been able to make footprints on stone, he would have said the tracks were the prints of a pack of large hounds.
             
                Extract:
                    People: Perrin
                    Places: Lugard Road, Manetherendrelle, Murandy
                    Groups: Whitecloaks, pack
                    Animals: dogs
                """},
            {"role": "user", "content": text},
        ],
        response_format=StorySageEntities
    )

    extracted_entity = completion.choices[0].message.parsed
    usage_information = completion.usage

    return extracted_entity, usage_information


#entities = extract_named_entities(chunk_to_extract)
#print(entities)

### Run chunks thru OpenAI for extraction

In [8]:
import time

def extract_entities_from_chunks(chunks):

    num_chapters = len(chunks[1])
    result = []
    
    counter = 0
    len_cap = 400000
    book_chunks = chunks[1]
    for i in range(num_chapters):
        chapter_chunks = book_chunks[i]
        chapter_text = '\n'.join(chapter_chunks)
        chapter_len = len(chapter_text)
        if counter + chapter_len > len_cap:
            print(f'Waiting for 30 seconds to avoid exceeding the character limit. Current chapter: {i + 1}. Current length: {counter}')
            time.sleep(30)
            counter = 0
        result.append(extract_named_entities(chapter_text))
        counter += chapter_len

    print(f'Finished extracting from {len(result)} chapters')
    if result:
        print(result[-5])
    return result

### Dump result into a json so I don't have to run this every time

In [9]:
import json

if False:
    result = extract_entities_from_chunks(chunks)
    with open('01_the_eye_of_the_world.json', 'w') as json_file:
        json.dump(result, json_file, default=lambda o: o.__dict__, indent=4)

with open('01_the_eye_of_the_world.json', 'r') as json_file:
    result = json.load(json_file)

# Process Extracted Entities

In [10]:
num_chapters = len(result[0])

entities_dict = {
    'series': {
        3: {
            'series_metadata_name': 'wheel_of_time', 
            'series_id': 3, 
            'series_name': 'The Wheel of Time',
            'series_entities': {
                'people_by_id': {}, 
                'people_by_name': {}, 
                'entities_by_id': {}, 
                'entities_by_name': {}
            },
            'books': [
                {
                    'book_name': 'The Eye of the World', 
                    'book_number': 1, 
                    'chapter_count': num_chapters,
                    'chapters': []
                }
            ]
        }
    }
}

def collect_unique_values(result):
    series_people_set = set()
    series_entities_set = set()
    
    for chapter in result:
        entities = chapter[0]
        series_people_set.update(entities.get('people', []))
        
        for key, values in entities.items():
            if key != 'people':
                series_entities_set.update(values)

    series_people_list = []
    series_entities_list = []
    
    for person in series_people_set:
        person = person.lower()
        person = ''.join(c for c in person if c.isalpha() or c.isspace())
        series_people_list.append(person)

    for entity in series_entities_set:
        entity = entity.lower()
        entity = ''.join(c for c in entity if c.isalpha() or c.isspace())
        series_entities_list.append(entity)
    
    return series_people_list, series_entities_list

series_people_list, series_entities_list = collect_unique_values(result)

if True:
    with open('entities.json', 'w') as json_file:
        json.dump(entities_dict, json_file, indent=4)

with open('entities.json', 'r') as json_file:
    entities_dict = json.load(json_file)

## Use OpenAI to cluster similar characters

This is an opportunity to make it MUCH smarter. Think things like using semantic understanding to differentiate characters who appear together as separate individuals or other things like that.

Maybe do multiple steps? OpenAI cluster, then bounce those against the text to see if it sounds like they're different characters?

In [11]:

class GroupedEntities(BaseModel):
    entities: list[list[str]]

def group_similar_names(names_to_group):
    text = ', '.join(names_to_group)
    completion = client.beta.chat.completions.parse(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": """
                You are a highly advanced natural language processing agent that 
                is optimized to do named entity recognition (NER). Your goal is to
                group together names that represent the same thing from the text
                provided to you.
             
                Make sure all names in the input are present in the output.   
             
                For example:
                    Input: Bran, Mat, Bran al'Vere, Haral Luhhan, Breyan, Matrim Cauthon, Alsbet Luhhan, Master al'Vere, Mat Cauthon
                    Output: [['Bran', "Bran al'Vere", "Master al'Vere"], ['Mat', 'Matrim Cauthon', 'Mat Cauthon'], ['Breyan'], ['Haral Luhhan'], ['Alsbet Luhhan']]
             
                Another example:
                    Input: sword, axe, horse, spear, mare
                    Output: [['sword', 'axe', 'spear'], ['horse', 'mare']]
                """},
            {"role": "user", "content": text},
        ],
        response_format=GroupedEntities
    )

    return completion.choices[0].message.parsed

In [12]:
grouped_people = group_similar_names(series_people_list)
grouped_entities = group_similar_names(series_entities_list)

In [13]:
from typing import List

def remove_duplicate_elements(grouped_entities: GroupedEntities) -> List[List[str]]:
    # Create a set to track seen names
    seen_names = set()
    filtered_groups = []

    # Iterate through each group in grouped_entities
    for group in grouped_entities.entities:
        # Filter out any names we've seen before
        filtered_group = []
        for name in group:
            if name not in seen_names:
                filtered_group.append(name)
                seen_names.add(name)
                
        # Only keep groups that still have elements after filtering
        if filtered_group:
            filtered_groups.append(filtered_group)

    return filtered_groups

deduped_people = remove_duplicate_elements(grouped_people)
deduped_entities = remove_duplicate_elements(grouped_entities)

# Process Extracted Entities into Series-level Metadata

In [14]:
def create_result_dict(people, entities, base_id):
    result = {
        'people_by_id': {},
        'people_by_name': {},
        'entity_by_id': {},
        'entity_by_name': {}
    }
    
    # Populate people_by_id and people_by_name
    for i, person_list in enumerate(people):
        person_id = f"{base_id}_p_{i}"
        result['people_by_id'][person_id] = person_list
        for name in person_list:
            result['people_by_name'][name] = person_id
    
    # Populate entity_by_id and entity_by_name
    for j, entity_list in enumerate(entities):
        filtered_entities = [entity for entity in entity_list if entity not in result['people_by_name']]
        if filtered_entities:
            entity_id = f"{base_id}_e_{j}"
            result['entity_by_id'][entity_id] = filtered_entities
            for entity in filtered_entities:
                result['entity_by_name'][entity] = entity_id
    
    return result

In [None]:
series_id = 3
series_entities = create_result_dict(deduped_people, deduped_entities, series_id)

entities_dict['series'][str(series_id)]['series_entities'] = series_entities

with open('entities.json', 'w') as json_file:
    json.dump(entities_dict, json_file, indent=4)