Skip to content
24 changes: 24 additions & 0 deletions app/alembic/versions/792a820e9374_document_id_in_data_source.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""document id_in_data_source

Revision ID: 792a820e9374
Revises: 9c2f5b290b16
Create Date: 2023-03-26 11:27:05.341609

"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = '792a820e9374'
down_revision = '9c2f5b290b16'
branch_labels = None
depends_on = None


def upgrade() -> None:
op.add_column('document', sa.Column('id_in_data_source', sa.String(length=64), default='__none__'))


def downgrade() -> None:
op.drop_column('document', 'id_in_data_source')
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from alembic import op
import sqlalchemy as sa

from data_source_api.utils import get_class_by_data_source_name
from data_source.api.dynamic_loader import DynamicLoader
from db_engine import Session
from schemas import DataSourceType

Expand All @@ -29,7 +29,7 @@ def upgrade() -> None:
# update existing data sources
data_source_types = session.query(DataSourceType).all()
for data_source_type in data_source_types:
data_source_class = get_class_by_data_source_name(data_source_type.name)
data_source_class = DynamicLoader.get_data_source_class(data_source_type.name)
config_fields = data_source_class.get_config_fields()

data_source_type.config_fields = json.dumps([config_field.dict() for config_field in config_fields])
Expand Down
66 changes: 29 additions & 37 deletions app/api/data_source.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,18 @@
import base64
import json
from datetime import datetime
from typing import List

from fastapi import APIRouter, BackgroundTasks
from fastapi import APIRouter
from pydantic import BaseModel
from starlette.responses import Response
from starlette.background import BackgroundTasks

from data_source_api.base_data_source import ConfigField
from data_source_api.exception import KnownException
from data_source_api.utils import get_class_by_data_source_name
from data_source.api.base_data_source import ConfigField
from data_source.api.context import DataSourceContext
from db_engine import Session
from schemas import DataSourceType, DataSource

router = APIRouter(
prefix='/data-source',
prefix='/data-sources',
)


Expand All @@ -39,50 +37,44 @@ def from_data_source_type(data_source_type: DataSourceType) -> 'DataSourceTypeDt
)


@router.get("/list-types")
class ConnectedDataSourceDto(BaseModel):
id: int
name: str


@router.get("/types")
async def list_data_source_types() -> List[DataSourceTypeDto]:
with Session() as session:
data_source_types = session.query(DataSourceType).all()
return [DataSourceTypeDto.from_data_source_type(data_source_type)
for data_source_type in data_source_types]


@router.get("/list-connected")
async def list_connected_data_sources() -> List[str]:
@router.get("/connected")
async def list_connected_data_sources() -> List[ConnectedDataSourceDto]:
with Session() as session:
data_sources = session.query(DataSource).all()
return [data_source.type.name for data_source in data_sources]
return [ConnectedDataSourceDto(id=data_source.id, name=data_source.type.name)
for data_source in data_sources]


class AddDataSource(BaseModel):
name: str
config: dict


@router.post("/add")
@router.delete("/{data_source_id}")
async def delete_data_source(data_source_id: int):
DataSourceContext.delete_data_source(data_source_id=data_source_id)
return {"success": "Data source deleted successfully"}


@router.post("")
async def add_integration(dto: AddDataSource, background_tasks: BackgroundTasks):
with Session() as session:
data_source_type = session.query(DataSourceType).filter_by(name=dto.name).first()
if data_source_type is None:
return {"error": "Data source type does not exist"}

data_source_class = get_class_by_data_source_name(dto.name)
try:
data_source_class.validate_config(dto.config)
except KnownException as e:
return Response(e.message, status_code=501)

config_str = json.dumps(dto.config)
ds = DataSource(type_id=data_source_type.id, config=config_str, created_at=datetime.now())
session.add(ds)
session.commit()

data_source_id = session.query(DataSource).filter_by(type_id=data_source_type.id)\
.order_by(DataSource.id.desc()).first().id
data_source = data_source_class(config=dto.config, data_source_id=data_source_id)

# in main.py we have a background task that runs every 5 minutes and indexes the data source
# but here we want to index the data source immediately
background_tasks.add_task(data_source.index)

return {"success": "Data source added successfully"}
data_source = DataSourceContext.create_data_source(name=dto.name, config=dto.config)

# in main.py we have a background task that runs every 5 minutes and indexes the data source
# but here we want to index the data source immediately
background_tasks.add_task(data_source.index)

return {"success": "Data source added successfully"}
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from abc import abstractmethod, ABC
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional
from typing import Dict, List, Optional, Callable
import re

from pydantic import BaseModel

from db_engine import Session
from queues.task_queue import TaskQueue, Task
from schemas import DataSource


Expand Down Expand Up @@ -80,16 +81,37 @@ def __init__(self, config: Dict, data_source_id: int, last_index_time: datetime
if last_index_time is None:
last_index_time = datetime(2012, 1, 1)
self._last_index_time = last_index_time
self._last_task_time = None

def _set_last_index_time(self) -> None:
def _save_index_time_in_db(self) -> None:
"""
Sets the index time in the database, to be now
"""
with Session() as session:
data_source: DataSource = session.query(DataSource).filter_by(id=self._data_source_id).first()
data_source.last_indexed_at = datetime.now()
session.commit()

def index(self) -> None:
def add_task_to_queue(self, function: Callable, **kwargs):
task = Task(data_source_id=self._data_source_id,
function_name=function.__name__,
kwargs=kwargs)
TaskQueue.get_instance().add_task(task)

def run_task(self, function_name: str, **kwargs) -> None:
self._last_task_time = datetime.now()
function = getattr(self, function_name)
function(**kwargs)

def index(self, force: bool = False) -> None:
if self._last_task_time is not None and not force:
# Don't index if the last task was less than an hour ago
time_since_last_task = datetime.now() - self._last_task_time
if time_since_last_task.total_seconds() < 60 * 60:
logging.info("Skipping indexing data source because it was indexed recently")

try:
self._set_last_index_time()
self._save_index_time_in_db()
self._feed_new_documents()
except Exception as e:
logging.exception("Error while indexing data source")
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import datetime
from dataclasses import dataclass
from enum import Enum
from typing import Union


class DocumentType(Enum):
Expand Down Expand Up @@ -32,7 +33,7 @@ def from_mime_type(cls, mime_type: str):

@dataclass
class BasicDocument:
id: int
id: Union[int, str]
data_source_id: int
type: DocumentType
title: str
Expand All @@ -44,3 +45,7 @@ class BasicDocument:
url: str
file_type: FileType = None

@property
def id_in_data_source(self):
return str(self.data_source_id) + '_' + str(self.id)

98 changes: 98 additions & 0 deletions app/data_source/api/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
import json
from datetime import datetime
from typing import Dict, List

from data_source.api.base_data_source import BaseDataSource
from data_source.api.dynamic_loader import DynamicLoader, ClassInfo
from data_source.api.exception import KnownException
from db_engine import Session
from schemas import DataSourceType, DataSource


class DataSourceContext:
"""
This class is responsible for loading data sources and caching them.
It dynamically loads data source types from the data_source/sources directory.
It loads data sources from the database and caches them.
"""
_initialized = False
_data_sources: Dict[int, BaseDataSource] = {}

@classmethod
def get_data_source(cls, data_source_id: int) -> BaseDataSource:
if not cls._initialized:
cls.init()
cls._initialized = True

return cls._data_sources[data_source_id]

@classmethod
def create_data_source(cls, name: str, config: dict) -> BaseDataSource:
with Session() as session:
data_source_type = session.query(DataSourceType).filter_by(name=name).first()
if data_source_type is None:
raise KnownException(message=f"Data source type {name} does not exist")

data_source_class = DynamicLoader.get_data_source_class(name)
data_source_class.validate_config(config)
config_str = json.dumps(config)

data_source_row = DataSource(type_id=data_source_type.id, config=config_str, created_at=datetime.now())
session.add(data_source_row)
session.commit()

data_source = data_source_class(config=config, data_source_id=data_source_row.id)
cls._data_sources[data_source_row.id] = data_source

return data_source

@classmethod
def delete_data_source(cls, data_source_id: int):
with Session() as session:
data_source = session.query(DataSource).filter_by(id=data_source_id).first()
if data_source is None:
raise KnownException(message=f"Data source {data_source_id} does not exist")

session.delete(data_source)
session.commit()

del cls._data_sources[data_source_id]

@classmethod
def init(cls):
cls._add_data_sources_to_db()
cls._load_context_from_db()

@classmethod
def _load_context_from_db(cls):
with Session() as session:
data_sources: List[DataSource] = session.query(DataSource).all()
for data_source in data_sources:
data_source_cls = DynamicLoader.get_data_source_class(data_source.type.name)
config = json.loads(data_source.config)
data_source_instance = data_source_cls(config=config, data_source_id=data_source.id,
last_index_time=data_source.last_indexed_at)
cls._data_sources[data_source.id] = data_source_instance

cls._initialized = True

@classmethod
def _add_data_sources_to_db(cls):
data_sources: Dict[str, ClassInfo] = DynamicLoader.find_data_sources()

with Session() as session:
for source_name in data_sources.keys():
if session.query(DataSourceType).filter_by(name=source_name).first():
continue

class_info = data_sources[source_name]
data_source_class = DynamicLoader.get_class(file_path=class_info.file_path,
class_name=class_info.name)

config_fields = data_source_class.get_config_fields()
config_fields_str = json.dumps([config_field.dict() for config_field in config_fields])
new_data_source = DataSourceType(name=source_name,
display_name=data_source_class.get_display_name(),
config_fields=config_fields_str)
session.add(new_data_source)
session.commit()
Loading