Skip to content

1.2406.5

Latest
Compare
Choose a tag to compare
@brunneis brunneis released this 11 Jun 17:15
· 2 commits to master since this release
741753b

Centroid Classifier Refactor:

  • Normalization Improvements: Introduced _normalize method for efficient tensor normalization using torch.nn.functional.normalize.
  • Training Enhancements:
    • train method now calculates centroids using mean embeddings for each label.
    • Centroids are stored and normalized upon training.
  • Prediction Optimization:
    • Improved predict and predict_one methods to utilize normalized centroids.
    • Replaced cosine similarity calculations with dot product for faster computations.

Interface Changes:

  • Updated get_embeddings method to yield torch.Tensor instead of numpy.ndarray.
  • Removed redundant code and streamlined embedding extraction process.

Embedding Model Initialization:

  • Ensured the embedding model is set to evaluation mode immediately after loading to improve inference efficiency (self._model.eval()).