-
Notifications
You must be signed in to change notification settings - Fork 1
/
BucketFactory.py
35 lines (26 loc) · 1.58 KB
/
BucketFactory.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import EncoderFactory
from bucketers.ClusterBasedBucketer import ClusterBasedBucketer
from bucketers.StateBasedBucketer import StateBasedBucketer
from bucketers.PrefixLengthBucketer import PrefixLengthBucketer
from bucketers.NoBucketer import NoBucketer
#from bucketers.KNNBucketer import KNNBucketer
from sklearn.cluster import KMeans
def get_bucketer(method, encoding_method=None, case_id_col=None, cat_cols=None, num_cols=None, n_clusters=None, random_state=None, n_neighbors=None):
print(method)
if method == "cluster":
bucket_encoder = EncoderFactory.get_encoder(method=encoding_method, case_id_col=case_id_col, dynamic_cat_cols=cat_cols, dynamic_num_cols=num_cols)
clustering = KMeans(n_clusters, random_state=random_state)
return ClusterBasedBucketer(encoder=bucket_encoder, clustering=clustering)
elif method == "state":
bucket_encoder = EncoderFactory.get_encoder(method=encoding_method, case_id_col=case_id_col, dynamic_cat_cols=cat_cols, dynamic_num_cols=num_cols)
return StateBasedBucketer(encoder=bucket_encoder)
elif method == "single":
return NoBucketer(case_id_col=case_id_col)
elif method == "prefix":
return PrefixLengthBucketer(case_id_col=case_id_col)
#elif method == "knn":
# bucket_encoder = EncoderFactory.get_encoder(method=encoding_method, case_id_col=case_id_col, dynamic_cat_cols=cat_cols, dynamic_num_cols=num_cols)
# return KNNBucketer(encoder=bucket_encoder, n_neighbors=n_neighbors)
else:
print("Invalid bucketer type")
return None