In [13]:
import yaml
from story_sage.utils.embedding import load_chunk_from_disk, update_tagged_entities, Embedder
import chromadb
import glob
from story_sage.utils.local_entity_extractor import StorySageEntityExtractor
from story_sage.story_sage_entity import StorySageEntityCollection
from story_sage.vector_store import StorySageRetriever


In [2]:
with open('config.yml', 'r') as file:
    config = yaml.safe_load(file)

api_key = config['OPENAI_API_KEY']
chroma_path = config['CHROMA_PATH']
chroma_collection = config['CHROMA_COLLECTION']
series_path = config['SERIES_PATH']

TARGET_SERIES_ID = 3 # wheel of time
TARGET_BOOK_NUMBER = 2

In [3]:
# Load series.yml to create a mapping from series_metadata_name to series_id
with open(series_path, 'r') as file:
    series_list = yaml.safe_load(file)

target_series_info = next(series for series in series_list if series['series_id'] == TARGET_SERIES_ID)

series_metadata_name = target_series_info['series_metadata_name']

In [4]:
chunks_path = f'./chunks/{series_metadata_name}/semantic_chunks/{TARGET_BOOK_NUMBER}_*.pkl'

chunks = []

for chunk_path in glob.glob(chunks_path):
    chunks.extend(load_chunk_from_disk(chunk_path))

In [5]:
with open(f'./entities/{series_metadata_name}/entities.json', 'r') as file:
    entity_collection = StorySageEntityCollection.from_json(file.read())

In [6]:
chroma_client = chromadb.PersistentClient(path=config['CHROMA_PATH'])
embedder = Embedder()

#chroma_client.delete_collection(config['CHROMA_COLLECTION'])  # Delete the collection if it already exists

# Get or create a collection in the vector store
vector_store = chroma_client.get_or_create_collection(
    name=config['CHROMA_COLLECTION'],
    embedding_function=embedder
)

In [7]:
update_tagged_entities(vector_store=vector_store, entity_collection=entity_collection, series_id=TARGET_SERIES_ID, book_number=TARGET_BOOK_NUMBER)

Updating tagged entities: 0it [00:00, ?it/s]

In [14]:
retriever = StorySageRetriever(chroma_path=chroma_path, chroma_collection_name=chroma_collection, entities=entity_collection)

In [None]:

#extractor = StorySageEntityExtractor(series = target_series_info, device='mps', existing_collection=entity_collection, similarity_threshold=0.5)

#entity_collection: StorySageEntityCollection = extractor.get_grouped_entities(chunks)

In [9]:
#entity_json = entity_collection.to_json()
#with open(f'./entities/{series_metadata_name}/entities.json', 'w') as file:
#    file.write(entity_json)