From cc76487aa23fdfa72fbd99e065991ca8820ae7ea Mon Sep 17 00:00:00 2001 From: Bary Levy Date: Sun, 26 Mar 2023 12:52:47 +0000 Subject: [PATCH 01/13] Add the ability to remove documents --- ...792a820e9374_document_id_in_data_source.py | 24 ++++++++++++ app/api/data_source.py | 19 +++++++++ app/data_source_api/base_data_source.py | 2 +- app/data_source_api/basic_document.py | 6 ++- app/data_source_api/utils.py | 17 ++++++-- app/data_sources/confluence.py | 10 ++++- app/indexing/bm25_index.py | 29 ++++++++++---- app/indexing/faiss_index.py | 5 +++ app/indexing/index_documents.py | 32 +++++++++++++++ app/main.py | 39 +++++++++++-------- app/models.py | 2 +- app/schemas/data_source.py | 15 ++++++- app/schemas/document.py | 3 +- 13 files changed, 170 insertions(+), 33 deletions(-) create mode 100644 app/alembic/versions/792a820e9374_document_id_in_data_source.py diff --git a/app/alembic/versions/792a820e9374_document_id_in_data_source.py b/app/alembic/versions/792a820e9374_document_id_in_data_source.py new file mode 100644 index 0000000..1a4ae41 --- /dev/null +++ b/app/alembic/versions/792a820e9374_document_id_in_data_source.py @@ -0,0 +1,24 @@ +"""document id_in_data_source + +Revision ID: 792a820e9374 +Revises: 9c2f5b290b16 +Create Date: 2023-03-26 11:27:05.341609 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '792a820e9374' +down_revision = '9c2f5b290b16' +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column('document', sa.Column('id_in_data_source', sa.String(length=64), default='__none__')) + + +def downgrade() -> None: + op.drop_column('document', 'id_in_data_source') diff --git a/app/api/data_source.py b/app/api/data_source.py index 0b1a4ac..3f24ec7 100644 --- a/app/api/data_source.py +++ b/app/api/data_source.py @@ -54,6 +54,13 @@ async def list_connected_data_sources() -> List[str]: return [data_source.type.name for data_source in data_sources] +@router.get("/list") +async def list_connected_data_sources() -> List[dict]: + with Session() as session: + data_sources = session.query(DataSource).all() + return [{'id': data_source.id} for data_source in data_sources] + + class AddDataSource(BaseModel): name: str config: dict @@ -86,3 +93,15 @@ async def add_integration(dto: AddDataSource, background_tasks: BackgroundTasks) background_tasks.add_task(data_source.index) return {"success": "Data source added successfully"} + + +@router.post("/delete") +async def delete_integration(data_source_id: int): + with Session() as session: + data_source: DataSource = session.query(DataSource).filter_by(id=data_source_id).first() + if data_source is None: + return {"error": "Data source does not exist"} + + session.delete(data_source) + session.commit() + return {"success": "Data source deleted successfully"} \ No newline at end of file diff --git a/app/data_source_api/base_data_source.py b/app/data_source_api/base_data_source.py index a67465e..8329385 100644 --- a/app/data_source_api/base_data_source.py +++ b/app/data_source_api/base_data_source.py @@ -89,7 +89,7 @@ def _set_last_index_time(self) -> None: def index(self) -> None: try: - self._set_last_index_time() self._feed_new_documents() + self._set_last_index_time() except Exception as e: logging.exception("Error while indexing data source") diff --git a/app/data_source_api/basic_document.py b/app/data_source_api/basic_document.py index d84d018..265562d 100644 --- a/app/data_source_api/basic_document.py +++ b/app/data_source_api/basic_document.py @@ -32,7 +32,7 @@ def from_mime_type(cls, mime_type: str): @dataclass class BasicDocument: - id: int + id: int | str data_source_id: int type: DocumentType title: str @@ -44,3 +44,7 @@ class BasicDocument: url: str file_type: FileType = None + @property + def id_in_data_source(self): + return str(self.data_source_id) + '_' + str(self.id) + diff --git a/app/data_source_api/utils.py b/app/data_source_api/utils.py index 43cedba..b1c2d72 100644 --- a/app/data_source_api/utils.py +++ b/app/data_source_api/utils.py @@ -29,15 +29,26 @@ def get_class_by_data_source_name(data_source_name: str): f"make sure you named the class correctly (it should be DataSource)") -def parse_with_workers(method_name: callable, items: list, **kwargs): +def _wrap_with_try_except(func): + def wrapper(*args, **kwargs): + try: + return func(*args, **kwargs) + except Exception as e: + logger.exception("Failed to parse data source", exc_info=e) + raise e + + return wrapper + + +def parse_with_workers(method: callable, items: list, **kwargs): workers = 10 # should be a config value - logger.info(f'Parsing {len(items)} documents (with {workers} workers)...') + logger.info(f'Parsing {len(items)} documents using {method} (with {workers} workers)...') with concurrent.futures.ThreadPoolExecutor(max_workers=workers) as executor: futures = [] for i in range(workers): - futures.append(executor.submit(method_name, items[i::workers], **kwargs)) + futures.append(executor.submit(_wrap_with_try_except(method), items[i::workers], **kwargs)) concurrent.futures.wait(futures) for w in futures: e = w.exception() diff --git a/app/data_sources/confluence.py b/app/data_sources/confluence.py index 7ac7b40..03241e2 100644 --- a/app/data_sources/confluence.py +++ b/app/data_sources/confluence.py @@ -3,6 +3,8 @@ from typing import List, Dict from atlassian import Confluence +from atlassian.errors import ApiError +from requests import HTTPError from data_source_api.basic_document import BasicDocument, DocumentType from data_source_api.base_data_source import BaseDataSource, ConfigField, HTMLInputType @@ -93,7 +95,13 @@ def _parse_documents_worker(self, raw_docs: List[Dict]): continue doc_id = raw_page['id'] - fetched_raw_page = self._confluence.get_page_by_id(doc_id, expand='body.storage,history') + try: + fetched_raw_page = self._confluence.get_page_by_id(doc_id, expand='body.storage,history') + except HTTPError as e: + logging.warning(f'Confluence returned status code {e.response.status_code} for document {doc_id} ({raw_page["title"]}). skipping.') + continue + except ApiError as e: + logging.warning(f'unable to access document {doc_id} ({raw_page["title"]}). reason: "{e.reason}". skipping.') author = fetched_raw_page['history']['createdBy']['displayName'] author_image = fetched_raw_page['history']['createdBy']['profilePicture']['path'] diff --git a/app/indexing/bm25_index.py b/app/indexing/bm25_index.py index 8c8d859..622f10d 100644 --- a/app/indexing/bm25_index.py +++ b/app/indexing/bm25_index.py @@ -3,6 +3,7 @@ import nltk import numpy as np from rank_bm25 import BM25Okapi +from sqlalchemy import Connection from typing import List from db_engine import Session @@ -10,6 +11,8 @@ from paths import BM25_INDEX_PATH + + def _add_metadata_for_indexing(paragraph: Paragraph) -> str: result = paragraph.content if paragraph.document.title is not None: @@ -45,13 +48,25 @@ def __init__(self) -> None: self.index = None self.id_map = [] - def update(self): - with Session() as session: - all_paragraphs = session.query(Paragraph).all() - corpus = [nltk.word_tokenize(_add_metadata_for_indexing(paragraph)) for paragraph in all_paragraphs] - id_map = [paragraph.id for paragraph in all_paragraphs] - self.index = BM25Okapi(corpus) - self.id_map = id_map + def _update(self, session): + all_paragraphs = session.query(Paragraph).all() + if len(all_paragraphs) == 0: + self.index = None + self.id_map = [] + return + + corpus = [nltk.word_tokenize(_add_metadata_for_indexing(paragraph)) for paragraph in all_paragraphs] + id_map = [paragraph.id for paragraph in all_paragraphs] + self.index = BM25Okapi(corpus) + self.id_map = id_map + + def update(self, session = None): + if session is None: + with Session() as session: + self._update(session) + else: + self._update(session) + self._save() def search(self, query: str, top_k: int) -> List[int]: diff --git a/app/indexing/faiss_index.py b/app/indexing/faiss_index.py index 2e1beeb..8623007 100644 --- a/app/indexing/faiss_index.py +++ b/app/indexing/faiss_index.py @@ -37,6 +37,11 @@ def update(self, ids: torch.LongTensor, embeddings: torch.FloatTensor): faiss.write_index(self.index, FAISS_INDEX_PATH) + def remove(self, ids: torch.LongTensor): + self.index.remove_ids(torch.tensor(ids)) + + faiss.write_index(self.index, FAISS_INDEX_PATH) + def search(self, queries: torch.FloatTensor, top_k: int, *args, **kwargs): if queries.ndim == 1: queries = queries.unsqueeze(0) diff --git a/app/indexing/index_documents.py b/app/indexing/index_documents.py index 5022d9c..fdc51fc 100644 --- a/app/indexing/index_documents.py +++ b/app/indexing/index_documents.py @@ -20,6 +20,19 @@ class Indexer: def index_documents(documents: List[BasicDocument]): logger.info(f"Indexing {len(documents)} documents") + ids_in_data_source = [document.id_in_data_source for document in documents] + + with Session() as session: + documents_to_delete = session.query(Document).filter(Document.id_in_data_source.in_(ids_in_data_source)).all() + + logging.info(f'removing documents that were updated and need to be re-indexed.') + Indexer.remove_documents(documents_to_delete, session) + for document in documents_to_delete: + # Currently bulk deleting doesn't cascade. So we need to delete them one by one. + # See https://stackoverflow.com/a/19245058/3541901 + session.delete(document) + session.commit() + with Session() as session: db_documents = [] for document in documents: @@ -28,6 +41,7 @@ def index_documents(documents: List[BasicDocument]): # Create a new document in the database db_document = Document( data_source_id=document.data_source_id, + id_in_data_source=document.id_in_data_source, type=document.type.value, file_type=document.file_type.value if document.file_type is not None else None, title=document.title, @@ -94,3 +108,21 @@ def _add_metadata_for_indexing(paragraph: Paragraph) -> str: if paragraph.document.title is not None: result += '; ' + paragraph.document.title return result + + @staticmethod + def remove_documents(documents: List[Document], session = None): + logger.info(f"Removing {len(documents)} documents") + + # Get the paragraphs from the documents + db_paragraphs = [paragraph for document in documents for paragraph in document.paragraphs] + + # Remove the paragraphs from the index + paragraph_ids = [paragraph.id for paragraph in db_paragraphs] + + logger.info(f"Removing documents from faiss index...") + FaissIndex.get().remove(paragraph_ids) + + logger.info(f"Removing documents from BM25 index...") + Bm25Index.get().update(session=session) + + logger.info(f"Finished removing {len(documents)} documents => {len(db_paragraphs)} paragraphs") \ No newline at end of file diff --git a/app/main.py b/app/main.py index 0e4e1e2..6ac6b60 100644 --- a/app/main.py +++ b/app/main.py @@ -59,22 +59,25 @@ async def catch_exceptions_middleware(request: Request, call_next): app.include_router(data_source_router) -# @app.on_event("startup") -# @repeat_every(seconds=60) -# def check_for_new_documents(): -# with Session() as session: -# data_sources: List[DataSource] = session.query(DataSource).all() -# for data_source in data_sources: -# # data source should be checked once every hour -# if (datetime.now() - data_source.last_indexed_at).total_seconds() <= 60 * 60: -# continue -# -# logger.info(f"Checking for new docs in {data_source.type.name} (id: {data_source.id})") -# data_source_cls = get_class_by_data_source_name(data_source.type.name) -# config = json.loads(data_source.config) -# data_source_instance = data_source_cls(config=config, data_source_id=data_source.id, -# last_index_time=data_source.last_indexed_at) -# data_source_instance.index() +def _check_for_new_documents(force=False): + with Session() as session: + data_sources: List[DataSource] = session.query(DataSource).all() + for data_source in data_sources: + # data source should be checked once every hour + if (datetime.now() - data_source.last_indexed_at).total_seconds() <= 60 * 60 and not force: + continue + + logger.info(f"Checking for new docs in {data_source.type.name} (id: {data_source.id})") + data_source_cls = get_class_by_data_source_name(data_source.type.name) + config = json.loads(data_source.config) + data_source_instance = data_source_cls(config=config, data_source_id=data_source.id, + last_index_time=data_source.last_indexed_at) + data_source_instance.index() + +@app.on_event("startup") +@repeat_every(seconds=60) +def check_for_new_documents(): + _check_for_new_documents(force=False) @app.on_event("startup") @@ -151,6 +154,10 @@ async def clear_index(): session.commit() +@app.post("/check-for-new-documents") +async def check_for_new_documents_endpoint(): + _check_for_new_documents(force=True) + try: app.mount('/', StaticFiles(directory=UI_PATH, html=True), name='ui') except Exception as e: diff --git a/app/models.py b/app/models.py index bf27efc..51c5d3c 100644 --- a/app/models.py +++ b/app/models.py @@ -8,4 +8,4 @@ cross_encoder_small = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-2-v2') cross_encoder_large = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') -qa_model = pipeline('question-answering', model='deepset/roberta-base-squad2') +qa_model = pipeline('question-answering', model='deepset/minilm-uncased-squad2') diff --git a/app/schemas/data_source.py b/app/schemas/data_source.py index f71dea9..730e7ec 100644 --- a/app/schemas/data_source.py +++ b/app/schemas/data_source.py @@ -1,9 +1,10 @@ from typing import Optional from schemas.base import Base -from sqlalchemy import ForeignKey, Column, Integer +from sqlalchemy import ForeignKey, Column, Integer, Connection from sqlalchemy.orm import Mapped, mapped_column, relationship from sqlalchemy import String, DateTime +from sqlalchemy import event class DataSource(Base): @@ -15,4 +16,14 @@ class DataSource(Base): config: Mapped[Optional[str]] = mapped_column(String(512)) last_indexed_at: Mapped[Optional[DateTime]] = mapped_column(DateTime()) created_at: Mapped[Optional[DateTime]] = mapped_column(DateTime()) - documents = relationship("Document", back_populates="data_source") + documents = relationship("Document", back_populates="data_source", cascade='all, delete, delete-orphan') + + +@event.listens_for(DataSource, 'before_delete') +def receive_before_delete(mapper, connection: Connection, target): + # import here to avoid circular imports + from indexing.index_documents import Indexer + from db_engine import Session + + with Session(bind=connection) as session: + Indexer.remove_documents(target.documents, session=session) diff --git a/app/schemas/document.py b/app/schemas/document.py index 35a7683..3fd5fb3 100644 --- a/app/schemas/document.py +++ b/app/schemas/document.py @@ -9,6 +9,7 @@ class Document(Base): __tablename__ = 'document' id: Mapped[int] = mapped_column(primary_key=True) + id_in_data_source: Mapped[str] = mapped_column(String(64)) data_source_id = Column(Integer, ForeignKey('data_source.id')) data_source = relationship("DataSource", back_populates="documents") type: Mapped[Optional[str]] = mapped_column(String(32)) @@ -19,5 +20,5 @@ class Document(Base): url: Mapped[Optional[str]] = mapped_column(String(512)) location: Mapped[Optional[str]] = mapped_column(String(512)) timestamp: Mapped[Optional[DateTime]] = mapped_column(DateTime()) - paragraphs = relationship("Paragraph", back_populates="document", cascade='all, delete-orphan', + paragraphs = relationship("Paragraph", back_populates="document", cascade='all, delete, delete-orphan', foreign_keys="Paragraph.document_id") From 8c98a4f3857d6452bdc3dfea54f8130e97170c5d Mon Sep 17 00:00:00 2001 From: roey Date: Mon, 27 Mar 2023 01:33:03 +0300 Subject: [PATCH 02/13] Initial commit for task-queue --- ...0b16_add_fields_to_datasourcetype_model.py | 4 +- app/api/data_source.py | 43 ++--- .../__init__.py | 0 .../base_data_source.py | 13 +- .../basic_document.py | 0 app/data_source/context.py | 83 +++++++++ app/data_source/dynamic_loader.py | 96 +++++++++++ .../exception.py | 0 .../sources}/__init__.py | 0 app/data_source/sources/bookstack/__init__.py | 0 .../sources/bookstack}/bookstack.py | 102 +++++------ .../sources/confluence/__init__.py | 0 .../sources/confluence}/confluence.py | 98 +++++------ .../sources/confluence}/confluence_cloud.py | 6 +- .../sources/gogle_drive/__init__.py | 0 .../sources/gogle_drive}/google_drive.py | 161 +++++++++--------- .../sources/mattermost/__init__.py | 0 .../sources/mattermost}/mattermost.py | 41 ++--- .../sources/rocketchat/__init__.py | 0 .../sources/rocketchat}/rocketchat.py | 94 +++++----- app/data_source/sources/slack/__init__.py | 0 .../sources/slack}/slack.py | 33 ++-- app/indexing/background_indexer.py | 2 +- app/indexing/index_documents.py | 2 +- app/main.py | 44 ++--- app/paths.py | 1 + app/queues/__init__.py | 0 app/{ => queues}/index_queue.py | 10 +- app/queues/task_queue.py | 50 ++++++ app/search_logic.py | 4 +- app/slaves.py | 52 ++++++ 31 files changed, 559 insertions(+), 380 deletions(-) rename app/{data_source_api => data_source}/__init__.py (100%) rename app/{data_source_api => data_source}/base_data_source.py (85%) rename app/{data_source_api => data_source}/basic_document.py (100%) create mode 100644 app/data_source/context.py create mode 100644 app/data_source/dynamic_loader.py rename app/{data_source_api => data_source}/exception.py (100%) rename app/{data_sources => data_source/sources}/__init__.py (100%) create mode 100644 app/data_source/sources/bookstack/__init__.py rename app/{data_sources => data_source/sources/bookstack}/bookstack.py (69%) create mode 100644 app/data_source/sources/confluence/__init__.py rename app/{data_sources => data_source/sources/confluence}/confluence.py (53%) rename app/{data_sources => data_source/sources/confluence}/confluence_cloud.py (87%) create mode 100644 app/data_source/sources/gogle_drive/__init__.py rename app/{data_sources => data_source/sources/gogle_drive}/google_drive.py (59%) create mode 100644 app/data_source/sources/mattermost/__init__.py rename app/{data_sources => data_source/sources/mattermost}/mattermost.py (81%) create mode 100644 app/data_source/sources/rocketchat/__init__.py rename app/{data_sources => data_source/sources/rocketchat}/rocketchat.py (67%) create mode 100644 app/data_source/sources/slack/__init__.py rename app/{data_sources => data_source/sources/slack}/slack.py (82%) create mode 100644 app/queues/__init__.py rename app/{ => queues}/index_queue.py (82%) create mode 100644 app/queues/task_queue.py create mode 100644 app/slaves.py diff --git a/app/alembic/versions/9c2f5b290b16_add_fields_to_datasourcetype_model.py b/app/alembic/versions/9c2f5b290b16_add_fields_to_datasourcetype_model.py index 0f9626c..694a039 100644 --- a/app/alembic/versions/9c2f5b290b16_add_fields_to_datasourcetype_model.py +++ b/app/alembic/versions/9c2f5b290b16_add_fields_to_datasourcetype_model.py @@ -10,7 +10,7 @@ from alembic import op import sqlalchemy as sa -from data_source_api.utils import get_class_by_data_source_name +from data_source.dynamic_loader import DynamicLoader from db_engine import Session from schemas import DataSourceType @@ -29,7 +29,7 @@ def upgrade() -> None: # update existing data sources data_source_types = session.query(DataSourceType).all() for data_source_type in data_source_types: - data_source_class = get_class_by_data_source_name(data_source_type.name) + data_source_class = DynamicLoader.get_data_source_class(data_source_type.name) config_fields = data_source_class.get_config_fields() data_source_type.config_fields = json.dumps([config_field.dict() for config_field in config_fields]) diff --git a/app/api/data_source.py b/app/api/data_source.py index 0b1a4ac..bceee8c 100644 --- a/app/api/data_source.py +++ b/app/api/data_source.py @@ -1,15 +1,12 @@ import base64 import json -from datetime import datetime from typing import List -from fastapi import APIRouter, BackgroundTasks +from fastapi import APIRouter from pydantic import BaseModel -from starlette.responses import Response -from data_source_api.base_data_source import ConfigField -from data_source_api.exception import KnownException -from data_source_api.utils import get_class_by_data_source_name +from data_source.base_data_source import ConfigField +from data_source.context import DataSourceContext from db_engine import Session from schemas import DataSourceType, DataSource @@ -60,29 +57,11 @@ class AddDataSource(BaseModel): @router.post("/add") -async def add_integration(dto: AddDataSource, background_tasks: BackgroundTasks): - with Session() as session: - data_source_type = session.query(DataSourceType).filter_by(name=dto.name).first() - if data_source_type is None: - return {"error": "Data source type does not exist"} - - data_source_class = get_class_by_data_source_name(dto.name) - try: - data_source_class.validate_config(dto.config) - except KnownException as e: - return Response(e.message, status_code=501) - - config_str = json.dumps(dto.config) - ds = DataSource(type_id=data_source_type.id, config=config_str, created_at=datetime.now()) - session.add(ds) - session.commit() - - data_source_id = session.query(DataSource).filter_by(type_id=data_source_type.id)\ - .order_by(DataSource.id.desc()).first().id - data_source = data_source_class(config=dto.config, data_source_id=data_source_id) - - # in main.py we have a background task that runs every 5 minutes and indexes the data source - # but here we want to index the data source immediately - background_tasks.add_task(data_source.index) - - return {"success": "Data source added successfully"} +async def add_integration(dto: AddDataSource): + data_source = DataSourceContext.create_data_source(name=dto.name, config=dto.config) + + # in main.py we have a background task that runs every 5 minutes and indexes the data source + # but here we want to index the data source immediately + data_source.add_task_to_queue(data_source.index) + + return {"success": "Data source added successfully"} diff --git a/app/data_source_api/__init__.py b/app/data_source/__init__.py similarity index 100% rename from app/data_source_api/__init__.py rename to app/data_source/__init__.py diff --git a/app/data_source_api/base_data_source.py b/app/data_source/base_data_source.py similarity index 85% rename from app/data_source_api/base_data_source.py rename to app/data_source/base_data_source.py index a67465e..30e6b7a 100644 --- a/app/data_source_api/base_data_source.py +++ b/app/data_source/base_data_source.py @@ -2,12 +2,13 @@ from abc import abstractmethod, ABC from datetime import datetime from enum import Enum -from typing import Dict, List, Optional +from typing import Dict, List, Optional, Callable import re from pydantic import BaseModel from db_engine import Session +from queues.task_queue import TaskQueue, Task from schemas import DataSource @@ -87,6 +88,16 @@ def _set_last_index_time(self) -> None: data_source.last_indexed_at = datetime.now() session.commit() + def add_task_to_queue(self, function: Callable, **kwargs): + task = Task(data_source_id=self._data_source_id, + function_name=function.__name__, + kwargs=kwargs) + TaskQueue.get_instance().add_task(task) + + def run_task(self, function_name: str, **kwargs) -> None: + function = getattr(self, function_name) + function(**kwargs) + def index(self) -> None: try: self._set_last_index_time() diff --git a/app/data_source_api/basic_document.py b/app/data_source/basic_document.py similarity index 100% rename from app/data_source_api/basic_document.py rename to app/data_source/basic_document.py diff --git a/app/data_source/context.py b/app/data_source/context.py new file mode 100644 index 0000000..b0345c9 --- /dev/null +++ b/app/data_source/context.py @@ -0,0 +1,83 @@ +import json +from datetime import datetime +from typing import Dict, List + +from data_source.base_data_source import BaseDataSource +from data_source.dynamic_loader import DynamicLoader, ClassInfo +from data_source.exception import KnownException +from db_engine import Session +from schemas import DataSourceType, DataSource + + +class DataSourceContext: + _initialized = False + _data_sources: Dict[int, BaseDataSource] = {} + + @classmethod + def get_data_source(cls, data_source_id: int) -> BaseDataSource: + if not cls._initialized: + cls.init() + cls._initialized = True + + return cls._data_sources[data_source_id] + + @classmethod + def create_data_source(cls, name: str, config: dict) -> BaseDataSource: + with Session() as session: + data_source_type = session.query(DataSourceType).filter_by(name=name).first() + if data_source_type is None: + raise KnownException(message=f"Data source type {name} does not exist") + + data_source_class = DynamicLoader.get_data_source_class(name) + data_source_class.validate_config(config) + config_str = json.dumps(config) + + data_source_row = DataSource(type_id=data_source_type.id, config=config_str, created_at=datetime.now()) + session.add(data_source_row) + session.commit() + + data_source_id = session.query(DataSource).filter_by(type_id=data_source_type.id) \ + .order_by(DataSource.id.desc()).first().id + data_source = data_source_class(config=config, data_source_id=data_source_id) + cls._data_sources[data_source_id] = data_source + + return data_source + + @classmethod + def init(cls): + cls._add_data_sources_to_db() + cls._load_context_from_db() + + @classmethod + def _load_context_from_db(cls): + with Session() as session: + data_sources: List[DataSource] = session.query(DataSource).all() + for data_source in data_sources: + data_source_cls = DynamicLoader.get_data_source_class(data_source.type.name) + config = json.loads(data_source.config) + data_source_instance = data_source_cls(config=config, data_source_id=data_source.id, + last_index_time=data_source.last_indexed_at) + cls._data_sources[data_source.id] = data_source_instance + + cls._initialized = True + + @classmethod + def _add_data_sources_to_db(cls): + data_sources: Dict[str, ClassInfo] = DynamicLoader.find_data_sources() + + with Session() as session: + for source_name in data_sources.keys(): + if session.query(DataSourceType).filter_by(name=source_name).first(): + continue + + class_info = data_sources[source_name] + data_source_class = DynamicLoader.get_class(file_path=class_info.file_path, + class_name=class_info.name) + + config_fields = data_source_class.get_config_fields() + config_fields_str = json.dumps([config_field.dict() for config_field in config_fields]) + new_data_source = DataSourceType(name=source_name, + display_name=data_source_class.get_display_name(), + config_fields=config_fields_str) + session.add(new_data_source) + session.commit() diff --git a/app/data_source/dynamic_loader.py b/app/data_source/dynamic_loader.py new file mode 100644 index 0000000..a1cf81c --- /dev/null +++ b/app/data_source/dynamic_loader.py @@ -0,0 +1,96 @@ +import ast +import os +import re +from dataclasses import dataclass +from typing import Dict +import importlib + +from data_source.utils import snake_case_to_pascal_case + + +@dataclass +class ClassInfo: + name: str + file_path: str + + +class DynamicLoader: + SOURCES_PATH = 'data_source/sources' + + @staticmethod + def extract_classes(file_path: str): + with open(file_path, 'r') as f: + file_ast = ast.parse(f.read()) + classes = {} + for node in file_ast.body: + if isinstance(node, ast.ClassDef): + classes[node.name] = {'node': node, 'file': file_path} + return classes + + @staticmethod + def get_data_source_class(data_source_name: str): + class_name = f"{snake_case_to_pascal_case(data_source_name)}DataSource" + class_file_path = DynamicLoader.find_class_file(DynamicLoader.SOURCES_PATH, class_name) + return DynamicLoader.get_class(class_file_path, class_name) + + @staticmethod + def get_class(file_path: str, class_name: str): + module_name = file_path.replace("/", ".").replace(".py", "") + module = importlib.import_module(module_name) + try: + return getattr(module, class_name) + except AttributeError: + raise AttributeError(f"Class {class_name} not found in module {module}," + f"make sure you named the class correctly (it should be DataSource)") + + @staticmethod + def find_class_file(directory, class_name): + for root, dirs, files in os.walk(directory): + for file in files: + if file.endswith('.py'): + file_path = os.path.join(root, file) + classes = DynamicLoader.extract_classes(file_path) + if class_name in classes: + return file_path + return None + + @staticmethod + def find_data_sources() -> Dict[str, ClassInfo]: + all_classes = {} + # First, extract all classes and their file paths + for root, dirs, files in os.walk(DynamicLoader.SOURCES_PATH): + for file in files: + if file.endswith('.py'): + file_path = os.path.join(root, file) + all_classes.update(DynamicLoader.extract_classes(file_path)) + + def is_base_data_source(class_name: str): + if class_name not in all_classes: + return False + + class_info = all_classes[class_name] + node = class_info['node'] + + for base in node.bases: + if isinstance(base, ast.Name): + if base.id == 'BaseDataSource': + return True + elif is_base_data_source(base.id): + return True + + return False + + data_sources = {} + # Then, check if each class inherits from BaseDataSource + for class_name, class_info in all_classes.items(): + if is_base_data_source(class_name): + snake_case = re.sub('([a-z0-9])([A-Z])', r'\1_\2', + re.sub('(.)([A-Z][a-z]+)', r'\1_\2', class_name)).lower() + clas_name = snake_case.replace('_data_source', '') + data_sources[clas_name] = ClassInfo(class_name, class_info['file']) + + return data_sources + + +if __name__ == '__main__': + print(DynamicLoader.find_data_sources()) diff --git a/app/data_source_api/exception.py b/app/data_source/exception.py similarity index 100% rename from app/data_source_api/exception.py rename to app/data_source/exception.py diff --git a/app/data_sources/__init__.py b/app/data_source/sources/__init__.py similarity index 100% rename from app/data_sources/__init__.py rename to app/data_source/sources/__init__.py diff --git a/app/data_source/sources/bookstack/__init__.py b/app/data_source/sources/bookstack/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/data_sources/bookstack.py b/app/data_source/sources/bookstack/bookstack.py similarity index 69% rename from app/data_sources/bookstack.py rename to app/data_source/sources/bookstack/bookstack.py index 1c52aad..9bafd6a 100644 --- a/app/data_sources/bookstack.py +++ b/app/data_source/sources/bookstack/bookstack.py @@ -1,19 +1,18 @@ import logging from datetime import datetime +from time import sleep from typing import List, Dict +from urllib.parse import urljoin -from data_source_api.basic_document import BasicDocument, DocumentType -from data_source_api.base_data_source import BaseDataSource, ConfigField, HTMLInputType -from data_source_api.exception import InvalidDataSourceConfig -from data_source_api.utils import parse_with_workers -from index_queue import IndexQueue -from parsers.html import html_to_text from pydantic import BaseModel from requests import Session, HTTPError from requests.auth import AuthBase -from urllib.parse import urljoin -from time import sleep +from data_source.base_data_source import BaseDataSource, ConfigField, HTMLInputType +from data_source.basic_document import BasicDocument, DocumentType +from data_source.exception import InvalidDataSourceConfig +from parsers.html import html_to_text +from queues.index_queue import IndexQueue logger = logging.getLogger(__name__) @@ -153,59 +152,44 @@ def _feed_new_documents(self) -> None: logger.info("Feeding new documents with BookStack") books = self._list_books() - raw_docs = [] for book in books: - raw_docs.extend(self._list_book_pages(book)) - - parse_with_workers(self._parse_documents_worker, raw_docs) - - def _parse_documents_worker(self, raw_docs: List[Dict]): - logger.info(f"Worker parsing {len(raw_docs)} documents") - - parsed_docs = [] - total_fed = 0 - for raw_page in raw_docs: - last_modified = datetime.strptime(raw_page["updated_at"], "%Y-%m-%dT%H:%M:%S.%fZ") - if last_modified < self._last_index_time: - continue - - page_id = raw_page["id"] - page_content = self._book_stack.get_page(page_id) - author_name = page_content["created_by"]["name"] - - author_image_url = "" - author = self._book_stack.get_user(raw_page["created_by"]) - if author: - author_image_url = author["avatar_url"] - - plain_text = html_to_text(page_content["html"]) - - url = urljoin(self._config.get('url'), f"/books/{raw_page['book_slug']}/page/{raw_page['slug']}") - - parsed_docs.append(BasicDocument(title=raw_page["name"], - content=plain_text, - author=author_name, - author_image_url=author_image_url, - timestamp=last_modified, - id=page_id, - data_source_id=self._data_source_id, - location=raw_page["book"]["name"], - url=url, - type=DocumentType.DOCUMENT)) - if len(parsed_docs) >= 50: - total_fed += len(parsed_docs) - IndexQueue.get_instance().put(docs=parsed_docs) - parsed_docs = [] - - IndexQueue.get_instance().put(docs=parsed_docs) - total_fed += len(parsed_docs) - if total_fed > 0: - logging.info(f"Worker fed {total_fed} documents") - - def _list_book_pages(self, book: Dict) -> List[Dict]: - logger.info(f"Getting documents from book {book['name']} ({book['id']})") - return self._book_stack.get_all_pages_from_book(book) + self.add_task_to_queue(self._feed_book, book=book) + def _feed_book(self, book: Dict): + logger.info(f"Getting documents from book {book['name']} ({book['id']})") + pages = self._book_stack.get_all_pages_from_book(book) + for page in pages: + self.add_task_to_queue(self._feed_page, raw_page=page) + + def _feed_page(self, raw_page: Dict): + last_modified = datetime.strptime(raw_page["updated_at"], "%Y-%m-%dT%H:%M:%S.%fZ") + if last_modified < self._last_index_time: + return + + page_id = raw_page["id"] + page_content = self._book_stack.get_page(page_id) + author_name = page_content["created_by"]["name"] + + author_image_url = "" + author = self._book_stack.get_user(raw_page["created_by"]) + if author: + author_image_url = author["avatar_url"] + + plain_text = html_to_text(page_content["html"]) + + url = urljoin(self._config.get('url'), f"/books/{raw_page['book_slug']}/page/{raw_page['slug']}") + + document = BasicDocument(title=raw_page["name"], + content=plain_text, + author=author_name, + author_image_url=author_image_url, + timestamp=last_modified, + id=page_id, + data_source_id=self._data_source_id, + location=raw_page["book"]["name"], + url=url, + type=DocumentType.DOCUMENT) + IndexQueue.get_instance().put_single(doc=document) # if __name__ == "__main__": # import os diff --git a/app/data_source/sources/confluence/__init__.py b/app/data_source/sources/confluence/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/data_sources/confluence.py b/app/data_source/sources/confluence/confluence.py similarity index 53% rename from app/data_sources/confluence.py rename to app/data_source/sources/confluence/confluence.py index b81fcc3..d4f907f 100644 --- a/app/data_sources/confluence.py +++ b/app/data_source/sources/confluence/confluence.py @@ -3,15 +3,13 @@ from typing import List, Dict from atlassian import Confluence - -from data_source_api.basic_document import BasicDocument, DocumentType -from data_source_api.base_data_source import BaseDataSource, ConfigField, HTMLInputType -from data_source_api.exception import InvalidDataSourceConfig -from data_source_api.utils import parse_with_workers -from index_queue import IndexQueue -from parsers.html import html_to_text from pydantic import BaseModel +from data_source.base_data_source import BaseDataSource, ConfigField, HTMLInputType +from data_source.basic_document import BasicDocument, DocumentType +from data_source.exception import InvalidDataSourceConfig +from parsers.html import html_to_text +from queues.index_queue import IndexQueue logger = logging.getLogger(__name__) @@ -74,56 +72,11 @@ def _list_spaces(self) -> List[Dict]: def _feed_new_documents(self) -> None: logger.info('Feeding new documents with Confluence') - spaces = self._list_spaces() - raw_docs = [] for space in spaces: - raw_docs.extend(self._list_space_docs(space)) - - parse_with_workers(self._parse_documents_worker, raw_docs) - - def _parse_documents_worker(self, raw_docs: List[Dict]): - logging.info(f'Worker parsing {len(raw_docs)} documents') - - parsed_docs = [] - total_fed = 0 - for raw_page in raw_docs: - last_modified = datetime.strptime(raw_page['version']['when'], "%Y-%m-%dT%H:%M:%S.%fZ") - if last_modified < self._last_index_time: - continue - - doc_id = raw_page['id'] - fetched_raw_page = self._confluence.get_page_by_id(doc_id, expand='body.storage,history') - - author = fetched_raw_page['history']['createdBy']['displayName'] - author_image = fetched_raw_page['history']['createdBy']['profilePicture']['path'] - author_image_url = fetched_raw_page['_links']['base'] + author_image - html_content = fetched_raw_page['body']['storage']['value'] - plain_text = html_to_text(html_content) - - url = fetched_raw_page['_links']['base'] + fetched_raw_page['_links']['webui'] - - parsed_docs.append(BasicDocument(title=fetched_raw_page['title'], - content=plain_text, - author=author, - author_image_url=author_image_url, - timestamp=last_modified, - id=doc_id, - data_source_id=self._data_source_id, - location=raw_page['space_name'], - url=url, - type=DocumentType.DOCUMENT)) - if len(parsed_docs) >= 50: - total_fed += len(parsed_docs) - IndexQueue.get_instance().put(docs=parsed_docs) - parsed_docs = [] - - IndexQueue.get_instance().put(docs=parsed_docs) - total_fed += len(parsed_docs) - if total_fed > 0: - logging.info(f'Worker fed {total_fed} documents') - - def _list_space_docs(self, space: Dict) -> List[Dict]: + self.add_task_to_queue(self._feed_space_docs, space=space) + + def _feed_space_docs(self, space: Dict) -> List[Dict]: logging.info(f'Getting documents from space {space["name"]} ({space["key"]})') start = 0 limit = 200 # limit when expanding the version @@ -132,10 +85,10 @@ def _list_space_docs(self, space: Dict) -> List[Dict]: while True: new_batch = self._confluence.get_all_pages_from_space(space['key'], start=start, limit=limit, expand='version') - for doc in new_batch: - doc['space_name'] = space['name'] + for raw_doc in new_batch: + raw_doc['space_name'] = space['name'] + self.add_task_to_queue(self._feed_doc, raw_doc=raw_doc) - space_docs.extend(new_batch) if len(new_batch) < limit: break @@ -143,6 +96,35 @@ def _list_space_docs(self, space: Dict) -> List[Dict]: return space_docs + def _feed_doc(self, raw_doc: Dict): + last_modified = datetime.strptime(raw_doc['version']['when'], "%Y-%m-%dT%H:%M:%S.%fZ") + + if last_modified < self._last_index_time: + return + + doc_id = raw_doc['id'] + fetched_raw_page = self._confluence.get_page_by_id(doc_id, expand='body.storage,history') + + author = fetched_raw_page['history']['createdBy']['displayName'] + author_image = fetched_raw_page['history']['createdBy']['profilePicture']['path'] + author_image_url = fetched_raw_page['_links']['base'] + author_image + html_content = fetched_raw_page['body']['storage']['value'] + plain_text = html_to_text(html_content) + + url = fetched_raw_page['_links']['base'] + fetched_raw_page['_links']['webui'] + + doc = BasicDocument(title=fetched_raw_page['title'], + content=plain_text, + author=author, + author_image_url=author_image_url, + timestamp=last_modified, + id=doc_id, + data_source_id=self._data_source_id, + location=raw_doc['space_name'], + url=url, + type=DocumentType.DOCUMENT) + IndexQueue.get_instance().put_single(doc=doc) + # if __name__ == '__main__': # import os diff --git a/app/data_sources/confluence_cloud.py b/app/data_source/sources/confluence/confluence_cloud.py similarity index 87% rename from app/data_sources/confluence_cloud.py rename to app/data_source/sources/confluence/confluence_cloud.py index f18e54f..3405969 100644 --- a/app/data_sources/confluence_cloud.py +++ b/app/data_source/sources/confluence/confluence_cloud.py @@ -3,9 +3,9 @@ from atlassian import Confluence from pydantic import BaseModel -from data_source_api.base_data_source import ConfigField, HTMLInputType -from data_source_api.exception import InvalidDataSourceConfig -from data_sources.confluence import ConfluenceDataSource +from data_source.base_data_source import ConfigField, HTMLInputType +from data_source.exception import InvalidDataSourceConfig +from data_source.sources.confluence.confluence import ConfluenceDataSource class ConfluenceCloudConfig(BaseModel): diff --git a/app/data_source/sources/gogle_drive/__init__.py b/app/data_source/sources/gogle_drive/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/data_sources/google_drive.py b/app/data_source/sources/gogle_drive/google_drive.py similarity index 59% rename from app/data_sources/google_drive.py rename to app/data_source/sources/gogle_drive/google_drive.py index a54c6ec..ec367ba 100644 --- a/app/data_sources/google_drive.py +++ b/app/data_source/sources/gogle_drive/google_drive.py @@ -1,27 +1,25 @@ -import json -import os import io +import json import logging +import os from datetime import datetime -from typing import Dict, List from functools import lru_cache +from typing import Dict, List -import googleapiclient -from googleapiclient.errors import HttpError from apiclient.discovery import build +from googleapiclient.errors import HttpError from googleapiclient.http import MediaIoBaseDownload from httplib2 import Http from oauth2client.service_account import ServiceAccountCredentials from pydantic import BaseModel -from data_source_api.base_data_source import BaseDataSource, ConfigField, HTMLInputType -from data_source_api.basic_document import BasicDocument, DocumentType, FileType -from data_source_api.exception import InvalidDataSourceConfig, KnownException -from index_queue import IndexQueue +from data_source.base_data_source import BaseDataSource, ConfigField, HTMLInputType +from data_source.basic_document import BasicDocument, DocumentType, FileType +from data_source.exception import KnownException +from parsers.docx import docx_to_html from parsers.html import html_to_text from parsers.pptx import pptx_to_text -from parsers.docx import docx_to_html - +from queues.index_queue import IndexQueue logger = logging.getLogger(__name__) @@ -33,7 +31,8 @@ class GoogleDriveConfig(BaseModel): class GoogleDriveDataSource(BaseDataSource): mime_type_to_parser = { 'application/vnd.google-apps.document': html_to_text, - 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': lambda content: html_to_text(docx_to_html(content)), + 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': lambda content: html_to_text( + docx_to_html(content)), 'application/vnd.openxmlformats-officedocument.presentationml.presentation': pptx_to_text, } @@ -66,7 +65,7 @@ def __init__(self, *args, **kwargs): self._credentials = ServiceAccountCredentials.from_json_keyfile_dict(json_dict, scopes=scopes) self._http_auth = self._credentials.authorize(Http()) self._drive = build('drive', 'v3', http=self._http_auth) - + self._supported_mime_types = [ 'application/vnd.google-apps.document', 'application/vnd.openxmlformats-officedocument.wordprocessingml.document', @@ -75,13 +74,14 @@ def __init__(self, *args, **kwargs): def _should_index_file(self, file): if file['mimeType'] not in self._supported_mime_types: - logging.info(f"Skipping file {file['name']} because it's mime type is {file['mimeType']} which is not supported.") + logging.info( + f"Skipping file {file['name']} because it's mime type is {file['mimeType']} which is not supported.") return False last_modified = datetime.strptime(file['modifiedTime'], "%Y-%m-%dT%H:%M:%S.%fZ") if last_modified < self._last_index_time: return False - + return True @lru_cache(maxsize=512) @@ -100,9 +100,12 @@ def _get_parent_name(self, parent_id) -> dict: def _get_parents_string(self, file): return self._get_parent_name(file['parents'][0]) if file['parents'] else '' - def _index_files_from_drive(self, drive) -> List[dict]: - is_shared_drive = drive['id'] is not None + def _feed_new_documents(self) -> None: + for drive in self._get_all_drives(): + self._feed_drive(drive=drive) + def _feed_drive(self, drive): + is_shared_drive = drive['id'] is not None logging.info(f'Indexing drive {drive["name"]}') kwargs = { @@ -111,89 +114,79 @@ def _index_files_from_drive(self, drive) -> List[dict]: 'includeItemsFromAllDrives': True, 'supportsAllDrives': True, } if is_shared_drive else {} - - files = [] - next_page_token = None while True: if next_page_token: kwargs['pageToken'] = next_page_token + response = self._drive.files().list( fields='nextPageToken,files(kind,id,name,mimeType,lastModifyingUser,webViewLink,modifiedTime,parents)', pageSize=1000, **kwargs ).execute() - files.extend(response['files']) + logger.info(f'got {len(response["files"])} documents from drive {drive["name"]}.') + + for file in response['files']: + if self._should_index_file(file): + self._feed_file(file) + next_page_token = response.get('nextPageToken') if next_page_token is None: break - logging.getLogger().info(f'got {len(files)} documents from drive {drive["name"]}.') - - files = [file for file in files if self._should_index_file(file)] + def _feed_file(self, file): + logger.info(f'processing file {file["name"]}') + + file_id = file['id'] + file_to_download = file['name'] + if file['mimeType'] == 'application/vnd.google-apps.document': + content = self._drive.files().export(fileId=file_id, mimeType='text/html').execute().decode('utf-8') + content = html_to_text(content) + else: + try: + request = self._drive.files().get_media(fileId=file_id) + fh = io.BytesIO() + downloader = MediaIoBaseDownload(fh, request) + done = False + while done is False: + status, done = downloader.next_chunk() + + # write the downloaded content to a file + with open(file_to_download, 'wb') as f: + f.write(fh.getbuffer()) + + if file['mimeType'] == 'application/vnd.openxmlformats-officedocument.presentationml.presentation': + content = pptx_to_text(file_to_download) + elif file['mimeType'] == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': + content = docx_to_html(file_to_download) + content = html_to_text(content) + else: + logger.error(f'Unsupported mime type {file["mimeType"]}') + return + + # delete file + os.remove(file_to_download) + except Exception as error: + logging.exception(f'Error occurred parsing file "{file["name"]}" from google drive') + + parent_name = self._get_parents_string(file) - documents = [] - - logging.getLogger().info(f'Indexing {len(files)} documents from drive {drive["name"]}.') - - for file in files: - logging.getLogger().info(f'processing file {file["name"]}') + last_modified = datetime.strptime(file['modifiedTime'], "%Y-%m-%dT%H:%M:%S.%fZ") - file_id = file['id'] - file_to_download = file['name'] - if file['mimeType'] == 'application/vnd.google-apps.document': - content = self._drive.files().export(fileId=file_id, mimeType='text/html').execute().decode('utf-8') - content = html_to_text(content) - else: - try: - request = self._drive.files().get_media(fileId=file_id) - fh = io.BytesIO() - downloader = MediaIoBaseDownload(fh, request) - done = False - while done is False: - status, done = downloader.next_chunk() - - # write the downloaded content to a file - with open(file_to_download, 'wb') as f: - f.write(fh.getbuffer()) - - if file['mimeType'] == 'application/vnd.openxmlformats-officedocument.presentationml.presentation': - content = pptx_to_text(file_to_download) - elif file['mimeType'] == 'application/vnd.openxmlformats-officedocument.wordprocessingml.document': - content = docx_to_html(file_to_download) - content = html_to_text(content) - else: - continue - - # delete file - os.remove(file_to_download) - except Exception as error: - logging.exception(f'Error occurred parsing file "{file["name"]}" from google drive') - - parent_name = self._get_parents_string(file) - - last_modified = datetime.strptime(file['modifiedTime'], "%Y-%m-%dT%H:%M:%S.%fZ") - - documents.append(BasicDocument( - id=file_id, - data_source_id=self._data_source_id, - type=DocumentType.DOCUMENT, - title=file['name'], - content=content, - author=file['lastModifyingUser']['displayName'], - author_image_url=file['lastModifyingUser']['photoLink'], - location=parent_name, - url=file['webViewLink'], - timestamp=last_modified, - file_type=FileType.from_mime_type(mime_type=file['mimeType']) - )) - - IndexQueue.get_instance().put(documents) + doc = BasicDocument( + id=file_id, + data_source_id=self._data_source_id, + type=DocumentType.DOCUMENT, + title=file['name'], + content=content, + author=file['lastModifyingUser']['displayName'], + author_image_url=file['lastModifyingUser']['photoLink'], + location=parent_name, + url=file['webViewLink'], + timestamp=last_modified, + file_type=FileType.from_mime_type(mime_type=file['mimeType'])) + IndexQueue.get_instance().put_single(doc) def _get_all_drives(self) -> List[dict]: return [{'name': 'My Drive', 'id': None}] \ + self._drive.drives().list(fields='drives(id,name)').execute()['drives'] - - def _feed_new_documents(self) -> None: - for drive in self._get_all_drives(): - self._index_files_from_drive(drive) diff --git a/app/data_source/sources/mattermost/__init__.py b/app/data_source/sources/mattermost/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/data_sources/mattermost.py b/app/data_source/sources/mattermost/mattermost.py similarity index 81% rename from app/data_sources/mattermost.py rename to app/data_source/sources/mattermost/mattermost.py index 3f1b182..3e2f033 100644 --- a/app/data_sources/mattermost.py +++ b/app/data_source/sources/mattermost/mattermost.py @@ -7,11 +7,10 @@ from mattermostdriver import Driver -from data_source_api.base_data_source import BaseDataSource, ConfigField, HTMLInputType -from data_source_api.basic_document import BasicDocument, DocumentType -from data_source_api.exception import InvalidDataSourceConfig -from data_source_api.utils import parse_with_workers -from index_queue import IndexQueue +from data_source.base_data_source import BaseDataSource, ConfigField, HTMLInputType +from data_source.basic_document import BasicDocument, DocumentType +from data_source.exception import InvalidDataSourceConfig +from queues.index_queue import IndexQueue logger = logging.getLogger(__name__) @@ -90,14 +89,12 @@ def _list_posts_in_channel(self, channel_id: str, page: int) -> Dict: def _feed_new_documents(self) -> None: self._mattermost.login() - channels = self._list_channels() + channels = self._list_channels() logger.info(f'Found {len(channels)} channels') - parse_with_workers(self._parse_channel_worker, channels) - def _parse_channel_worker(self, channels: List[MattermostChannel]): for channel in channels: - self._feed_channel(channel) + self.add_task_to_queue(self._feed_channel, channel=channel) def _get_mattermost_url(self): options = self._mattermost.options @@ -115,27 +112,22 @@ def _get_mattermost_user(self, user_id: str): def _feed_channel(self, channel: MattermostChannel): if not self._is_valid_channel(channel): return + logger.info(f'Feeding channel {channel.name}') page = 0 - total_fed = 0 - - parsed_posts = [] - team_url = self._get_team_url(channel) - while True: posts = self._list_posts_in_channel(channel.id, page) last_message: Optional[BasicDocument] = None - posts["order"].reverse() for id in posts["order"]: post = posts["posts"][id] if not self._is_valid_message(post): if last_message is not None: - parsed_posts.append(last_message) + IndexQueue.get_instance().put_single(doc=last_message) last_message = None continue @@ -147,11 +139,8 @@ def _feed_channel(self, channel: MattermostChannel): last_message.content += f"\n{content}" continue else: - parsed_posts.append(last_message) - if len(parsed_posts) >= MattermostDataSource.FEED_BATCH_SIZE: - total_fed += len(parsed_posts) - IndexQueue.get_instance().put(docs=parsed_posts) - parsed_posts = [] + IndexQueue.get_instance().put_single(doc=last_message) + last_message = None author_image_url = f"{self._get_mattermost_url()}/api/v4/users/{post['user_id']}/image?_=0" timestamp = datetime.fromtimestamp(post["update_at"] / 1000) @@ -168,15 +157,9 @@ def _feed_channel(self, channel: MattermostChannel): type=DocumentType.MESSAGE ) - if last_message is not None: - parsed_posts.append(last_message) - if posts["prev_post_id"] == "": break page += 1 - IndexQueue.get_instance().put(docs=parsed_posts) - total_fed += len(parsed_posts) - - if len(parsed_posts) > 0: - logger.info(f"Worker fed {total_fed} documents") + if last_message is not None: + IndexQueue.get_instance().put_single(doc=last_message) diff --git a/app/data_source/sources/rocketchat/__init__.py b/app/data_source/sources/rocketchat/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/data_sources/rocketchat.py b/app/data_source/sources/rocketchat/rocketchat.py similarity index 67% rename from app/data_sources/rocketchat.py rename to app/data_source/sources/rocketchat/rocketchat.py index 69991aa..3fa1185 100644 --- a/app/data_sources/rocketchat.py +++ b/app/data_source/sources/rocketchat/rocketchat.py @@ -6,10 +6,10 @@ from pydantic import BaseModel from rocketchat_API.rocketchat import RocketChat -from data_source_api.base_data_source import BaseDataSource, ConfigField, HTMLInputType -from data_source_api.basic_document import DocumentType, BasicDocument -from data_source_api.exception import InvalidDataSourceConfig -from index_queue import IndexQueue +from data_source.base_data_source import BaseDataSource, ConfigField, HTMLInputType +from data_source.basic_document import DocumentType, BasicDocument +from data_source.exception import InvalidDataSourceConfig +from queues.index_queue import IndexQueue @dataclass @@ -69,8 +69,8 @@ def __init__(self, *args, **kwargs): server_url=rocket_chat_config.url) self._authors_cache: Dict[str, RocketchatAuthor] = {} - def _list_rooms(self, oldest: datetime) -> List[RocketchatRoom]: - oldest = oldest.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + def _list_rooms(self) -> List[RocketchatRoom]: + oldest = self._last_index_time.strftime("%Y-%m-%dT%H:%M:%S.%fZ") r = self._rocket_chat.call_api_get("rooms.get", updatedSince=oldest) json = r.json() data = json.get("update") @@ -109,8 +109,8 @@ def _list_threads(self, channel: RocketchatRoom) -> List[RocketchatThread]: total = json.get("total") return [RocketchatThread(id=trds["_id"], name=trds["msg"], channel_id=trds["rid"]) for trds in data] - def _list_messages(self, channel: RocketchatRoom, oldest: datetime): - oldest = oldest.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + def _list_messages(self, channel: RocketchatRoom): + oldest = self._last_index_time.strftime("%Y-%m-%dT%H:%M:%S.%fZ") data = [] while oldest: r = self._rocket_chat.call_api_get("chat.syncMessages", roomId=channel.id, lastUpdate=oldest) @@ -123,8 +123,8 @@ def _list_messages(self, channel: RocketchatRoom, oldest: datetime): oldest = None return data - def _list_thread_messages(self, thread: RocketchatThread, oldest: datetime): - oldest = oldest.strftime("%Y-%m-%dT%H:%M:%S.%fZ") + def _list_thread_messages(self, thread: RocketchatThread): + oldest = self._last_index_time.strftime("%Y-%m-%dT%H:%M:%S.%fZ") data = [] records = 0 total = 1 # Set 1 to enter the loop @@ -150,48 +150,50 @@ def _get_author_details(self, author_id: str) -> RocketchatAuthor: return author def _feed_new_documents(self) -> None: - documents = [] - for channel in self._list_rooms(self._last_index_time): - last_msg: Optional[BasicDocument] = None + for channel in self._list_rooms(): + self.add_task_to_queue(self._feed_channel, channel=channel) - messages = self._list_messages(channel, self._last_index_time) + def _feed_channel(self, channel): + messages = self._list_messages(channel) + threads = self._list_threads(channel) + for thread in threads: + messages += self._list_thread_messages(thread) - threads = self._list_threads(channel) - for thread in threads: - messages += self._list_thread_messages(thread, self._last_index_time) - - logging.info(f"Getting {len(messages)} messages from room {channel.name} ({channel.id})" - f" with {len(threads)} threads") - - for message in messages: - if "msg" not in message: - continue - text = message["msg"] - author_id = message["u"]["_id"] - author = self._get_author_details(author_id) + logging.info(f"Got {len(messages)} messages from room {channel.name} ({channel.id})" + f" with {len(threads)} threads") + last_msg: Optional[BasicDocument] = None + for message in messages: + if "msg" not in message: if last_msg is not None: - if last_msg.author == author.name: - last_msg.content += f"\n{text}" - continue - else: - documents.append(last_msg) - - timestamp = message["ts"] - message_id = message["_id"] - readable_timestamp = datetime.datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%fZ") - message_url = f"{self._config.get('url')}/{channel.id}?msg={message_id}" - last_msg = BasicDocument(title=channel.name, content=text, author=author.name, - timestamp=readable_timestamp, id=message_id, - data_source_id=self._data_source_id, location=channel.name, - url=message_url, author_image_url=author.image_url, - type=DocumentType.MESSAGE) + IndexQueue.get_instance().put_single(doc=last_msg) + last_msg = None + continue - if last_msg is not None: - documents.append(last_msg) + text = message["msg"] + author_id = message["u"]["_id"] + author = self._get_author_details(author_id) - logging.info(f"Total messages : {len(documents)}") - IndexQueue.get_instance().put(docs=documents) + if last_msg is not None: + if last_msg.author == author.name: + last_msg.content += f"\n{text}" + continue + else: + IndexQueue.get_instance().put_single(doc=last_msg) + last_msg = None + + timestamp = message["ts"] + message_id = message["_id"] + readable_timestamp = datetime.datetime.strptime(timestamp, "%Y-%m-%dT%H:%M:%S.%fZ") + message_url = f"{self._config.get('url')}/{channel.id}?msg={message_id}" + last_msg = BasicDocument(title=channel.name, content=text, author=author.name, + timestamp=readable_timestamp, id=message_id, + data_source_id=self._data_source_id, location=channel.name, + url=message_url, author_image_url=author.image_url, + type=DocumentType.MESSAGE) + + if last_msg is not None: + IndexQueue.get_instance().put_single(doc=last_msg) if __name__ == "__main__": diff --git a/app/data_source/sources/slack/__init__.py b/app/data_source/sources/slack/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/data_sources/slack.py b/app/data_source/sources/slack/slack.py similarity index 82% rename from app/data_sources/slack.py rename to app/data_source/sources/slack/slack.py index abcf482..36fa2a1 100644 --- a/app/data_sources/slack.py +++ b/app/data_source/sources/slack/slack.py @@ -7,10 +7,10 @@ from pydantic import BaseModel from slack_sdk import WebClient -from data_source_api.base_data_source import BaseDataSource, ConfigField, HTMLInputType -from data_source_api.basic_document import DocumentType, BasicDocument -from data_source_api.utils import parse_with_workers -from index_queue import IndexQueue +from data_source.base_data_source import BaseDataSource, ConfigField, HTMLInputType +from data_source.basic_document import DocumentType, BasicDocument +from data_source.utils import parse_with_workers +from queues.index_queue import IndexQueue logger = logging.getLogger(__name__) @@ -93,24 +93,19 @@ def _feed_new_documents(self) -> None: joined_conversations = self._join_conversations(conversations) logger.info(f'Joined {len(joined_conversations)} conversations') - parse_with_workers(self._parse_conversations_worker, joined_conversations) - - def _parse_conversations_worker(self, conversations: List[SlackConversation]) -> None: - for conv in conversations: - self._feed_conversation(conv) + for conv in joined_conversations: + self.add_task_to_queue(self._feed_conversation, conv=conv) def _feed_conversation(self, conv): logger.info(f'Feeding conversation {conv.name}') last_msg: Optional[BasicDocument] = None - total_fed = 0 - documents = [] messages = self._fetch_conversation_messages(conv) for message in messages: if not self._is_valid_message(message): if last_msg is not None: - documents.append(last_msg) + IndexQueue.get_instance().put_single(doc=last_msg) last_msg = None continue @@ -122,11 +117,8 @@ def _feed_conversation(self, conv): last_msg.content += f"\n{text}" continue else: - documents.append(last_msg) - if len(documents) == SlackDataSource.FEED_BATCH_SIZE: - total_fed += SlackDataSource.FEED_BATCH_SIZE - IndexQueue.get_instance().put(docs=documents) - documents = [] + IndexQueue.get_instance().put_single(doc=last_msg) + last_msg = None timestamp = message['ts'] message_id = message['client_msg_id'] @@ -139,12 +131,7 @@ def _feed_conversation(self, conv): type=DocumentType.MESSAGE) if last_msg is not None: - documents.append(last_msg) - - IndexQueue.get_instance().put(docs=documents) - total_fed += len(documents) - if total_fed > 0: - logger.info(f'Slack worker fed {total_fed} documents') + IndexQueue.get_instance().put_single(doc=last_msg) def _fetch_conversation_messages(self, conv): messages = [] diff --git a/app/indexing/background_indexer.py b/app/indexing/background_indexer.py index 50c8295..ce6673b 100644 --- a/app/indexing/background_indexer.py +++ b/app/indexing/background_indexer.py @@ -2,7 +2,7 @@ import threading from typing import List -from index_queue import IndexQueue +from queues.index_queue import IndexQueue from indexing.index_documents import Indexer diff --git a/app/indexing/index_documents.py b/app/indexing/index_documents.py index 5022d9c..1e74517 100644 --- a/app/indexing/index_documents.py +++ b/app/indexing/index_documents.py @@ -2,7 +2,7 @@ import re from typing import List -from data_source_api.basic_document import BasicDocument +from data_source.basic_document import BasicDocument from paths import IS_IN_DOCKER from schemas import Document, Paragraph from models import bi_encoder diff --git a/app/main.py b/app/main.py index 48d9199..aff94ec 100644 --- a/app/main.py +++ b/app/main.py @@ -1,9 +1,5 @@ -import json import logging -import os from dataclasses import dataclass -from datetime import datetime -from typing import List import torch from fastapi import FastAPI, Request @@ -14,18 +10,17 @@ from api.data_source import router as data_source_router from api.search import router as search_router -from data_source_api.exception import KnownException -from data_source_api.utils import get_class_by_data_source_name +from data_source.exception import KnownException +from data_source.context import DataSourceContext from db_engine import Session from indexing.background_indexer import BackgroundIndexer from indexing.bm25_index import Bm25Index from indexing.faiss_index import FaissIndex -from index_queue import IndexQueue +from queues.index_queue import IndexQueue from paths import UI_PATH -from schemas import DataSource -from schemas.data_source_type import DataSourceType from schemas.document import Document from schemas.paragraph import Paragraph +from slaves import Slaves from telemetry import Posthog logging.basicConfig(level=logging.INFO, @@ -94,40 +89,21 @@ def send_daily_telemetry(): pass -def load_data_source_types(): - supported_data_source_type = [] - for file in os.listdir("data_sources"): - if file.endswith(".py") and file != "__init__.py": - supported_data_source_type.append(file[:-3]) - - with Session() as session: - for data_source_type in supported_data_source_type: - if session.query(DataSourceType).filter_by(name=data_source_type).first(): - continue - - data_source_class = get_class_by_data_source_name(data_source_type) - config_fields = data_source_class.get_config_fields() - config_fields_str = json.dumps([config_field.dict() for config_field in config_fields]) - new_data_source = DataSourceType(name=data_source_type, - display_name=data_source_class.get_display_name(), - config_fields=config_fields_str) - session.add(new_data_source) - session.commit() - - @app.on_event("startup") async def startup_event(): if not torch.cuda.is_available(): logger.warning("CUDA is not available, using CPU. This will make indexing and search very slow!!!") FaissIndex.create() Bm25Index.create() - load_data_source_types() + DataSourceContext.init() BackgroundIndexer.start() + Slaves.start() @app.on_event("shutdown") async def shutdown_event(): BackgroundIndexer.stop() + Slaves.stop() @app.get("/status") @@ -157,6 +133,6 @@ async def clear_index(): logger.warning(f"Failed to mount UI (you probably need to build it): {e}") -# if __name__ == '__main__': -# import uvicorn -# uvicorn.run("main:app", host="localhost", port=8000) +if __name__ == '__main__': + import uvicorn + uvicorn.run("main:app", host="localhost", port=8000) diff --git a/app/paths.py b/app/paths.py index a999f43..81646f6 100644 --- a/app/paths.py +++ b/app/paths.py @@ -11,6 +11,7 @@ UI_PATH = Path('/ui/') if IS_IN_DOCKER else Path('../ui/build/') SQLITE_DB_PATH = STORAGE_PATH / 'db.sqlite3' SQLITE_TASKS_PATH = STORAGE_PATH / 'tasks.sqlite3' +SQLITE_INDEXING_PATH = STORAGE_PATH / 'indexing.sqlite3' FAISS_INDEX_PATH = str(STORAGE_PATH / 'faiss_index.bin') BM25_INDEX_PATH = str(STORAGE_PATH / 'bm25_index.bin') UUID_PATH = str(STORAGE_PATH / '.uuid') diff --git a/app/queues/__init__.py b/app/queues/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/index_queue.py b/app/queues/index_queue.py similarity index 82% rename from app/index_queue.py rename to app/queues/index_queue.py index 56b6f76..75ac715 100644 --- a/app/index_queue.py +++ b/app/queues/index_queue.py @@ -4,8 +4,8 @@ from persistqueue import SQLiteAckQueue -from data_source_api.basic_document import BasicDocument -from paths import SQLITE_TASKS_PATH +from data_source.basic_document import BasicDocument +from paths import SQLITE_INDEXING_PATH @dataclass @@ -30,7 +30,7 @@ def __init__(self): raise RuntimeError("Queue is a singleton, use .get() to get the instance") self.condition = threading.Condition() - super().__init__(path=SQLITE_TASKS_PATH, multithreading=True, name="index") + super().__init__(path=SQLITE_INDEXING_PATH, multithreading=True, name="index") def put_single(self, doc: BasicDocument): self.put([doc]) @@ -49,8 +49,8 @@ def consume_all(self, max_docs=5000, timeout=1) -> List[IndexQueueItem]: queue_items = [] count = 0 while not super().empty() and count < max_docs: - raw_items = super().get(raw=True) - queue_items.append(IndexQueueItem(queue_item_id=raw_items['pqid'], doc=raw_items['data'])) + raw_item = super().get(raw=True) + queue_items.append(IndexQueueItem(queue_item_id=raw_item['pqid'], doc=raw_item['data'])) count += 1 return queue_items diff --git a/app/queues/task_queue.py b/app/queues/task_queue.py new file mode 100644 index 0000000..01f26bb --- /dev/null +++ b/app/queues/task_queue.py @@ -0,0 +1,50 @@ +import threading +from dataclasses import dataclass +from typing import Optional + +from persistqueue import SQLiteAckQueue, Empty + +from paths import SQLITE_TASKS_PATH + + +@dataclass +class Task: + data_source_id: int + function_name: str + kwargs: dict + + +@dataclass +class TaskQueueItem: + queue_item_id: int + task: Task + + +class TaskQueue(SQLiteAckQueue): + __instance = None + __lock = threading.Lock() + + @classmethod + def get_instance(cls): + with cls.__lock: + if cls.__instance is None: + cls.__instance = cls() + return cls.__instance + + def __init__(self): + if TaskQueue.__instance is not None: + raise RuntimeError("TaskQueue is a singleton, use .get() to get the instance") + + self.condition = threading.Condition() + super().__init__(path=SQLITE_TASKS_PATH, multithreading=True, name="task") + + def add_task(self, task: Task): + self.put(task) + + def get_task(self, timeout=1) -> Optional[TaskQueueItem]: + try: + raw_item = super().get(raw=True, block=True, timeout=timeout) + return TaskQueueItem(queue_item_id=raw_item['pqid'], task=raw_item['data']) + + except Empty: + return None diff --git a/app/search_logic.py b/app/search_logic.py index 35c2a35..23b42b3 100644 --- a/app/search_logic.py +++ b/app/search_logic.py @@ -12,8 +12,8 @@ import torch from sentence_transformers import CrossEncoder -from data_source_api.basic_document import DocumentType, FileType -from data_source_api.utils import get_confluence_user_image +from data_source.basic_document import DocumentType, FileType +from data_source.utils import get_confluence_user_image from db_engine import Session from indexing.bm25_index import Bm25Index from indexing.faiss_index import FaissIndex diff --git a/app/slaves.py b/app/slaves.py new file mode 100644 index 0000000..6c1e480 --- /dev/null +++ b/app/slaves.py @@ -0,0 +1,52 @@ +import logging +import threading + +from data_source.context import DataSourceContext +from queues.task_queue import TaskQueue + +logger = logging.getLogger() + + +class Slaves: + _threads = [] + _stop_event = threading.Event() + + @classmethod + def start(cls): + for i in range(0, 20): + cls._threads.append(threading.Thread(target=cls.run)) + for thread in cls._threads: + thread.start() + + @classmethod + def stop(cls): + cls._stop_event.set() + logging.info('Stop event set, waiting for slaves to stop...') + + for thread in cls._threads: + thread.join() + logging.info('Slaves stopped') + + cls._thread = None + + @staticmethod + def run(): + task_queue = TaskQueue.get_instance() + logger.info(f'Slave started...') + + while not Slaves._stop_event.is_set(): + task_item = task_queue.get_task() + if not task_item: + continue + + try: + data_source = DataSourceContext.get_data_source(task_item.task.data_source_id) + # load kwargs dict to real kwargs + data_source.run_task(task_item.task.function_name, **task_item.task.kwargs) + task_queue.ack(id=task_item.queue_item_id) + except Exception as e: + logger.exception(f'Failed to ack task {task_item.task.function_name} ' + f'for data source {task_item.task.data_source_id}') + task_queue.nack(id=task_item.queue_item_id) + import time + time.sleep(1) From 2e6e60477556736e1f3727bf3b27c4c0ef4b6fa9 Mon Sep 17 00:00:00 2001 From: roey Date: Mon, 27 Mar 2023 01:33:58 +0300 Subject: [PATCH 03/13] Initial commit for task-queue --- app/{data_source_api => data_source}/utils.py | 15 +-------------- app/slaves.py | 3 +-- 2 files changed, 2 insertions(+), 16 deletions(-) rename app/{data_source_api => data_source}/utils.py (74%) diff --git a/app/data_source_api/utils.py b/app/data_source/utils.py similarity index 74% rename from app/data_source_api/utils.py rename to app/data_source/utils.py index 43cedba..51a9cbd 100644 --- a/app/data_source_api/utils.py +++ b/app/data_source/utils.py @@ -1,5 +1,4 @@ import base64 -import importlib import logging import concurrent.futures from functools import lru_cache @@ -11,24 +10,12 @@ logger = logging.getLogger(__name__) -def _snake_case_to_pascal_case(snake_case_string: str): +def snake_case_to_pascal_case(snake_case_string: str): """Converts a snake case string to a PascalCase string""" components = snake_case_string.split('_') return "".join(x.title() for x in components) -def get_class_by_data_source_name(data_source_name: str): - class_name = f"{_snake_case_to_pascal_case(data_source_name)}DataSource" - - module = importlib.import_module(f"data_sources.{data_source_name}") - - try: - return getattr(module, class_name) - except AttributeError: - raise AttributeError(f"Class {class_name} not found in module {module}," - f"make sure you named the class correctly (it should be DataSource)") - - def parse_with_workers(method_name: callable, items: list, **kwargs): workers = 10 # should be a config value diff --git a/app/slaves.py b/app/slaves.py index 6c1e480..ff22456 100644 --- a/app/slaves.py +++ b/app/slaves.py @@ -1,5 +1,6 @@ import logging import threading +import time from data_source.context import DataSourceContext from queues.task_queue import TaskQueue @@ -41,12 +42,10 @@ def run(): try: data_source = DataSourceContext.get_data_source(task_item.task.data_source_id) - # load kwargs dict to real kwargs data_source.run_task(task_item.task.function_name, **task_item.task.kwargs) task_queue.ack(id=task_item.queue_item_id) except Exception as e: logger.exception(f'Failed to ack task {task_item.task.function_name} ' f'for data source {task_item.task.data_source_id}') task_queue.nack(id=task_item.queue_item_id) - import time time.sleep(1) From f4df376391e50af8a0629e2f0abf1f9e8c397665 Mon Sep 17 00:00:00 2001 From: Roey Lalazar <127092381+Roey7@users.noreply.github.com> Date: Mon, 27 Mar 2023 01:40:45 +0300 Subject: [PATCH 04/13] Update main.py --- app/main.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/app/main.py b/app/main.py index aff94ec..1fbba35 100644 --- a/app/main.py +++ b/app/main.py @@ -133,6 +133,6 @@ async def clear_index(): logger.warning(f"Failed to mount UI (you probably need to build it): {e}") -if __name__ == '__main__': - import uvicorn - uvicorn.run("main:app", host="localhost", port=8000) +# if __name__ == '__main__': +# import uvicorn +# uvicorn.run("main:app", host="localhost", port=8000) From 552b57bc3f7e66cca8c959702578040829cb49c4 Mon Sep 17 00:00:00 2001 From: roey Date: Mon, 27 Mar 2023 02:12:43 +0300 Subject: [PATCH 05/13] merge --- app/data_source/sources/confluence/confluence.py | 1 + 1 file changed, 1 insertion(+) diff --git a/app/data_source/sources/confluence/confluence.py b/app/data_source/sources/confluence/confluence.py index 88511f0..6e25668 100644 --- a/app/data_source/sources/confluence/confluence.py +++ b/app/data_source/sources/confluence/confluence.py @@ -114,6 +114,7 @@ def _feed_doc(self, raw_doc: Dict): logging.warning( f'unable to access document {doc_id} ({raw_doc["title"]}). reason: "{e.reason}". skipping.') return + author = fetched_raw_page['history']['createdBy']['displayName'] author_image = fetched_raw_page['history']['createdBy']['profilePicture']['path'] author_image_url = fetched_raw_page['_links']['base'] + author_image From 016d6756a5e660fec1022b0758e040dd1fe64927 Mon Sep 17 00:00:00 2001 From: roey Date: Mon, 27 Mar 2023 13:34:59 +0300 Subject: [PATCH 06/13] CR --- ...0b16_add_fields_to_datasourcetype_model.py | 2 +- app/api/data_source.py | 35 +++++++++++-------- .../{sources/gogle_drive => api}/__init__.py | 0 app/data_source/{ => api}/base_data_source.py | 0 app/data_source/{ => api}/basic_document.py | 3 +- app/data_source/{ => api}/context.py | 23 ++++++++++-- app/data_source/{ => api}/dynamic_loader.py | 6 +++- app/data_source/{ => api}/exception.py | 0 app/data_source/{ => api}/utils.py | 0 .../sources/bookstack/bookstack.py | 6 ++-- .../sources/confluence/confluence.py | 6 ++-- .../sources/confluence/confluence_cloud.py | 4 +-- .../google_drive.py | 6 ++-- .../sources/mattermost/mattermost.py | 6 ++-- .../sources/rocketchat/rocketchat.py | 6 ++-- app/data_source/sources/slack/slack.py | 5 ++- app/indexing/index_documents.py | 23 +++++++----- app/main.py | 13 +++++-- app/models.py | 2 +- app/queues/index_queue.py | 16 ++++----- app/queues/task_queue.py | 14 ++++---- app/search_logic.py | 4 +-- app/slaves.py | 3 +- ui/src/App.tsx | 9 ++--- ui/src/data-source.ts | 5 +++ 25 files changed, 120 insertions(+), 77 deletions(-) rename app/data_source/{sources/gogle_drive => api}/__init__.py (100%) rename app/data_source/{ => api}/base_data_source.py (100%) rename app/data_source/{ => api}/basic_document.py (96%) rename app/data_source/{ => api}/context.py (79%) rename app/data_source/{ => api}/dynamic_loader.py (93%) rename app/data_source/{ => api}/exception.py (100%) rename app/data_source/{ => api}/utils.py (100%) rename app/data_source/sources/{gogle_drive => google_drive}/google_drive.py (97%) diff --git a/app/alembic/versions/9c2f5b290b16_add_fields_to_datasourcetype_model.py b/app/alembic/versions/9c2f5b290b16_add_fields_to_datasourcetype_model.py index 694a039..02584b6 100644 --- a/app/alembic/versions/9c2f5b290b16_add_fields_to_datasourcetype_model.py +++ b/app/alembic/versions/9c2f5b290b16_add_fields_to_datasourcetype_model.py @@ -10,7 +10,7 @@ from alembic import op import sqlalchemy as sa -from data_source.dynamic_loader import DynamicLoader +from data_source.api.dynamic_loader import DynamicLoader from db_engine import Session from schemas import DataSourceType diff --git a/app/api/data_source.py b/app/api/data_source.py index c361b17..286d2cb 100644 --- a/app/api/data_source.py +++ b/app/api/data_source.py @@ -5,13 +5,13 @@ from fastapi import APIRouter from pydantic import BaseModel -from data_source.base_data_source import ConfigField -from data_source.context import DataSourceContext +from data_source.api.base_data_source import ConfigField +from data_source.api.context import DataSourceContext from db_engine import Session from schemas import DataSourceType, DataSource router = APIRouter( - prefix='/data-source', + prefix='/data-sources', ) @@ -36,7 +36,12 @@ def from_data_source_type(data_source_type: DataSourceType) -> 'DataSourceTypeDt ) -@router.get("/list-types") +class ConnectedDataSourceDto(BaseModel): + id: int + name: str + + +@router.get("/types") async def list_data_source_types() -> List[DataSourceTypeDto]: with Session() as session: data_source_types = session.query(DataSourceType).all() @@ -44,18 +49,12 @@ async def list_data_source_types() -> List[DataSourceTypeDto]: for data_source_type in data_source_types] -@router.get("/list-connected") -async def list_connected_data_sources() -> List[str]: +@router.get("/connected") +async def list_connected_data_sources() -> List[ConnectedDataSourceDto]: with Session() as session: data_sources = session.query(DataSource).all() - return [data_source.type.name for data_source in data_sources] - - -@router.get("/list") -async def list_connected_data_sources() -> List[dict]: - with Session() as session: - data_sources = session.query(DataSource).all() - return [{'id': data_source.id} for data_source in data_sources] + return [ConnectedDataSourceDto(id=data_source.id, name=data_source.type.name) + for data_source in data_sources] class AddDataSource(BaseModel): @@ -63,7 +62,13 @@ class AddDataSource(BaseModel): config: dict -@router.post("/add") +@router.delete("/{data_source_id}") +async def delete_data_source(data_source_id: int): + DataSourceContext.delete_data_source(data_source_id=data_source_id) + return {"success": "Data source deleted successfully"} + + +@router.post("") async def add_integration(dto: AddDataSource): data_source = DataSourceContext.create_data_source(name=dto.name, config=dto.config) diff --git a/app/data_source/sources/gogle_drive/__init__.py b/app/data_source/api/__init__.py similarity index 100% rename from app/data_source/sources/gogle_drive/__init__.py rename to app/data_source/api/__init__.py diff --git a/app/data_source/base_data_source.py b/app/data_source/api/base_data_source.py similarity index 100% rename from app/data_source/base_data_source.py rename to app/data_source/api/base_data_source.py diff --git a/app/data_source/basic_document.py b/app/data_source/api/basic_document.py similarity index 96% rename from app/data_source/basic_document.py rename to app/data_source/api/basic_document.py index 265562d..2790203 100644 --- a/app/data_source/basic_document.py +++ b/app/data_source/api/basic_document.py @@ -1,6 +1,7 @@ from datetime import datetime from dataclasses import dataclass from enum import Enum +from typing import Union class DocumentType(Enum): @@ -32,7 +33,7 @@ def from_mime_type(cls, mime_type: str): @dataclass class BasicDocument: - id: int | str + id: Union[int, str] data_source_id: int type: DocumentType title: str diff --git a/app/data_source/context.py b/app/data_source/api/context.py similarity index 79% rename from app/data_source/context.py rename to app/data_source/api/context.py index b0345c9..8579569 100644 --- a/app/data_source/context.py +++ b/app/data_source/api/context.py @@ -2,14 +2,19 @@ from datetime import datetime from typing import Dict, List -from data_source.base_data_source import BaseDataSource -from data_source.dynamic_loader import DynamicLoader, ClassInfo -from data_source.exception import KnownException +from data_source.api.base_data_source import BaseDataSource +from data_source.api.dynamic_loader import DynamicLoader, ClassInfo +from data_source.api.exception import KnownException from db_engine import Session from schemas import DataSourceType, DataSource class DataSourceContext: + """ + This class is responsible for loading data sources and caching them. + It dynamically loads data source types from the data_source/sources directory. + It loads data sources from the database and caches them. + """ _initialized = False _data_sources: Dict[int, BaseDataSource] = {} @@ -43,6 +48,18 @@ def create_data_source(cls, name: str, config: dict) -> BaseDataSource: return data_source + @classmethod + def delete_data_source(cls, data_source_id: int): + with Session() as session: + data_source = session.query(DataSource).filter_by(id=data_source_id).first() + if data_source is None: + raise KnownException(message=f"Data source {data_source_id} does not exist") + + session.delete(data_source) + session.commit() + + del cls._data_sources[data_source_id] + @classmethod def init(cls): cls._add_data_sources_to_db() diff --git a/app/data_source/dynamic_loader.py b/app/data_source/api/dynamic_loader.py similarity index 93% rename from app/data_source/dynamic_loader.py rename to app/data_source/api/dynamic_loader.py index a1cf81c..5fd02fa 100644 --- a/app/data_source/dynamic_loader.py +++ b/app/data_source/api/dynamic_loader.py @@ -5,7 +5,7 @@ from typing import Dict import importlib -from data_source.utils import snake_case_to_pascal_case +from data_source.api.utils import snake_case_to_pascal_case @dataclass @@ -15,6 +15,10 @@ class ClassInfo: class DynamicLoader: + """ + This class is used to dynamically load classes from files. + Specifically, it is used to load data sources from the data_source/sources directory. + """ SOURCES_PATH = 'data_source/sources' @staticmethod diff --git a/app/data_source/exception.py b/app/data_source/api/exception.py similarity index 100% rename from app/data_source/exception.py rename to app/data_source/api/exception.py diff --git a/app/data_source/utils.py b/app/data_source/api/utils.py similarity index 100% rename from app/data_source/utils.py rename to app/data_source/api/utils.py diff --git a/app/data_source/sources/bookstack/bookstack.py b/app/data_source/sources/bookstack/bookstack.py index 9bafd6a..5211fb6 100644 --- a/app/data_source/sources/bookstack/bookstack.py +++ b/app/data_source/sources/bookstack/bookstack.py @@ -8,9 +8,9 @@ from requests import Session, HTTPError from requests.auth import AuthBase -from data_source.base_data_source import BaseDataSource, ConfigField, HTMLInputType -from data_source.basic_document import BasicDocument, DocumentType -from data_source.exception import InvalidDataSourceConfig +from data_source.api.base_data_source import BaseDataSource, ConfigField, HTMLInputType +from data_source.api.basic_document import BasicDocument, DocumentType +from data_source.api.exception import InvalidDataSourceConfig from parsers.html import html_to_text from queues.index_queue import IndexQueue diff --git a/app/data_source/sources/confluence/confluence.py b/app/data_source/sources/confluence/confluence.py index 6e25668..b7f449a 100644 --- a/app/data_source/sources/confluence/confluence.py +++ b/app/data_source/sources/confluence/confluence.py @@ -6,9 +6,9 @@ from pydantic import BaseModel from atlassian.errors import ApiError from requests import HTTPError -from data_source.base_data_source import BaseDataSource, ConfigField, HTMLInputType -from data_source.basic_document import BasicDocument, DocumentType -from data_source.exception import InvalidDataSourceConfig +from data_source.api.base_data_source import BaseDataSource, ConfigField, HTMLInputType +from data_source.api.basic_document import BasicDocument, DocumentType +from data_source.api.exception import InvalidDataSourceConfig from parsers.html import html_to_text from queues.index_queue import IndexQueue diff --git a/app/data_source/sources/confluence/confluence_cloud.py b/app/data_source/sources/confluence/confluence_cloud.py index 3405969..6b4ba58 100644 --- a/app/data_source/sources/confluence/confluence_cloud.py +++ b/app/data_source/sources/confluence/confluence_cloud.py @@ -3,8 +3,8 @@ from atlassian import Confluence from pydantic import BaseModel -from data_source.base_data_source import ConfigField, HTMLInputType -from data_source.exception import InvalidDataSourceConfig +from data_source.api.base_data_source import ConfigField, HTMLInputType +from data_source.api.exception import InvalidDataSourceConfig from data_source.sources.confluence.confluence import ConfluenceDataSource diff --git a/app/data_source/sources/gogle_drive/google_drive.py b/app/data_source/sources/google_drive/google_drive.py similarity index 97% rename from app/data_source/sources/gogle_drive/google_drive.py rename to app/data_source/sources/google_drive/google_drive.py index ec367ba..b52fc2f 100644 --- a/app/data_source/sources/gogle_drive/google_drive.py +++ b/app/data_source/sources/google_drive/google_drive.py @@ -13,9 +13,9 @@ from oauth2client.service_account import ServiceAccountCredentials from pydantic import BaseModel -from data_source.base_data_source import BaseDataSource, ConfigField, HTMLInputType -from data_source.basic_document import BasicDocument, DocumentType, FileType -from data_source.exception import KnownException +from data_source.api.base_data_source import BaseDataSource, ConfigField, HTMLInputType +from data_source.api.basic_document import BasicDocument, DocumentType, FileType +from data_source.api.exception import KnownException from parsers.docx import docx_to_html from parsers.html import html_to_text from parsers.pptx import pptx_to_text diff --git a/app/data_source/sources/mattermost/mattermost.py b/app/data_source/sources/mattermost/mattermost.py index 3e2f033..482a782 100644 --- a/app/data_source/sources/mattermost/mattermost.py +++ b/app/data_source/sources/mattermost/mattermost.py @@ -7,9 +7,9 @@ from mattermostdriver import Driver -from data_source.base_data_source import BaseDataSource, ConfigField, HTMLInputType -from data_source.basic_document import BasicDocument, DocumentType -from data_source.exception import InvalidDataSourceConfig +from data_source.api.base_data_source import BaseDataSource, ConfigField, HTMLInputType +from data_source.api.basic_document import BasicDocument, DocumentType +from data_source.api.exception import InvalidDataSourceConfig from queues.index_queue import IndexQueue logger = logging.getLogger(__name__) diff --git a/app/data_source/sources/rocketchat/rocketchat.py b/app/data_source/sources/rocketchat/rocketchat.py index 3fa1185..b03160d 100644 --- a/app/data_source/sources/rocketchat/rocketchat.py +++ b/app/data_source/sources/rocketchat/rocketchat.py @@ -6,9 +6,9 @@ from pydantic import BaseModel from rocketchat_API.rocketchat import RocketChat -from data_source.base_data_source import BaseDataSource, ConfigField, HTMLInputType -from data_source.basic_document import DocumentType, BasicDocument -from data_source.exception import InvalidDataSourceConfig +from data_source.api.base_data_source import BaseDataSource, ConfigField, HTMLInputType +from data_source.api.basic_document import DocumentType, BasicDocument +from data_source.api.exception import InvalidDataSourceConfig from queues.index_queue import IndexQueue diff --git a/app/data_source/sources/slack/slack.py b/app/data_source/sources/slack/slack.py index 36fa2a1..ffe9e32 100644 --- a/app/data_source/sources/slack/slack.py +++ b/app/data_source/sources/slack/slack.py @@ -7,9 +7,8 @@ from pydantic import BaseModel from slack_sdk import WebClient -from data_source.base_data_source import BaseDataSource, ConfigField, HTMLInputType -from data_source.basic_document import DocumentType, BasicDocument -from data_source.utils import parse_with_workers +from data_source.api.base_data_source import BaseDataSource, ConfigField, HTMLInputType +from data_source.api.basic_document import DocumentType, BasicDocument from queues.index_queue import IndexQueue logger = logging.getLogger(__name__) diff --git a/app/indexing/index_documents.py b/app/indexing/index_documents.py index 0beead4..6995206 100644 --- a/app/indexing/index_documents.py +++ b/app/indexing/index_documents.py @@ -2,7 +2,7 @@ import re from typing import List -from data_source.basic_document import BasicDocument +from data_source.api.basic_document import BasicDocument from paths import IS_IN_DOCKER from schemas import Document, Paragraph from models import bi_encoder @@ -24,14 +24,14 @@ def index_documents(documents: List[BasicDocument]): with Session() as session: documents_to_delete = session.query(Document).filter(Document.id_in_data_source.in_(ids_in_data_source)).all() - - logging.info(f'removing documents that were updated and need to be re-indexed.') - Indexer.remove_documents(documents_to_delete, session) - for document in documents_to_delete: - # Currently bulk deleting doesn't cascade. So we need to delete them one by one. - # See https://stackoverflow.com/a/19245058/3541901 - session.delete(document) - session.commit() + if documents_to_delete: + logging.info(f'removing documents that were updated and need to be re-indexed.') + Indexer.remove_documents(documents_to_delete, session) + for document in documents_to_delete: + # Currently bulk deleting doesn't cascade. So we need to delete them one by one. + # See https://stackoverflow.com/a/19245058/3541901 + session.delete(document) + session.commit() with Session() as session: db_documents = [] @@ -63,7 +63,12 @@ def index_documents(documents: List[BasicDocument]): session.commit() # Create a list of all the paragraphs in the documents + logger.info(f"Indexing {len(db_documents)} documents => {len(paragraphs)} paragraphs") paragraphs = [paragraph for document in db_documents for paragraph in document.paragraphs] + if len(paragraphs) == 0: + logger.info(f"No paragraphs to index") + return + paragraph_ids = [paragraph.id for paragraph in paragraphs] paragraph_contents = [Indexer._add_metadata_for_indexing(paragraph) for paragraph in paragraphs] diff --git a/app/main.py b/app/main.py index 72affcf..7fe661f 100644 --- a/app/main.py +++ b/app/main.py @@ -1,5 +1,8 @@ +import json +from datetime import datetime import logging from dataclasses import dataclass +from typing import List import torch from fastapi import FastAPI, Request @@ -10,14 +13,16 @@ from api.data_source import router as data_source_router from api.search import router as search_router -from data_source.exception import KnownException -from data_source.context import DataSourceContext +from data_source.api.dynamic_loader import DynamicLoader +from data_source.api.exception import KnownException +from data_source.api.context import DataSourceContext from db_engine import Session from indexing.background_indexer import BackgroundIndexer from indexing.bm25_index import Bm25Index from indexing.faiss_index import FaissIndex from queues.index_queue import IndexQueue from paths import UI_PATH +from schemas import DataSource from schemas.document import Document from schemas.paragraph import Paragraph from slaves import Slaves @@ -26,6 +31,7 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s | %(levelname)s | %(filename)s:%(lineno)d | %(message)s') logger = logging.getLogger(__name__) +logging.getLogger("urllib3").propagate = False app = FastAPI() @@ -63,12 +69,13 @@ def _check_for_new_documents(force=False): continue logger.info(f"Checking for new docs in {data_source.type.name} (id: {data_source.id})") - data_source_cls = get_class_by_data_source_name(data_source.type.name) + data_source_cls = DynamicLoader.get_data_source_class(data_source.type.name) config = json.loads(data_source.config) data_source_instance = data_source_cls(config=config, data_source_id=data_source.id, last_index_time=data_source.last_indexed_at) data_source_instance.index() + @app.on_event("startup") @repeat_every(seconds=60) def check_for_new_documents(): diff --git a/app/models.py b/app/models.py index 51c5d3c..bf27efc 100644 --- a/app/models.py +++ b/app/models.py @@ -8,4 +8,4 @@ cross_encoder_small = CrossEncoder('cross-encoder/ms-marco-TinyBERT-L-2-v2') cross_encoder_large = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2') -qa_model = pipeline('question-answering', model='deepset/minilm-uncased-squad2') +qa_model = pipeline('question-answering', model='deepset/roberta-base-squad2') diff --git a/app/queues/index_queue.py b/app/queues/index_queue.py index 75ac715..5aed974 100644 --- a/app/queues/index_queue.py +++ b/app/queues/index_queue.py @@ -4,7 +4,7 @@ from persistqueue import SQLiteAckQueue -from data_source.basic_document import BasicDocument +from data_source.api.basic_document import BasicDocument from paths import SQLITE_INDEXING_PATH @@ -15,18 +15,18 @@ class IndexQueueItem: class IndexQueue(SQLiteAckQueue): - __instance = None - __lock = threading.Lock() + _instance = None + _lock = threading.Lock() @classmethod def get_instance(cls): - with cls.__lock: - if cls.__instance is None: - cls.__instance = cls() - return cls.__instance + with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance def __init__(self): - if IndexQueue.__instance is not None: + if IndexQueue._instance is not None: raise RuntimeError("Queue is a singleton, use .get() to get the instance") self.condition = threading.Condition() diff --git a/app/queues/task_queue.py b/app/queues/task_queue.py index 01f26bb..f453242 100644 --- a/app/queues/task_queue.py +++ b/app/queues/task_queue.py @@ -21,18 +21,18 @@ class TaskQueueItem: class TaskQueue(SQLiteAckQueue): - __instance = None - __lock = threading.Lock() + _instance = None + _lock = threading.Lock() @classmethod def get_instance(cls): - with cls.__lock: - if cls.__instance is None: - cls.__instance = cls() - return cls.__instance + with cls._lock: + if cls._instance is None: + cls._instance = cls() + return cls._instance def __init__(self): - if TaskQueue.__instance is not None: + if TaskQueue._instance is not None: raise RuntimeError("TaskQueue is a singleton, use .get() to get the instance") self.condition = threading.Condition() diff --git a/app/search_logic.py b/app/search_logic.py index 23b42b3..2a178ee 100644 --- a/app/search_logic.py +++ b/app/search_logic.py @@ -12,8 +12,8 @@ import torch from sentence_transformers import CrossEncoder -from data_source.basic_document import DocumentType, FileType -from data_source.utils import get_confluence_user_image +from data_source.api.basic_document import DocumentType, FileType +from data_source.api.utils import get_confluence_user_image from db_engine import Session from indexing.bm25_index import Bm25Index from indexing.faiss_index import FaissIndex diff --git a/app/slaves.py b/app/slaves.py index ff22456..c3684bc 100644 --- a/app/slaves.py +++ b/app/slaves.py @@ -2,7 +2,7 @@ import threading import time -from data_source.context import DataSourceContext +from data_source.api.context import DataSourceContext from queues.task_queue import TaskQueue logger = logging.getLogger() @@ -48,4 +48,3 @@ def run(): logger.exception(f'Failed to ack task {task_item.task.function_name} ' f'for data source {task_item.task.data_source_id}') task_queue.nack(id=task_item.queue_item_id) - time.sleep(1) diff --git a/ui/src/App.tsx b/ui/src/App.tsx index 24c8d3f..39a0eb5 100644 --- a/ui/src/App.tsx +++ b/ui/src/App.tsx @@ -22,7 +22,7 @@ import 'react-toastify/dist/ReactToastify.css'; import { ClipLoader } from "react-spinners"; import { FiSettings } from "react-icons/fi"; import {AiFillWarning} from "react-icons/ai"; -import { DataSourceType } from "./data-source"; +import { ConnectedDataSourceType, DataSourceType } from "./data-source"; export interface AppState { query: string @@ -125,7 +125,7 @@ export default class App extends React.Component <{}, AppState>{ async listDataSourceTypes() { try { - const response = await api.get('/data-source/list-types'); + const response = await api.get('/data-source/types'); let dataSourceTypesDict: { [key: string]: DataSourceType } = {}; response.data.forEach((dataSourceType) => { dataSourceTypesDict[dataSourceType.name] = dataSourceType; @@ -137,8 +137,9 @@ export default class App extends React.Component <{}, AppState>{ async listConnectedDataSources() { try { - const response = await api.get('/data-source/list-connected'); - this.setState({ connectedDataSources: response.data }) + const response = await api.get('/data-source/connected'); + let nameList = response.data.map((dataSource) => dataSource.name); + this.setState({ connectedDataSources: nameList }) } catch (error) { } } diff --git a/ui/src/data-source.ts b/ui/src/data-source.ts index d4f4a05..0f3c176 100644 --- a/ui/src/data-source.ts +++ b/ui/src/data-source.ts @@ -18,4 +18,9 @@ export interface DataSourceType { display_name: string config_fields: ConfigField[] image_base64: string +} + +export interface ConnectedDataSourceType { + id: number + name: string } \ No newline at end of file From 68649d01f43c7ed5524107e8ec65a24d3a8fadfc Mon Sep 17 00:00:00 2001 From: roey Date: Mon, 27 Mar 2023 13:35:08 +0300 Subject: [PATCH 07/13] CR --- app/data_source/sources/google_drive/__init__.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) create mode 100644 app/data_source/sources/google_drive/__init__.py diff --git a/app/data_source/sources/google_drive/__init__.py b/app/data_source/sources/google_drive/__init__.py new file mode 100644 index 0000000..e69de29 From 11b0d476b89da1ab6991f6e1dcb0d236c59161a4 Mon Sep 17 00:00:00 2001 From: roey Date: Mon, 27 Mar 2023 14:11:09 +0300 Subject: [PATCH 08/13] Verify ssl env vars --- app/data_source/api/context.py | 6 ++---- app/data_source/sources/bookstack/bookstack.py | 7 +++++-- app/data_source/sources/confluence/confluence.py | 5 ++++- app/data_source/sources/confluence/confluence_cloud.py | 4 +++- app/data_source/sources/rocketchat/rocketchat.py | 3 ++- ui/src/App.tsx | 4 ++-- ui/src/components/data-source-panel.tsx | 2 +- 7 files changed, 19 insertions(+), 12 deletions(-) diff --git a/app/data_source/api/context.py b/app/data_source/api/context.py index 8579569..c412a16 100644 --- a/app/data_source/api/context.py +++ b/app/data_source/api/context.py @@ -41,10 +41,8 @@ def create_data_source(cls, name: str, config: dict) -> BaseDataSource: session.add(data_source_row) session.commit() - data_source_id = session.query(DataSource).filter_by(type_id=data_source_type.id) \ - .order_by(DataSource.id.desc()).first().id - data_source = data_source_class(config=config, data_source_id=data_source_id) - cls._data_sources[data_source_id] = data_source + data_source = data_source_class(config=config, data_source_id=data_source_row.id) + cls._data_sources[data_source_row.id] = data_source return data_source diff --git a/app/data_source/sources/bookstack/bookstack.py b/app/data_source/sources/bookstack/bookstack.py index 5211fb6..586545a 100644 --- a/app/data_source/sources/bookstack/bookstack.py +++ b/app/data_source/sources/bookstack/bookstack.py @@ -1,4 +1,5 @@ import logging +import os from datetime import datetime from time import sleep from typing import List, Dict @@ -29,6 +30,8 @@ def __call__(self, r): class BookStack(Session): + VERIFY_SSL = os.environ.get('BOOKSTACK_VERIFY_SSL') is not None + def __init__(self, url: str, token_id: str, token_secret: str, *args, **kwargs): super().__init__(*args, **kwargs) self.base_url = url @@ -40,7 +43,7 @@ def request(self, method, url_path, *args, **kwargs): sleep(1) url = urljoin(self.base_url, url_path) - r = super().request(method, url, verify=False, *args, **kwargs) + r = super().request(method, url, verify=BookStack.VERIFY_SSL, *args, **kwargs) if r.status_code != 200: if r.status_code == 429: @@ -50,7 +53,7 @@ def request(self, method, url_path, *args, **kwargs): sleep(60) self.rate_limit_reach = False logger.info("Done waiting for the API rate limit") - return self.request(method, url, verify=False, *args, **kwargs) + return self.request(method, url, verify=BookStack.VERIFY_SSL, *args, **kwargs) r.raise_for_status() return r diff --git a/app/data_source/sources/confluence/confluence.py b/app/data_source/sources/confluence/confluence.py index b7f449a..bb3bb91 100644 --- a/app/data_source/sources/confluence/confluence.py +++ b/app/data_source/sources/confluence/confluence.py @@ -1,6 +1,7 @@ import logging from datetime import datetime from typing import List, Dict +import os from atlassian import Confluence from pydantic import BaseModel @@ -53,7 +54,9 @@ def validate_config(config: Dict) -> None: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) confluence_config = ConfluenceConfig(**self._config) - self._confluence = Confluence(url=confluence_config.url, token=confluence_config.token, verify_ssl=False) + should_verify_ssl = os.environ.get('CONFLUENCE_VERIFY_SSL') is not None + self._confluence = Confluence(url=confluence_config.url, token=confluence_config.token, + verify_ssl=should_verify_ssl) def _list_spaces(self) -> List[Dict]: logger.info('Listing spaces') diff --git a/app/data_source/sources/confluence/confluence_cloud.py b/app/data_source/sources/confluence/confluence_cloud.py index 6b4ba58..49dfc99 100644 --- a/app/data_source/sources/confluence/confluence_cloud.py +++ b/app/data_source/sources/confluence/confluence_cloud.py @@ -1,4 +1,5 @@ from typing import List, Dict +import os from atlassian import Confluence from pydantic import BaseModel @@ -37,5 +38,6 @@ def validate_config(config: Dict) -> None: def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) confluence_config = ConfluenceCloudConfig(**self._config) + should_verify_ssl = os.environ.get('CONFLUENCE_CLOUD_VERIFY_SSL') is not None self._confluence = Confluence(url=confluence_config.url, username=confluence_config.username, - password=confluence_config.token, verify_ssl=False, cloud=True) + password=confluence_config.token, verify_ssl=should_verify_ssl, cloud=True) diff --git a/app/data_source/sources/rocketchat/rocketchat.py b/app/data_source/sources/rocketchat/rocketchat.py index b03160d..213036a 100644 --- a/app/data_source/sources/rocketchat/rocketchat.py +++ b/app/data_source/sources/rocketchat/rocketchat.py @@ -55,8 +55,9 @@ def get_display_name(cls) -> str: @staticmethod def validate_config(config: Dict) -> None: rocket_chat_config = RocketchatConfig(**config) + should_verify_ssl = os.environ.get('ROCKETCHAT_VERIFY_SSL') is not None rocket_chat = RocketChat(user_id=rocket_chat_config.token_id, auth_token=rocket_chat_config.token_secret, - server_url=rocket_chat_config.url) + server_url=rocket_chat_config.url, ssl_verify=should_verify_ssl) try: rocket_chat.me().json() except Exception as e: diff --git a/ui/src/App.tsx b/ui/src/App.tsx index 39a0eb5..be2f9c4 100644 --- a/ui/src/App.tsx +++ b/ui/src/App.tsx @@ -125,7 +125,7 @@ export default class App extends React.Component <{}, AppState>{ async listDataSourceTypes() { try { - const response = await api.get('/data-source/types'); + const response = await api.get('/data-sources/types'); let dataSourceTypesDict: { [key: string]: DataSourceType } = {}; response.data.forEach((dataSourceType) => { dataSourceTypesDict[dataSourceType.name] = dataSourceType; @@ -137,7 +137,7 @@ export default class App extends React.Component <{}, AppState>{ async listConnectedDataSources() { try { - const response = await api.get('/data-source/connected'); + const response = await api.get('/data-sources/connected'); let nameList = response.data.map((dataSource) => dataSource.name); this.setState({ connectedDataSources: nameList }) } catch (error) { diff --git a/ui/src/components/data-source-panel.tsx b/ui/src/components/data-source-panel.tsx index 7e2c1d3..d07c6c4 100644 --- a/ui/src/components/data-source-panel.tsx +++ b/ui/src/components/data-source-panel.tsx @@ -368,7 +368,7 @@ export default class DataSourcePanel extends React.Component { + api.post(`/data-sources`, payload).then(response => { toast.success("Data source added successfully, indexing..."); let selectedDataSource = this.state.selectedDataSource; From bc95c556317dda4a53cb57f41c5fed81bdf9766a Mon Sep 17 00:00:00 2001 From: roey Date: Mon, 27 Mar 2023 15:16:41 +0300 Subject: [PATCH 09/13] export num to const --- app/slaves.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/app/slaves.py b/app/slaves.py index c3684bc..198958c 100644 --- a/app/slaves.py +++ b/app/slaves.py @@ -11,10 +11,11 @@ class Slaves: _threads = [] _stop_event = threading.Event() + SLAVE_AMOUNT = 20 @classmethod def start(cls): - for i in range(0, 20): + for i in range(cls.SLAVE_AMOUNT): cls._threads.append(threading.Thread(target=cls.run)) for thread in cls._threads: thread.start() From 3911c16a894968eb3506f1f0e0aad36f9e7a44bf Mon Sep 17 00:00:00 2001 From: roey Date: Mon, 27 Mar 2023 15:30:28 +0300 Subject: [PATCH 10/13] Fix slack ratelimit --- app/data_source/sources/slack/slack.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/app/data_source/sources/slack/slack.py b/app/data_source/sources/slack/slack.py index ffe9e32..ebf98d6 100644 --- a/app/data_source/sources/slack/slack.py +++ b/app/data_source/sources/slack/slack.py @@ -6,6 +6,7 @@ from pydantic import BaseModel from slack_sdk import WebClient +from slack_sdk.errors import SlackApiError from data_source.api.base_data_source import BaseDataSource, ConfigField, HTMLInputType from data_source.api.basic_document import DocumentType, BasicDocument @@ -140,13 +141,17 @@ def _fetch_conversation_messages(self, conv): logger.info(f'Fetching messages for conversation {conv.name}') while has_more: - response = self._slack.conversations_history(channel=conv.id, oldest=str(last_index_unix), - limit=1000, cursor=cursor) - if not response['ok'] and response['error'] == 'ratelimited': - retry_after_seconds = int(response['headers']['Retry-After']) - logger.warning(f'Slack API rate limit exceeded, retrying after {retry_after_seconds} seconds') - time.sleep(retry_after_seconds) - continue + try: + response = self._slack.conversations_history(channel=conv.id, oldest=str(last_index_unix), + limit=1000, cursor=cursor) + except SlackApiError as e: + logger.warning(f'Error fetching messages for conversation {conv.name}: {e}') + if e.response['error'] == 'ratelimited': + retry_after_seconds = int(response['headers']['Retry-After']) + logger.warning(f'Ratelimited: Slack API rate limit exceeded,' + f' retrying after {retry_after_seconds} seconds') + time.sleep(retry_after_seconds) + continue logger.info(f'Fetched {len(response["messages"])} messages for conversation {conv.name}') messages.extend(response['messages']) From 79dddd55083beed61c52bedb53eb202f0c0a5f3f Mon Sep 17 00:00:00 2001 From: roey Date: Mon, 27 Mar 2023 15:31:39 +0300 Subject: [PATCH 11/13] Fix slack ratelimit --- app/data_source/sources/slack/slack.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/app/data_source/sources/slack/slack.py b/app/data_source/sources/slack/slack.py index ebf98d6..672c148 100644 --- a/app/data_source/sources/slack/slack.py +++ b/app/data_source/sources/slack/slack.py @@ -146,7 +146,8 @@ def _fetch_conversation_messages(self, conv): limit=1000, cursor=cursor) except SlackApiError as e: logger.warning(f'Error fetching messages for conversation {conv.name}: {e}') - if e.response['error'] == 'ratelimited': + response = e.response + if response['error'] == 'ratelimited': retry_after_seconds = int(response['headers']['Retry-After']) logger.warning(f'Ratelimited: Slack API rate limit exceeded,' f' retrying after {retry_after_seconds} seconds') From e6147556e0cb8e918ee2ffbf1d16531f98c82d25 Mon Sep 17 00:00:00 2001 From: roey Date: Mon, 27 Mar 2023 17:44:34 +0300 Subject: [PATCH 12/13] Index every hour --- app/api/data_source.py | 5 +++-- app/data_source/api/base_data_source.py | 17 ++++++++++++++--- app/main.py | 14 ++++++-------- ui/src/components/data-source-panel.tsx | 1 + 4 files changed, 24 insertions(+), 13 deletions(-) diff --git a/app/api/data_source.py b/app/api/data_source.py index 286d2cb..376c63f 100644 --- a/app/api/data_source.py +++ b/app/api/data_source.py @@ -4,6 +4,7 @@ from fastapi import APIRouter from pydantic import BaseModel +from starlette.background import BackgroundTasks from data_source.api.base_data_source import ConfigField from data_source.api.context import DataSourceContext @@ -69,11 +70,11 @@ async def delete_data_source(data_source_id: int): @router.post("") -async def add_integration(dto: AddDataSource): +async def add_integration(dto: AddDataSource, background_tasks: BackgroundTasks): data_source = DataSourceContext.create_data_source(name=dto.name, config=dto.config) # in main.py we have a background task that runs every 5 minutes and indexes the data source # but here we want to index the data source immediately - data_source.add_task_to_queue(data_source.index) + background_tasks.add_task(data_source.index) return {"success": "Data source added successfully"} diff --git a/app/data_source/api/base_data_source.py b/app/data_source/api/base_data_source.py index b113eb2..b9903be 100644 --- a/app/data_source/api/base_data_source.py +++ b/app/data_source/api/base_data_source.py @@ -81,8 +81,12 @@ def __init__(self, config: Dict, data_source_id: int, last_index_time: datetime if last_index_time is None: last_index_time = datetime(2012, 1, 1) self._last_index_time = last_index_time + self._last_task_time = None - def _set_last_index_time(self) -> None: + def _save_index_time_in_db(self) -> None: + """ + Sets the index time in the database, to be now + """ with Session() as session: data_source: DataSource = session.query(DataSource).filter_by(id=self._data_source_id).first() data_source.last_indexed_at = datetime.now() @@ -95,12 +99,19 @@ def add_task_to_queue(self, function: Callable, **kwargs): TaskQueue.get_instance().add_task(task) def run_task(self, function_name: str, **kwargs) -> None: + self._last_task_time = datetime.now() function = getattr(self, function_name) function(**kwargs) - def index(self) -> None: + def index(self, force: bool = False) -> None: + if self._last_task_time is not None and not force: + # Don't index if the last task was less than an hour ago + time_since_last_task = datetime.now() - self._last_task_time + if time_since_last_task.total_seconds() < 60 * 60: + logging.info("Skipping indexing data source because it was indexed recently") + try: + self._save_index_time_in_db() self._feed_new_documents() - self._set_last_index_time() except Exception as e: logging.exception("Error while indexing data source") diff --git a/app/main.py b/app/main.py index 7fe661f..f156457 100644 --- a/app/main.py +++ b/app/main.py @@ -69,11 +69,9 @@ def _check_for_new_documents(force=False): continue logger.info(f"Checking for new docs in {data_source.type.name} (id: {data_source.id})") - data_source_cls = DynamicLoader.get_data_source_class(data_source.type.name) - config = json.loads(data_source.config) - data_source_instance = data_source_cls(config=config, data_source_id=data_source.id, - last_index_time=data_source.last_indexed_at) - data_source_instance.index() + data_source_instance = DataSourceContext.get_data_source(data_source_id=data_source.id) + data_source_instance._last_index_time = data_source.last_indexed_at + data_source_instance.index(force=force) @app.on_event("startup") @@ -147,6 +145,6 @@ async def check_for_new_documents_endpoint(): logger.warning(f"Failed to mount UI (you probably need to build it): {e}") -# if __name__ == '__main__': -# import uvicorn -# uvicorn.run("main:app", host="localhost", port=8000) +if __name__ == '__main__': + import uvicorn + uvicorn.run("main:app", host="localhost", port=8000) diff --git a/ui/src/components/data-source-panel.tsx b/ui/src/components/data-source-panel.tsx index d07c6c4..0b99505 100644 --- a/ui/src/components/data-source-panel.tsx +++ b/ui/src/components/data-source-panel.tsx @@ -266,6 +266,7 @@ export default class DataSourcePanel extends React.Component + *Gerev bot will join your channels. ) } From 3a5023cb0a680a843a3ea5bb278cfadbd042693d Mon Sep 17 00:00:00 2001 From: roey Date: Mon, 27 Mar 2023 18:06:46 +0300 Subject: [PATCH 13/13] add-button for generic data source in data-source-panel --- ui/src/components/data-source-panel.tsx | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/ui/src/components/data-source-panel.tsx b/ui/src/components/data-source-panel.tsx index 0b99505..7be0376 100644 --- a/ui/src/components/data-source-panel.tsx +++ b/ui/src/components/data-source-panel.tsx @@ -155,6 +155,10 @@ export default class DataSourcePanel extends React.Component this.setState({ isAdding: true})} className="flex hover:text-[#9875d4] py-2 pl-5 pr-3 m-2 flex-row items-center justify-center bg-[#36323b] hover:border-[#9875d4] rounded-lg font-poppins leading-[28px] border-[#777777] border-b-[.5px] transition duration-300 ease-in-out"> +

Add

+ + ) }