Skip to content

Commit

Permalink
Merge pull request #17 from Curly-Mo/dev
Browse files Browse the repository at this point in the history
matchers
  • Loading branch information
Curly-Mo committed Oct 8, 2021
2 parents fa64be9 + cd5da7a commit 2e16ee3
Show file tree
Hide file tree
Showing 2 changed files with 90 additions and 64 deletions.
121 changes: 69 additions & 52 deletions sample_id/ann/ann.py
Expand Up @@ -22,47 +22,36 @@
class Matcher(abc.ABC):
"""Nearest neighbor matcher that may use one of various implementations under the hood."""

tempdir = None

def __init__(self, metadata: MatcherMetadata):
self.index = 0
self.meta = metadata
self.model = self.init_model()

@classmethod
def create(
cls, n_features: int, sr: Optional[int] = None, hop_length: Optional[int] = None, metric: Optional[str] = None
) -> Matcher:
meta = MatcherMetadata(n_features=n_features, sr=sr, hop_length=hop_length, metric=metric)
return cls(meta)
@abc.abstractmethod
def init_model(self) -> Any:
"""Initialize the model."""
pass

@classmethod
def from_fingerprint(cls, fp: Fingerprint) -> Matcher:
matcher = cls.create(fp.descriptors.shape[1], sr=fp.sr, hop_length=fp.hop_length)
matcher.add_fingerprint(fp)
return matcher
@abc.abstractmethod
def save_model(self, filepath: str) -> str:
"""Save this matcher's model to disk."""
pass

@abc.abstractmethod
def init_model(self) -> Any:
def load_model(self, filepath: str) -> Any:
"""Load this matcher's model from disk."""
pass

def can_add_fingerprint(self, fingerprint: Fingerprint) -> Boolean:
"""Check if fingerprint can be added to matcher."""
if not self.meta.sr:
self.meta.sr = fingerprint.sr
if not self.meta.hop_length:
self.meta.hop_length = fingerprint.hop_length
if self.meta.sr != fingerprint.sr:
logger.warn(f"Can't add fingerprint with sr={fingerprint.sr}, must equal matcher sr={self.meta.sr}")
if self.meta.hop_length != fingerprint.hop_length:
logger.warn(
f"Can't add fingerprint with hop_length={fingerprint.hop_length}, must equal matcher hop_length={self.meta.hop_length}"
)
return True
@abc.abstractmethod
def nearest_neighbors(self, fp: Fingerprint, k: int) -> str:
"""Save this matcher's model to disk."""
pass

def add_fingerprint(self, fingerprint: Fingerprint) -> Matcher:
def add_fingerprint(self, fingerprint: Fingerprint, dedupe=True) -> Matcher:
"""Add a Fingerprint to the matcher."""
if self.can_add_fingerprint(fingerprint):
if dedupe:
fingerprint.remove_similar_keypoints()
logger.info(f"Adding {fingerprint} to index.")
self.meta.index_to_id = np.hstack([self.meta.index_to_id, fingerprint.keypoint_index_ids()])
self.meta.index_to_ms = np.hstack([self.meta.index_to_ms, fingerprint.keypoint_index_ms()])
Expand All @@ -72,34 +61,62 @@ def add_fingerprint(self, fingerprint: Fingerprint) -> Matcher:
self.index += 1
return self

def add_fingerprints(self, fingerprints: Iterable[Fingerprint]) -> Matcher:
def add_fingerprints(self, fingerprints: Iterable[Fingerprint], **kwargs) -> Matcher:
"""Add Fingerprints to the matcher."""
for fingerprint in fingerprints:
self.add_fingerprint(fingerprint)
self.add_fingerprint(fingerprint, **kwargs)
return self

def can_add_fingerprint(self, fingerprint: Fingerprint) -> Boolean:
"""Check if fingerprint can be added to matcher."""
if not self.meta.sr:
self.meta.sr = fingerprint.sr
if not self.meta.hop_length:
self.meta.hop_length = fingerprint.hop_length
if self.meta.sr != fingerprint.sr:
logger.warn(f"Can't add fingerprint with sr={fingerprint.sr}, must equal matcher sr={self.meta.sr}")
if self.meta.hop_length != fingerprint.hop_length:
logger.warn(
f"Can't add fingerprint with hop_length={fingerprint.hop_length}, must equal matcher hop_length={self.meta.hop_length}"
)
return True

def save(self, filepath: str, compress: bool = True) -> None:
"""Save this matcher to disk."""
with tempfile.TemporaryDirectory() as tmpdir:
tmp_model_path = os.path.join(tmpdir, MATCHER_FILENAME)
tmp_meta_path = os.path.join(tmpdir, META_FILENAME)
logger.info(f"Saving matcher model to {tmp_model_path}.")
tmp_model_path = self.save_model(tmp_model_path)
self.meta.save(tmp_meta_path)
with zipfile.ZipFile(filepath, "w") as zipf:
with zipfile.ZipFile(filepath, "w", compression=zipfile.ZIP_DEFLATED) as zipf:
logger.info(f"Zipping {tmp_model_path} and {tmp_meta_path} into {zipf.filename}")
zipf.write(tmp_model_path, arcname=MATCHER_FILENAME)
zipf.write(tmp_meta_path, arcname=META_FILENAME)

@abc.abstractmethod
def save_model(self, filepath: str) -> str:
"""Save this matcher's model to disk."""
pass
def unload(self) -> None:
"""Unload things from memory and cleanup any temporary files."""
self.model.unload()
if "tempdir" in vars(self):
self.tempdir.cleanup()

@abc.abstractmethod
def load_model(self, filepath: str) -> Any:
"""Load this matcher's model from disk."""
pass
@classmethod
def create(cls, sr: Optional[int] = None, hop_length: Optional[int] = None, **kwargs) -> Matcher:
"""Create an instance, pass any kwargs needed by the subclass."""
meta = MatcherMetadata(sr=sr, hop_length=hop_length, **kwargs)
return cls(meta)

@classmethod
def from_fingerprint(cls, fp: Fingerprint, **kwargs) -> Matcher:
"""Useful for determining metadata for the Matcher based on the data being added."""
matcher = cls.create(fp.descriptors.shape[1], sr=fp.sr, hop_length=fp.hop_length, **kwargs)
return matcher.add_fingerprint(fp, **kwargs)

@classmethod
def from_fingerprints(cls, fingerprints: Iterable[Fingerprint], **kwargs) -> Matcher:
"""My data is small, just create and train the entire matcher."""
fp = fingerprints[0]
matcher = cls.create(fp.descriptors.shape[1], sr=fp.sr, hop_length=fp.hop_length, **kwargs)
return matcher.add_fingerprints(fingerprints, **kwargs)

@classmethod
def load(cls, filepath: str) -> Matcher:
Expand All @@ -118,29 +135,21 @@ def load(cls, filepath: str) -> Matcher:
matcher.load_model(tmp_model_path)
return matcher

def unload(self) -> None:
self.model.unload()
if self.tempdir:
self.tempdir.cleanup()


class MatcherMetadata:
"""Metadata for a Matcher object."""

def __init__(
self,
n_features: Optional[int] = None,
metric: Optional[str] = None,
sr: Optional[int] = None,
hop_length: Optional[int] = None,
index_to_id=None,
index_to_ms=None,
index_to_kp=None,
**kwargs,
):
self.sr = sr
self.hop_length = hop_length
self.n_features = n_features
self.metric = metric
self.index_to_id = index_to_id
self.index_to_ms = index_to_ms
self.index_to_kp = index_to_kp
Expand All @@ -150,11 +159,13 @@ def __init__(
self.index_to_ms = np.array([], np.uint32)
if index_to_kp is None:
self.index_to_kp = np.empty(shape=(0, 4), dtype=np.float32)
for key, value in kwargs.items():
setattr(self, key, value)

def save(self, filepath: str, compress: bool = True) -> None:
"""Save this matcher's metadata to disk."""
save_fn = np.savez_compressed if compress else np.savez
logger.info(f"Saving matcher metadata to {filepath}.")
logger.info(f"Saving matcher metadata to {filepath}...")
save_fn(
filepath,
n_features=self.n_features,
Expand All @@ -169,9 +180,9 @@ def save(self, filepath: str, compress: bool = True) -> None:
@classmethod
def load(cls, filepath: str) -> MatcherMetadata:
"""Load this matcher's metadata from disk."""
logger.info(f"Loading matcher metadata from {filepath}.")
logger.info(f"Loading matcher metadata from {filepath}...")
with np.load(filepath) as data:
return cls(
meta = cls(
n_features=data["n_features"].item(),
metric=data["metric"].item(),
sr=data["sr"].item(),
Expand All @@ -180,3 +191,9 @@ def load(cls, filepath: str) -> MatcherMetadata:
index_to_ms=data["index_to_ms"],
index_to_kp=data["index_to_kp"],
)
logger.info(f"Loaded metadata: {meta}")
return meta

def __repr__(self):
attrs = ",".join(f"{k}={v}" for k, v in vars(self).items() if type(v) in (int, float, bool, str))
return f"MatcherMeta({attrs})"
33 changes: 21 additions & 12 deletions sample_id/ann/annoy.py
Expand Up @@ -12,31 +12,40 @@ class AnnoyMatcher(Matcher):
"""Nearest neighbor matcher using annoy."""

def __init__(self, metadata: MatcherMetadata):
metadata.metric = "euclidean"
metadata.metric = vars(metadata).get("metric", "euclidean")
metadata.n_features = vars(metadata).get("n_features", 128)
metadata.n_trees = vars(metadata).get("n_trees", -1)
metadata.n_jobs = vars(metadata).get("n_jobs", -1)
super().__init__(metadata)
self.on_disk = None
self.n_trees = -1
self.n_jobs = -1
self.built = False

def init_model(self) -> Any:
logger.info(f"Initializing Annoy Index with {self.meta}...")
return annoy.AnnoyIndex(self.meta.n_features, metric=self.meta.metric)

def build(self, n_trees: int = -1, n_jobs: int = -1) -> None:
logger.info("Building Annoy Index...")
self.model.build(n_trees, n_jobs)

def on_disk_build(self, filename: str) -> None:
self.model.on_disk_build(filename)
self.on_disk = filename

def save_model(self, filepath: str) -> str:
self.build(self.n_trees, self.n_jobs)
if not self.built:
self.build()
if self.on_disk:
logger.info(f"Annoy index already built on disk at {self.on_disk}.")
return self.on_disk
logger.info(f"Saving matcher model to {filepath}...")
self.model.save(filepath)
return filepath

def load_model(self, filepath: str) -> None:
logger.info(f"Loading Annoy Index from {filepath}...")
self.model.load(filepath)
self.built = True
return self.model

def build(self) -> None:
logger.info(f"Building Annoy Index with {self.meta}...")
self.model.build(self.meta.n_trees, self.meta.n_jobs)
self.built = True

def on_disk_build(self, filename: str) -> None:
logger.info(f"Building Annoy Index straight to disk: {filename}...")
self.model.on_disk_build(filename)
self.on_disk = filename

0 comments on commit 2e16ee3

Please sign in to comment.