-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
65 lines (54 loc) · 2.89 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import asyncio
from californiahousing import *
from classification import *
from whitebox import *
from svm import *
from decisiontrees import *
from ensemble import *
from dimensionality import *
import dependency_injector.containers as containers
import dependency_injector.providers as providers
import sys
class ExperimentRunner:
def __init__(self, experiment):
self.__experiment = experiment
async def run_async(self):
await self.__experiment.run_async()
class ExperimentsContainer(containers.DeclarativeContainer):
experiments = {
"data_analysis": providers.Factory(DataAnalysisExperiment),
"preprocessing": providers.Factory(PreProcessingExperiment),
"models": providers.Factory(ModelsExperiment),
"binary_classification": providers.Factory(BinaryClassificationExperiment),
"multiclass_classification": providers.Factory(MulticlassClassificationExperiment),
"multilabel_classification": providers.Factory(MultilabelClassificationExperiment),
"multioutput_classification": providers.Factory(MultioutputClassificationExperiment),
"whitebox_linear_regression": providers.Factory(LinearRegressionExperiment),
"whitebox_polynomial_regression": providers.Factory(PolynomialRegressionExperiment),
"whitebox_logistic_regression": providers.Factory(LogisticRegressionExperiment),
"svm_linear_classification": providers.Factory(LinearSvmClassificationExperiment),
"svm_non_linear_classification": providers.Factory(NonLinearSvmClassificationExperiment),
"svm_linear_regression": providers.Factory(LinearSvmRegressionExperiment),
"svm_non_linear_regression": providers.Factory(NonLinearSvmRegressionExperiment),
"decision_tree_visualization": providers.Factory(DecisionTreeVisualizationExperiment),
"decision_tree_regularization": providers.Factory(DecisionTreeRegularizationExperiment),
"decision_tree_regression": providers.Factory(DecisionTreeRegressionExperiment),
"voting_classifier": providers.Factory(VotingClassifierExperiment),
"bagging_pasting": providers.Factory(BaggingAndPastingExperiment),
"random_forest": providers.Factory(RandomForestExperiment),
"ada_boosting": providers.Factory(AdaBoostingExperiment),
"stacking": providers.Factory(StackingExperiment),
"dimensionality": providers.Factory(DimensionalityReductionExperiment)
}
@classmethod
def get_experiment(cls):
if len(sys.argv) > 1 and sys.argv[1] in cls.experiments:
return cls.experiments[sys.argv[1]]
else:
return cls.experiments["data_analysis"]
class RunnersContainer(containers.DeclarativeContainer):
instance = providers.Factory(ExperimentRunner, experiment=ExperimentsContainer.get_experiment())
async def main():
runner = RunnersContainer.instance()
await runner.run_async()
asyncio.run(main())