Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 58 additions & 30 deletions hivemind_summarizer/activities.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,11 @@
from bson import ObjectId
from tc_hivemind_backend.db.qdrant import QdrantSingleton
from tc_hivemind_backend.db.mongo import MongoSingleton
from tc_hivemind_backend.ingest_qdrant import CustomIngestionPipeline

from temporalio import activity, workflow
from qdrant_client.models import Filter, FieldCondition, MatchValue
from qdrant_client.http import models

with workflow.unsafe.imports_passed_through():
from hivemind_summarizer.schema import (
Expand Down Expand Up @@ -40,7 +42,7 @@ def extract_summary_text(node_content: dict[str, Any]) -> str:


@activity.defn
async def get_collection_name(input: TelegramGetCollectionNameInput) -> str:
async def get_platform_name(input: TelegramGetCollectionNameInput) -> str:
"""
Activity that extracts collection name from MongoDB based on platform_id and community_id.

Expand All @@ -52,7 +54,7 @@ async def get_collection_name(input: TelegramGetCollectionNameInput) -> str:
Returns
-------
str
The collection name in format [communityId]_[platformName]_summary
The platform name

Raises
------
Expand Down Expand Up @@ -83,11 +85,7 @@ async def get_collection_name(input: TelegramGetCollectionNameInput) -> str:
if not platform_name:
raise Exception(f"Platform name not found for platform_id {platform_id}")

# Construct collection name
collection_name = f"{community_id}_{platform_name}_summary"

logging.info(f"Generated collection name: {collection_name}")
return collection_name
return platform_name

except Exception as e:
logging.error(f"Error getting collection name: {str(e)}")
Expand All @@ -113,11 +111,13 @@ async def fetch_telegram_summaries_by_date(
"""
date = input.date
extract_text_only = input.extract_text_only
collection_name = input.collection_name
collection_name = f"{input.community_id}_{input.platform_name}_summary"
community_id = input.community_id

logging.info("Started fetch_telegram_summaries_by_date!")
if not collection_name:
raise ValueError("Collection name is required but was not provided")

if not input.platform_name:
raise ValueError("Platform name is required but was not provided")

logging.info(
f"Fetching summaries for date: {date} from collection: {collection_name}"
Expand All @@ -128,19 +128,46 @@ async def fetch_telegram_summaries_by_date(
qdrant_client = QdrantSingleton.get_instance().get_client()

# Create filter for the specified date
filter_conditions = [FieldCondition(key="date", match=MatchValue(value=date))]

date_filter = Filter(must=filter_conditions)

# Query Qdrant for all summaries matching the date using the provided collection name
search_results = qdrant_client.search(
collection_name=collection_name,
query_vector=[0] * 1024,
query_filter=date_filter,
limit=100,
with_payload=True,
with_vectors=False,
)
if date is not None:
filter_conditions = [
FieldCondition(key="date", match=MatchValue(value=date))
]
date_filter = Filter(must=filter_conditions)

# Query Qdrant for all summaries matching the date using the provided collection name
search_results = qdrant_client.search(
collection_name=collection_name,
query_vector=[0] * 1024,
query_filter=date_filter,
limit=100,
with_payload=True,
with_vectors=False,
)
else:
# pipeline requires a different format for the collection name
pipeline = CustomIngestionPipeline(
community_id=community_id,
collection_name=f"{input.platform_name}_summary",
)
# get the latest date from the collection
latest_date = pipeline.get_latest_document_date(
field_name="date", field_schema=models.PayloadSchemaType.DATETIME
)

filter_conditions = [
FieldCondition(
key="date", match=MatchValue(value=latest_date.strftime("%Y-%m-%d"))
)
]
date_filter = Filter(must=filter_conditions)
search_results = qdrant_client.search(
collection_name=collection_name,
query_vector=[0] * 1024,
query_filter=date_filter,
limit=100,
with_payload=True,
with_vectors=False,
)

summaries = []
for point in search_results:
Expand Down Expand Up @@ -189,7 +216,7 @@ async def fetch_telegram_summaries_by_date_range(
Parameters
----------
input : TelegramSummariesRangeActivityInput
Input object containing start_date, end_date, collection_name and extract_text_only
Input object containing start_date, end_date, platform_name and community_id

Returns
-------
Expand All @@ -199,15 +226,15 @@ async def fetch_telegram_summaries_by_date_range(
Raises
------
ValueError
If end_date is before start_date or collection_name is not provided
If end_date is before start_date or platform_name is not provided
"""
start_date = input.start_date
end_date = input.end_date
extract_text_only = input.extract_text_only
collection_name = input.collection_name

if not collection_name:
raise ValueError("Collection name is required but was not provided")
platform_name = input.platform_name
community_id = input.community_id
if not platform_name:
raise ValueError("Platform name is required but was not provided")

logging.info(
f"Fetching summaries for date range: {start_date} to {end_date} from collection: {collection_name}"
Expand Down Expand Up @@ -235,7 +262,8 @@ async def fetch_telegram_summaries_by_date_range(
date_input = TelegramSummariesActivityInput(
date=date,
extract_text_only=extract_text_only,
collection_name=collection_name,
platform_name=input.platform_name,
community_id=community_id,
)
summaries = await fetch_telegram_summaries_by_date(date_input)
result[date] = summaries
Expand Down
10 changes: 6 additions & 4 deletions hivemind_summarizer/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,18 @@


class TelegramSummariesActivityInput(BaseModel):
date: str
date: str | None = None
extract_text_only: bool = True
collection_name: str | None = None
platform_name: str | None = None
community_id: str | None = None


class TelegramSummariesRangeActivityInput(BaseModel):
start_date: str
end_date: str
extract_text_only: bool = True
collection_name: str | None = None
platform_name: str | None = None
community_id: str | None = None


class TelegramGetCollectionNameInput(BaseModel):
Expand All @@ -22,6 +24,6 @@ class TelegramGetCollectionNameInput(BaseModel):
class TelegramFetchSummariesWorkflowInput(BaseModel):
platform_id: str
community_id: str
start_date: str
start_date: str | None = None
end_date: str | None = None
extract_text_only: bool = True
12 changes: 7 additions & 5 deletions hivemind_summarizer/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .activities import (
fetch_telegram_summaries_by_date,
fetch_telegram_summaries_by_date_range,
get_collection_name,
get_platform_name,
)
from .schema import (
TelegramSummariesActivityInput,
Expand Down Expand Up @@ -54,8 +54,8 @@ async def run(

logging.info("Getting collection name!")
# First, get the collection name
collection_name = await workflow.execute_activity(
get_collection_name,
platform_name = await workflow.execute_activity(
get_platform_name,
TelegramGetCollectionNameInput(
platform_id=input.platform_id, community_id=input.community_id
),
Expand All @@ -70,7 +70,8 @@ async def run(
fetch_telegram_summaries_by_date,
TelegramSummariesActivityInput(
date=input.start_date,
collection_name=collection_name,
platform_name=platform_name,
community_id=input.community_id,
extract_text_only=input.extract_text_only,
),
schedule_to_close_timeout=timedelta(minutes=2),
Expand All @@ -84,7 +85,8 @@ async def run(
TelegramSummariesRangeActivityInput(
start_date=input.start_date,
end_date=input.end_date,
collection_name=collection_name,
platform_name=platform_name,
community_id=input.community_id,
extract_text_only=input.extract_text_only,
),
schedule_to_close_timeout=timedelta(minutes=2),
Expand Down
4 changes: 2 additions & 2 deletions registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from hivemind_summarizer.activities import (
fetch_telegram_summaries_by_date,
fetch_telegram_summaries_by_date_range,
get_collection_name,
get_platform_name,
)
from workflows import (
CommunityWebsiteWorkflow,
Expand Down Expand Up @@ -42,5 +42,5 @@
say_hello,
fetch_telegram_summaries_by_date,
fetch_telegram_summaries_by_date_range,
get_collection_name,
get_platform_name,
]