-
Notifications
You must be signed in to change notification settings - Fork 86
/
run.py
87 lines (76 loc) · 2.75 KB
/
run.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
"""
Example script showing how one might setup a generic model training pipeline that is quickly configurable.
"""
import importlib
# Required import to register adapters
import os
import data_loaders
import model_training
from hamilton import base, driver
from hamilton.io.materialization import to
# This has to be imported, but the linter doesn't like it cause its unused
# We just need to import it to register the materializers
importlib.import_module("custom_materializers")
def get_model_config(model_type: str) -> dict:
"""Returns model type specific configuration"""
if model_type == "svm":
return {"clf": "svm", "gamma": 0.001}
elif model_type == "logistic":
return {"clf": "logistic", "penalty": "l2"}
else:
raise ValueError(f"Unsupported model {model_type}.")
if __name__ == "__main__":
import sys
if len(sys.argv) < 3:
print("Error: required arguments are [iris|digits] [svm|logistic]")
sys.exit(1)
_data_set = sys.argv[1] # the data set to load
_model_type = sys.argv[2] # the model type to fit and evaluate with
dag_config = {
"test_size_fraction": 0.5,
"shuffle_train_test_split": True,
}
if not os.path.exists("data"):
os.mkdir("data")
# augment config
dag_config.update(get_model_config(_model_type))
dag_config["data_loader"] = _data_set
dr = (
driver.Builder()
.with_adapter(base.DefaultAdapter())
.with_config(dag_config)
.with_modules(data_loaders, model_training)
.build()
)
materializers = [
to.json(
dependencies=["model_parameters"], id="model_params_to_json", path="./data/params.json"
),
# classification report to .txt file
to.file(
dependencies=["classification_report"],
id="classification_report_to_txt",
path="./data/classification_report.txt",
),
# materialize the model to a pickle file
to.pickle(dependencies=["fit_clf"], id="clf_to_pickle", path="./data/clf.pkl"),
# materialize the predictions we made to a csv file
to.csv(
dependencies=["predicted_output_with_labels"],
id="predicted_output_with_labels_to_csv",
path="./data/predicted_output_with_labels.csv",
),
]
dr.visualize_materialization(
*materializers,
additional_vars=["classification_report"],
output_file_path="./dag",
render_kwargs={},
)
materialization_results, additional_vars = dr.materialize(
# materialize model parameters to json
*materializers,
additional_vars=["classification_report"],
)
# print(materialization_results["classification_report"])
# print(additional_vars)