Skip to content

Commit

Permalink
make more flexible
Browse files Browse the repository at this point in the history
  • Loading branch information
juanmirocks committed Jan 13, 2017
1 parent ae6b23c commit edf49b4
Showing 1 changed file with 15 additions and 7 deletions.
22 changes: 15 additions & 7 deletions nalaf/learning/lib/sklsvm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,8 @@ def __init__(self, model_path=None, classification_threshold=0.0, use_tree_kerne

def train(self, training_corpus, feature_set):
self.global_feature_set = feature_set
self.allowed_features_keys = {fkey for edge in training_corpus.edges() for fkey in edge.features.keys()}
self.final_allowed_key_mapping = {}
num_feat = 0
for allowed_feat_key in self.allowed_features_keys:
self.final_allowed_key_mapping[allowed_feat_key] = num_feat
num_feat += 1
self.allowed_features_keys, self.final_allowed_key_mapping = \
__class__._gen_allowed_and_final_mapping_features_keys(training_corpus)

X, y = __class__._convert_edges_to_SVC_instances(training_corpus, self.final_allowed_key_mapping, self.preprocess)
print_debug("Train SVC with #samples {} - #features {} - params: {}".format(X.shape[0], X.shape[1], str(self.model.get_params())))
Expand All @@ -66,6 +62,17 @@ def annotate(self, corpus):

return corpus.form_predicted_relations()

@staticmethod
def _gen_allowed_and_final_mapping_features_keys(corpus):
allowed_keys = {fkey for edge in corpus.edges() for fkey in edge.features.keys()}
final_mapping_keys = {}
num_feat = 0
for allowed_feat_key in allowed_keys:
final_mapping_keys[allowed_feat_key] = num_feat
num_feat += 1

return (allowed_keys, final_mapping_keys)

@staticmethod
def _convert_edges_to_SVC_instances(corpus, final_allowed_key_mapping, preprocess):
"""
Expand Down Expand Up @@ -101,7 +108,8 @@ def _convert_edges_to_SVC_instances(corpus, final_allowed_key_mapping, preproces
X = __class__._preprocess(X)
print_verbose("SVC, minx & max features after preprocessing:", sklearn.utils.sparsefuncs.min_max_axis(X, axis=0))

# selector = VarianceThreshold()
# # See: http://scikit-learn.org/stable/modules/feature_selection.html#removing-features-with-low-variance
# selector = VarianceThreshold(threshold=(.9 * (1 - .9)))
# X = selector.fit_transform(X)

end = time.time()
Expand Down

0 comments on commit edf49b4

Please sign in to comment.