diff --git a/hivemind_summarizer/activities.py b/hivemind_summarizer/activities.py index 4e0a72d..7350a90 100644 --- a/hivemind_summarizer/activities.py +++ b/hivemind_summarizer/activities.py @@ -6,9 +6,11 @@ from bson import ObjectId from tc_hivemind_backend.db.qdrant import QdrantSingleton from tc_hivemind_backend.db.mongo import MongoSingleton +from tc_hivemind_backend.ingest_qdrant import CustomIngestionPipeline from temporalio import activity, workflow from qdrant_client.models import Filter, FieldCondition, MatchValue +from qdrant_client.http import models with workflow.unsafe.imports_passed_through(): from hivemind_summarizer.schema import ( @@ -40,7 +42,7 @@ def extract_summary_text(node_content: dict[str, Any]) -> str: @activity.defn -async def get_collection_name(input: TelegramGetCollectionNameInput) -> str: +async def get_platform_name(input: TelegramGetCollectionNameInput) -> str: """ Activity that extracts collection name from MongoDB based on platform_id and community_id. @@ -52,7 +54,7 @@ async def get_collection_name(input: TelegramGetCollectionNameInput) -> str: Returns ------- str - The collection name in format [communityId]_[platformName]_summary + The platform name Raises ------ @@ -83,11 +85,7 @@ async def get_collection_name(input: TelegramGetCollectionNameInput) -> str: if not platform_name: raise Exception(f"Platform name not found for platform_id {platform_id}") - # Construct collection name - collection_name = f"{community_id}_{platform_name}_summary" - - logging.info(f"Generated collection name: {collection_name}") - return collection_name + return platform_name except Exception as e: logging.error(f"Error getting collection name: {str(e)}") @@ -113,11 +111,13 @@ async def fetch_telegram_summaries_by_date( """ date = input.date extract_text_only = input.extract_text_only - collection_name = input.collection_name + collection_name = f"{input.community_id}_{input.platform_name}_summary" + community_id = input.community_id logging.info("Started fetch_telegram_summaries_by_date!") - if not collection_name: - raise ValueError("Collection name is required but was not provided") + + if not input.platform_name: + raise ValueError("Platform name is required but was not provided") logging.info( f"Fetching summaries for date: {date} from collection: {collection_name}" @@ -128,19 +128,46 @@ async def fetch_telegram_summaries_by_date( qdrant_client = QdrantSingleton.get_instance().get_client() # Create filter for the specified date - filter_conditions = [FieldCondition(key="date", match=MatchValue(value=date))] - - date_filter = Filter(must=filter_conditions) - - # Query Qdrant for all summaries matching the date using the provided collection name - search_results = qdrant_client.search( - collection_name=collection_name, - query_vector=[0] * 1024, - query_filter=date_filter, - limit=100, - with_payload=True, - with_vectors=False, - ) + if date is not None: + filter_conditions = [ + FieldCondition(key="date", match=MatchValue(value=date)) + ] + date_filter = Filter(must=filter_conditions) + + # Query Qdrant for all summaries matching the date using the provided collection name + search_results = qdrant_client.search( + collection_name=collection_name, + query_vector=[0] * 1024, + query_filter=date_filter, + limit=100, + with_payload=True, + with_vectors=False, + ) + else: + # pipeline requires a different format for the collection name + pipeline = CustomIngestionPipeline( + community_id=community_id, + collection_name=f"{input.platform_name}_summary", + ) + # get the latest date from the collection + latest_date = pipeline.get_latest_document_date( + field_name="date", field_schema=models.PayloadSchemaType.DATETIME + ) + + filter_conditions = [ + FieldCondition( + key="date", match=MatchValue(value=latest_date.strftime("%Y-%m-%d")) + ) + ] + date_filter = Filter(must=filter_conditions) + search_results = qdrant_client.search( + collection_name=collection_name, + query_vector=[0] * 1024, + query_filter=date_filter, + limit=100, + with_payload=True, + with_vectors=False, + ) summaries = [] for point in search_results: @@ -189,7 +216,7 @@ async def fetch_telegram_summaries_by_date_range( Parameters ---------- input : TelegramSummariesRangeActivityInput - Input object containing start_date, end_date, collection_name and extract_text_only + Input object containing start_date, end_date, platform_name and community_id Returns ------- @@ -199,15 +226,15 @@ async def fetch_telegram_summaries_by_date_range( Raises ------ ValueError - If end_date is before start_date or collection_name is not provided + If end_date is before start_date or platform_name is not provided """ start_date = input.start_date end_date = input.end_date extract_text_only = input.extract_text_only - collection_name = input.collection_name - - if not collection_name: - raise ValueError("Collection name is required but was not provided") + platform_name = input.platform_name + community_id = input.community_id + if not platform_name: + raise ValueError("Platform name is required but was not provided") logging.info( f"Fetching summaries for date range: {start_date} to {end_date} from collection: {collection_name}" @@ -235,7 +262,8 @@ async def fetch_telegram_summaries_by_date_range( date_input = TelegramSummariesActivityInput( date=date, extract_text_only=extract_text_only, - collection_name=collection_name, + platform_name=input.platform_name, + community_id=community_id, ) summaries = await fetch_telegram_summaries_by_date(date_input) result[date] = summaries diff --git a/hivemind_summarizer/schema.py b/hivemind_summarizer/schema.py index efe8ebf..47388bf 100644 --- a/hivemind_summarizer/schema.py +++ b/hivemind_summarizer/schema.py @@ -2,16 +2,18 @@ class TelegramSummariesActivityInput(BaseModel): - date: str + date: str | None = None extract_text_only: bool = True - collection_name: str | None = None + platform_name: str | None = None + community_id: str | None = None class TelegramSummariesRangeActivityInput(BaseModel): start_date: str end_date: str extract_text_only: bool = True - collection_name: str | None = None + platform_name: str | None = None + community_id: str | None = None class TelegramGetCollectionNameInput(BaseModel): @@ -22,6 +24,6 @@ class TelegramGetCollectionNameInput(BaseModel): class TelegramFetchSummariesWorkflowInput(BaseModel): platform_id: str community_id: str - start_date: str + start_date: str | None = None end_date: str | None = None extract_text_only: bool = True diff --git a/hivemind_summarizer/workflows.py b/hivemind_summarizer/workflows.py index b825789..3331db5 100644 --- a/hivemind_summarizer/workflows.py +++ b/hivemind_summarizer/workflows.py @@ -9,7 +9,7 @@ from .activities import ( fetch_telegram_summaries_by_date, fetch_telegram_summaries_by_date_range, - get_collection_name, + get_platform_name, ) from .schema import ( TelegramSummariesActivityInput, @@ -54,8 +54,8 @@ async def run( logging.info("Getting collection name!") # First, get the collection name - collection_name = await workflow.execute_activity( - get_collection_name, + platform_name = await workflow.execute_activity( + get_platform_name, TelegramGetCollectionNameInput( platform_id=input.platform_id, community_id=input.community_id ), @@ -70,7 +70,8 @@ async def run( fetch_telegram_summaries_by_date, TelegramSummariesActivityInput( date=input.start_date, - collection_name=collection_name, + platform_name=platform_name, + community_id=input.community_id, extract_text_only=input.extract_text_only, ), schedule_to_close_timeout=timedelta(minutes=2), @@ -84,7 +85,8 @@ async def run( TelegramSummariesRangeActivityInput( start_date=input.start_date, end_date=input.end_date, - collection_name=collection_name, + platform_name=platform_name, + community_id=input.community_id, extract_text_only=input.extract_text_only, ), schedule_to_close_timeout=timedelta(minutes=2), diff --git a/registry.py b/registry.py index 200b50b..1225ceb 100644 --- a/registry.py +++ b/registry.py @@ -12,7 +12,7 @@ from hivemind_summarizer.activities import ( fetch_telegram_summaries_by_date, fetch_telegram_summaries_by_date_range, - get_collection_name, + get_platform_name, ) from workflows import ( CommunityWebsiteWorkflow, @@ -42,5 +42,5 @@ say_hello, fetch_telegram_summaries_by_date, fetch_telegram_summaries_by_date_range, - get_collection_name, + get_platform_name, ]