Skip to content

Commit

Permalink
Make it possible to pass other distance metrics to the brute force se…
Browse files Browse the repository at this point in the history
…arch
  • Loading branch information
JohnVinyard committed Aug 18, 2018
1 parent 6badaab commit 2e08213
Showing 1 changed file with 4 additions and 19 deletions.
23 changes: 4 additions & 19 deletions zounds/index/brute_force.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,29 +25,14 @@ def random_search(self, n_results=10):
return self.search(query, n_results)


# class BruteForceSearch(object):
# def __init__(self, gen):
# index = []
# self._ids = []
# for _id, example in gen:
# index.append(example)
# crts = ConstantRateTimeSeries(example)
# for ts, _ in crts.iter_slices():
# self._ids.append((_id, ts))
# self.index = np.concatenate(index)
#
# def random_search(self, n_results=10):
# query = choice(self.index)
# distances = cdist(query[None, ...], self.index)
# indices = np.argsort(distances[0])[:n_results]
# return SearchResults(query, (self._ids[i] for i in indices))

class BruteForceSearch(BaseBruteForceSearch):
def __init__(self, gen):
def __init__(self, gen, distance_metric='euclidean'):
super(BruteForceSearch, self).__init__(gen)
self.distance_metric = distance_metric

def search(self, query, n_results=10):
distances = cdist(query[None, ...], self.index)
distances = cdist(
query[None, ...], self.index, metric=self.distance_metric)
indices = np.argsort(distances[0])[:n_results]
return SearchResults(query, (self._ids[i] for i in indices))

Expand Down

0 comments on commit 2e08213

Please sign in to comment.