/
solver.py
59 lines (46 loc) · 1.51 KB
/
solver.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import json
import toml
from cvpm.bundle import Bundle
from cvpm.server import run_server
from cvpm.utility import Downloader
class Solver(object):
def __init__(self, toml_file=None):
self._isReady = False
self.bundle = None
self._enable_train = False
if toml_file is None:
toml_file = "./pretrained/pretrained.toml"
self._prepare_models(toml_file)
@property
def enable_train(self):
return self._enable_train
@property
def is_ready(self):
return self._isReady
@property
def help_message(self):
if self.is_ready:
members = self.bundle.members()
return json.dumps(members)
else:
return json.dumps({"error": "Failed to start", "code": "101"}), 101
def _prepare_models(self, toml_file):
parsed_toml = toml.load(toml_file)
downloader = Downloader()
if "models" in parsed_toml.keys():
for each in parsed_toml["models"]:
downloader.download(each["url"], "pretrained")
def set_ready(self):
self._isReady = True
def set_bundle(self, bundle):
if issubclass(bundle, Bundle):
self.bundle = bundle
solver = self
bundle.add_solver(self=bundle, solver=solver)
def infer(self, input, config):
pass
def train(self, train_x, train_y, **kwargs):
pass
def start(self, port=None):
print('Server will run on port: ' + str(port))
run_server(self, port)