-
Notifications
You must be signed in to change notification settings - Fork 0
/
get_models.py
52 lines (39 loc) · 1.46 KB
/
get_models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import zipfile
import requests
class ModelFetcher:
"""
Takes model names and checks whether the zip file with the name modelName.zip
is in the google cloud bucket.
If it is newer, or if it dosn't exsists locally, it gets the zip file and
extracts it.
"""
def __init__(self, models):
if not os.path.exists(self.working_dir):
os.makedirs("ai_models")
self.models = models
self._checkModels()
working_dir = "ai_models"
def _shouldGetModel(self, modelName):
dest_folder = os.path.join(self.working_dir, modelName)
return not os.path.exists(dest_folder)
def _getModel(self, modelName):
print("Downloading model File")
response = requests.get(
f"https://storage.googleapis.com/bolius-ml-models/{modelName}.zip"
)
if response.status_code != 200:
raise ConnectionError("Could not get zip file")
zipDest = os.path.join(self.working_dir, f"{modelName}.zip")
with open(zipDest, "wb+") as zip:
zip.write(response.content)
ref = zipfile.ZipFile(zipDest, "r")
ref.extractall(self.working_dir)
ref.close()
os.remove(zipDest)
def _checkModels(self):
for model in self.models:
if self._shouldGetModel(model):
print(f"Getting model {model}")
self._getModel(model)
ModelFetcher(["energy", "komfort", "proposals", "radon"])