-
Notifications
You must be signed in to change notification settings - Fork 316
/
service.py
46 lines (36 loc) · 1.33 KB
/
service.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from typing import List, Type
from fastapi import Depends
from rubrix.server.commons import telemetry
from rubrix.server.commons.config import TasksFactory
from rubrix.server.daos.records import DatasetRecordsDAO
from rubrix.server.services.datasets import ServiceDataset
from rubrix.server.services.tasks.commons import ServiceRecord
class RecordsStorageService:
_INSTANCE: "RecordsStorageService" = None
@classmethod
def get_instance(
cls,
dao: DatasetRecordsDAO = Depends(DatasetRecordsDAO.get_instance),
) -> "RecordsStorageService":
if not cls._INSTANCE:
cls._INSTANCE = cls(dao)
return cls._INSTANCE
def __init__(self, dao: DatasetRecordsDAO):
self.__dao__ = dao
async def store_records(
self,
dataset: ServiceDataset,
records: List[ServiceRecord],
record_type: Type[ServiceRecord],
) -> int:
"""Store a set of records"""
await telemetry.track_bulk(task=dataset.task, records=len(records))
metrics = TasksFactory.get_task_metrics(dataset.task)
if metrics:
for record in records:
record.metrics = metrics.record_metrics(record)
return self.__dao__.add_records(
dataset=dataset,
records=records,
record_class=record_type,
)