From dd24993c532ccda2f7f4b934e4e432805de0dcf0 Mon Sep 17 00:00:00 2001 From: juhoinkinen <34240031+juhoinkinen@users.noreply.github.com> Date: Fri, 5 Jul 2024 14:08:06 +0300 Subject: [PATCH] Use fasttext via floret --- annif/backend/fasttext.py | 13 +++---------- pyproject.toml | 4 ++-- 2 files changed, 5 insertions(+), 12 deletions(-) diff --git a/annif/backend/fasttext.py b/annif/backend/fasttext.py index e102b02ba..0c37c07e3 100644 --- a/annif/backend/fasttext.py +++ b/annif/backend/fasttext.py @@ -6,7 +6,7 @@ import os.path from typing import TYPE_CHECKING, Any -import fasttext +import floret import annif.util from annif.exception import NotInitializedException, NotSupportedException @@ -65,14 +65,7 @@ def default_params(self) -> dict[str, Any]: @staticmethod def _load_model(path: str) -> _FastText: - # monkey patch fasttext.FastText.eprint to avoid spurious warning - # see https://github.com/facebookresearch/fastText/issues/1067 - orig_eprint = fasttext.FastText.eprint - fasttext.FastText.eprint = lambda x: None - model = fasttext.load_model(path) - # restore the original eprint - fasttext.FastText.eprint = orig_eprint - return model + return floret.load_model(path) def initialize(self, parallel: bool = False) -> None: if self._model is None: @@ -132,7 +125,7 @@ def _create_model(self, params: dict[str, Any], jobs: int) -> None: if jobs != 0: # jobs set by user to non-default value params["thread"] = jobs self.debug("Model parameters: {}".format(params)) - self._model = fasttext.train_supervised(trainpath, **params) + self._model = floret.train_supervised(trainpath, **params) self._model.save_model(modelpath) def _train( diff --git a/pyproject.toml b/pyproject.toml index 970fd2503..40a4a0955 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -49,7 +49,7 @@ simplemma = "0.9.*" jsonschema = "4.21.*" huggingface-hub = "0.22.*" -fasttext-wheel = { version = "0.9.2", optional = true } +floret = { version = "~0.10.5", optional = true } voikko = { version = "0.5.*", optional = true } tensorflow-cpu = { version = "2.15.*", optional = true, python = "<3.12" } lmdb = { version = "1.4.1", optional = true } @@ -71,7 +71,7 @@ isort = "*" schemathesis = "3.*.*" [tool.poetry.extras] -fasttext = ["fasttext-wheel"] +fasttext = ["floret"] voikko = ["voikko"] nn = ["tensorflow-cpu", "lmdb"] omikuji = ["omikuji"]