In [34]:
import yaml

with open('config.yml', 'r') as file:
    config = yaml.safe_load(file)

api_key = config['OPENAI_API_KEY']
chroma_path = config['CHROMA_PATH']
chroma_collection = config['CHROMA_COLLECTION']

In [35]:
TARGET_SERIES_ID = 3 # wheel of time
TARGET_BOOK_NUMBER = 4


# Load series.yml to create a mapping from series_metadata_name to series_id
with open('series.yml', 'r') as file:
    series_list = yaml.safe_load(file)

target_series_info = next(series for series in series_list if series['series_id'] == TARGET_SERIES_ID)
target_book_info = next(book for book in target_series_info['books'] if book['number_in_series'] == TARGET_BOOK_NUMBER)

series_metadata_name = target_series_info['series_metadata_name']
book_metadata_name = target_book_info['book_metadata_name']

# Run queries through story_sage module

In [36]:
import logging
from story_sage.story_sage import StorySage

# Configure the logger

logger = logging.getLogger('story_sage')
logger.setLevel(logging.DEBUG)
# Create a console handler
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.INFO)

# Create a formatter and set it for the handler
formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
console_handler.setFormatter(formatter)

# Add the handler to the logger
logger.addHandler(console_handler)

# Filter out logs from other modules
class StorySageFilter(logging.Filter):
    def filter(self, record):
        return record.name.startswith('story_sage')

logger.addFilter(StorySageFilter())

# Load all character dictionaries and merge them using the metadata_to_id mapping
# Load entities.json
with open('entities.json', 'r') as file:
    entities = yaml.safe_load(file)

story_sage = StorySage(
    api_key=api_key,
    chroma_path=chroma_path,
    chroma_collection_name=chroma_collection,
    entities=entities,
    series_yml_path='series.yml',
    n_chunks=10
)


# Add a handler to the StorySage logger
story_sage.logger = logger

def invoke_story_sage(data: dict):
    required_keys = ['question', 'book_number', 'chapter_number', 'series_id']
    if not all(key in data for key in required_keys):
        return {'error': f'Missing parameter! Request must include {", ".join(required_keys)}'}, 400

    try:
        result, context = story_sage.invoke(**data)
        return result, context
    except Exception as e:
        raise e
        return {'error': 'Internal server error.'}, 500
    
data = {
    'question': 'Explain the interactions between Cenn and Rand',
    'book_number': 2,
    'chapter_number': 1,
    'series_id': 3
}

In [37]:
#response, context = invoke_story_sage(data)
#print(response)

# Configure and send queries to ChromaDB Directly

In [38]:
import chromadb
from langchain.embeddings import SentenceTransformerEmbeddings
class EmbeddingAdapter(SentenceTransformerEmbeddings):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

    def _embed_documents(self, texts):
        return super().embed_documents(texts)  

    def __call__(self, input):
        return self._embed_documents(input)  

embedder = EmbeddingAdapter
client = chromadb.PersistentClient(path=chroma_path)
vector_store = client.get_collection(name=chroma_collection)

In [39]:
filter_dict = {'$and': [
                {'$or': [
                    {'book_number': {'$lt': 1}},
                    {'$and': [
                        {'book_number': 1}, 
                        {'chapter_number': {'$lt': 25}}
                    ]}
                ]}, 
                {'3_e_8': True}
               ]}

# filter_dict = {'$or': [
#                     {'book_number': {'$lt': 1}},
#                     {'$and': [
#                         {'book_number': 1}, 
#                         {'chapter_number': {'$lt': 25}}
#                     ]}
#                 ]}
#client.delete_collection('wot_retriever_test')
if False:
    result = vector_store.query(query_texts=['trolloc'],
                                n_results=5,
                                where=filter_dict,
                                include=['metadatas','documents'])
    print(result)

# New entity extraction

Load chunks dict:

```python
chunks = {
    int: {
        int: List[str]
    }
}
```

where keys are `chunks[book_number][chapter_number]` and values are lists of strings, each string is a chunk of text.

In [40]:
import glob
import re
import pickle
import os

path_to_chunks = f'./chunks/{series_metadata_name}/semantic_chunks'
chunks = {}
for filepath in glob.glob(f'{path_to_chunks}/*.pkl'):
    match = re.match(r'(\d+)_(\d+)\.pkl', os.path.basename(filepath))
    if match:
        book_number, chapter_number = map(int, match.groups())
        with open(filepath, 'rb') as f:
            if book_number not in chunks:
                chunks[book_number] = {}
            chunks[book_number][chapter_number] = pickle.load(f)


## OpenAI based entity extraction

In [45]:
from pydantic import BaseModel
from openai import OpenAI
import httpx
import time

# Create a custom HTTPX client with SSL verification disabled
req_client = httpx.Client(verify=False)

client = OpenAI(api_key=api_key, http_client=req_client)

full_response = None

class StorySageEntities(BaseModel):
  people: list[str]
  places: list[str]
  groups: list[str]
  animals: list[str]
  objects: list[str]

def extract_named_entities(text):
    completion = client.beta.chat.completions.parse(
        model="gpt-4o-mini",
        messages=[
            {"role": "system", "content": """
                You are a highly advanced natural language processing agent that 
                is optimized to do named entity recognition (NER). Your goal is to
                extract entities and a summary from text provided to you.
                
                For example, if the text is:
                    Standing with the other Whitecloaks, Perrin saw the Lugard Road near the Manetherendrelle and the border of Murandy.
                    If dogs had been able to make footprints on stone, he would have said the tracks were the prints of a pack of large hounds.
                    He hefted his axe and kicked aside the basket on the road.
             
                Extract:
                    People: Perrin
                    Places: Lugard Road, Manetherendrelle, Murandy
                    Groups: Whitecloaks, pack
                    Animals: dogs
                    Objects: axe, basket
                """},
            {"role": "user", "content": text},
        ],
        response_format=StorySageEntities
    )

    extracted_entity = completion.choices[0].message.parsed
    usage_information = completion.usage

    return extracted_entity, usage_information

def extract_entities_from_chunks(book_chunks: dict, token_per_min_limit: int = 200000, cooldown_secs: int = 30) -> list:
    """
    Extract named entities from chunks of text in a book.
    Args:
        book_chunks (dict): A dictionary where keys are chapter indices and values are lists of text chunks.
        token_per_min_limit (int, optional): The limit on the number of tokens processed per minute. Defaults to 200000.
        cooldown_secs (int, optional): The cooldown period in seconds to wait if the character limit is exceeded. Defaults to 30.
    Raises:
        ValueError: If cooldown_secs is greater than 30.
    Returns:
        list: A list of extracted named entities from each chapter.
    """
    
    # Raise an error if cooldown_secs > 30
    if cooldown_secs > 30:
        raise ValueError('Cooldown seconds cannot exceed 30 seconds.')
    
    # Calculate the number of chapters and initialize an empty list to store results
    num_chapters = len(book_chunks)
    result = []

    # Set a limit on the number of tokens processed per minute
    len_cap = (token_per_min_limit * 4) / (60 / cooldown_secs)  # ~ 4 characters per token, divide by 2 for 30s cooldown
    
    # Keep track of the number of text characters processed
    counter = 0

    # Iterate over the chapters
    for i, chapter_chunks in book_chunks.items():
        # Extract the chapter chunks and join them into a single text
        chapter_text = '\n'.join(chapter_chunks)

        # Check if the chapter length exceeds the limit and wait if necessary
        chapter_len = len(chapter_text)
        if counter + chapter_len > len_cap:
            print(f'Waiting for 30 seconds to avoid exceeding the character limit. Current chapter: {i + 1}. Current length: {counter}')
            time.sleep(30)
            counter = 0
        
        # Extract named entities from the chapter text
        result.append(extract_named_entities(chapter_text))

        # Update the character counter
        counter += chapter_len

    print(f'Finished extracting from {num_chapters} chapters')
    
    return result

### Dump result into a json so I don't have to run this every time

```python
series_info = {
    'series_id': 3,
    'series_name': 'Wheel of Time',
    'series_metadata_name': 'wheel_of_time',
    'books': [{
        'number_in_series': 1,
        'title': 'The Eye of the World',
        'book_metadata_name': '01_the_eye_of_the_world',
        'number_of_chapters': 53
    }]
}

target_book_info = {
    'number_in_series': 2,
    'title': 'The Great Hunt',
    'book_metadata_name': '02_the_great_hunt',
    'number_of_chapters': 50
}
```


In [46]:
import json

target_file_path = f'./entities/{series_metadata_name}'
if not os.path.exists(target_file_path):
    os.makedirs(target_file_path)
target_filename = f'{target_file_path}/{book_metadata_name}.json'

# Extract entities from the chunks of the target book
if True:
    extracted_entities_dict = extract_entities_from_chunks(chunks[TARGET_BOOK_NUMBER])
    with open(target_filename, 'w') as json_file:
        json.dump(extracted_entities_dict, json_file, default=lambda o: o.__dict__, indent=4)

Waiting for 30 seconds to avoid exceeding the character limit. Current chapter: 46. Current length: 396248
Waiting for 30 seconds to avoid exceeding the character limit. Current chapter: 21. Current length: 373705
Waiting for 30 seconds to avoid exceeding the character limit. Current chapter: 25. Current length: 372095
Waiting for 30 seconds to avoid exceeding the character limit. Current chapter: 40. Current length: 377324
Waiting for 30 seconds to avoid exceeding the character limit. Current chapter: 2. Current length: 341166
Finished extracting from 58 chapters


```python
extracted_entities_dict = [
    [
        {
            'people': List[str],
            'places': List[str],
            'groups': List[str],
            'animals': List[str],
            'objects': List[str],
        },
        {
            <OpenAI Usage information>
        }
    ]
]
```

where index in `extracted_entities_dict` corresponds to the chapter number (chapter 0 is anything before chapter 1)

# Collect all extracted entities from the series

# Process Extracted Entities

In [19]:
def collect_unique_values(extracted_entities_dict: dict) -> tuple[list, list]:
    series_people_set = set()
    series_entities_set = set()
    
    for chapter in extracted_entities_dict:
        entities = chapter[0]
        series_people_set.update(entities.get('people', []))
        
        for key, values in entities.items():
            if key != 'people':
                series_entities_set.update(values)

    series_people_list = []
    series_entities_list = []
    
    for person in series_people_set:
        person = person.lower()
        person = ''.join(c for c in person if c.isalpha() or c.isspace())
        series_people_list.append(person)

    for entity in series_entities_set:
        entity = entity.lower()
        entity = ''.join(c for c in entity if c.isalpha() or c.isspace())
        series_entities_list.append(entity)
    
    return series_people_list, series_entities_list

series_people_list, series_entities_list = collect_unique_values(extracted_entities_dict=extracted_entities_dict)

## Use OpenAI to cluster similar characters

This is an opportunity to make it MUCH smarter. Think things like using semantic understanding to differentiate characters who appear together as separate individuals or other things like that.

Maybe do multiple steps? OpenAI cluster, then bounce those against the text to see if it sounds like they're different characters?

In [20]:

class GroupedEntities(BaseModel):
    entities: list[list[str]]

def group_similar_names(names_to_group: list[str]) -> GroupedEntities:
    text = ', '.join(names_to_group)
    completion = client.beta.chat.completions.parse(
        model="gpt-4o",
        messages=[
            {"role": "system", "content": """
                You are a highly advanced natural language processing agent that 
                is optimized to do named entity recognition (NER). Your goal is to
                group together names that represent the same thing from the text
                provided to you.
             
                Make sure all names in the input are present in the output.   
             
                For example:
                    Input: Bran, Mat, Bran al'Vere, Haral Luhhan, Breyan, Matrim Cauthon, Alsbet Luhhan, Master al'Vere, Mat Cauthon
                    Output: [['Bran', "Bran al'Vere", "Master al'Vere"], ['Mat', 'Matrim Cauthon', 'Mat Cauthon'], ['Breyan'], ['Haral Luhhan'], ['Alsbet Luhhan']]
             
                Another example:
                    Input: sword, axe, horse, spear, mare
                    Output: [['sword', 'axe', 'spear'], ['horse', 'mare']]
                """},
            {"role": "user", "content": text},
        ],
        response_format=GroupedEntities
    )

    return completion.choices[0].message.parsed

In [21]:
grouped_people = group_similar_names(series_people_list)
grouped_entities = group_similar_names(series_entities_list)

In [24]:
from typing import List

def remove_duplicate_elements(grouped_entities: GroupedEntities) -> List[List[str]]:
    # Create a set to track seen names
    seen_names = set()
    filtered_groups = []

    # Iterate through each group in grouped_entities
    for group in grouped_entities.entities:
        # Filter out any names we've seen before
        filtered_group = []
        for name in group:
            if name not in seen_names:
                filtered_group.append(name)
                seen_names.add(name)
                
        # Only keep groups that still have elements after filtering
        if filtered_group:
            filtered_groups.append(filtered_group)

    return filtered_groups

deduped_people = remove_duplicate_elements(grouped_people)
deduped_entities = remove_duplicate_elements(grouped_entities)

# Process Extracted Entities into Series-level Metadata

In [25]:
def create_result_dict(people, entities, base_id):
    result = {
        'people_by_id': {},
        'people_by_name': {},
        'entity_by_id': {},
        'entity_by_name': {}
    }
    
    # Populate people_by_id and people_by_name
    for i, person_list in enumerate(people):
        person_id = f"{base_id}_p_{i}"
        result['people_by_id'][person_id] = person_list
        for name in person_list:
            result['people_by_name'][name] = person_id
    
    # Populate entity_by_id and entity_by_name
    for j, entity_list in enumerate(entities):
        filtered_entities = [entity for entity in entity_list if entity not in result['people_by_name']]
        if filtered_entities:
            entity_id = f"{base_id}_e_{j}"
            result['entity_by_id'][entity_id] = filtered_entities
            for entity in filtered_entities:
                result['entity_by_name'][entity] = entity_id
    
    return result

In [None]:
series_entities = create_result_dict(deduped_people, deduped_entities, TARGET_SERIES_ID)

entities_dict['series'][str(series_id)]['series_entities'] = series_entities

with open('entities.json', 'w') as json_file:
    json.dump(entities_dict, json_file, indent=4)