generated from fastai/nbdev_template
-
-
Notifications
You must be signed in to change notification settings - Fork 37
/
base.py
67 lines (59 loc) · 2.11 KB
/
base.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import abc
import itertools
from typing import Callable, List, Optional
from chitra.data_processing import (
DataProcessor,
DefaultTextProcessor,
DefaultVisionProcessor,
)
from chitra.serve import constants as const
class ModelServer:
API_TYPES = {
"VISION": (const.IMAGE_CLF, const.OBJECT_DETECTION),
"NLP": (const.TXT_CLF, const.QNA),
}
def __init__(
self,
api_type: str,
model: Callable,
preprocess_fn=None,
postprocess_fn=None,
preprocess_conf: Optional[dict] = None,
postprocess_conf: Optional[dict] = None,
**kwargs,
):
if not preprocess_conf:
preprocess_conf = {}
if not postprocess_conf:
postprocess_conf = {}
self.api_type = api_type.upper()
self.model = model
self.preprocess_conf = preprocess_conf
self.postprocess_conf = postprocess_conf
self.data_processor: Optional[DataProcessor] = self.set_data_processor(
preprocess_fn, postprocess_fn
)
@classmethod
def get_available_api_types(cls) -> List[str]:
return list(itertools.chain.from_iterable(cls.API_TYPES.values()))
def set_data_processor(
self, preprocess_fn: Callable, postprocess_fn: Callable
) -> DataProcessor:
data_preprocessor = self.set_default_processor()
if preprocess_fn:
data_preprocessor.set_preprocess_fn(preprocess_fn)
if postprocess_fn:
data_preprocessor.set_postprocess_fn(postprocess_fn)
return data_preprocessor
def set_default_processor(self) -> DataProcessor:
api_type = self.api_type
if api_type in ModelServer.API_TYPES.get("VISION"):
self.data_processor = DefaultVisionProcessor.vision
elif api_type in ModelServer.API_TYPES.get("NLP"):
self.data_processor = DefaultTextProcessor.nlp
else:
raise NotImplementedError(
f"{api_type} is not implemented! Available types are -\
{ModelServer.get_available_api_types()}"
)
return self.data_processor