Skip to content

Commit

Permalink
➕ Add tomli
Browse files Browse the repository at this point in the history
  • Loading branch information
Freed-Wu committed Jul 20, 2023
1 parent 42251be commit e2616de
Show file tree
Hide file tree
Showing 10 changed files with 99 additions and 51 deletions.
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ classifiers = [
"Programming Language :: Python :: Implementation :: CPython",
"Programming Language :: Python :: Implementation :: PyPy",
]
dynamic = ["version", "optional-dependencies"]
dynamic = ["version", "dependencies", "optional-dependencies"]

[[project.authors]]
name = "Wu Zhenyu"
Expand Down Expand Up @@ -65,6 +65,9 @@ write-to = "src/translate_shell/_metainfo.py"
[tool.setuptools-generate.metainfo-template]
file = "templates/metainfo.py.j2"

[tool.setuptools.dynamic.dependencies]
file = "requirements.txt"

# begin: scripts/update-pyproject.toml.pl
[tool.setuptools.dynamic.optional-dependencies.color]
file = "requirements/color.txt"
Expand Down
3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
#!/usr/bin/env -S pip install -r

tomli; python_version < "3.11"
7 changes: 7 additions & 0 deletions src/translate_shell/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,13 @@ def get_parser() -> ArgumentParser:
default=config.source_lang,
help="source languages. default: %(default)s",
).complete = LANG_COMPLETE # type: ignore
parser.add_argument(
"--options",
help="advanced usage, see "
"https://translate-shell.readthedocs.io/en/latest/resources/config.html "
". default: %(default)s",
action="append",
).complete = LANG_COMPLETE # type: ignore
parser.add_argument(
"text",
nargs="*",
Expand Down
10 changes: 6 additions & 4 deletions src/translate_shell/tools/po/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@

from translate_shell.translate import translate

try:
import tomllib as tomli
except ImportError:
import tomli

logger = logging.getLogger(__name__)


Expand All @@ -24,10 +29,7 @@ def run(args: Namespace) -> None:
default_target_lang = args.target_lang
source_lang = args.source_lang
translator = args.translator
option = {
option.partition("=")[0]: option.partition("=")[2]
for option in args.option
}
option = tomli.loads("\n".join(args.option))
wrapwidth = int(args.wrapwidth)
progress = args.progress.lower() == "true"
verbose = args.verbose.lower() == "true"
Expand Down
1 change: 1 addition & 0 deletions src/translate_shell/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ def translate(
if isinstance(translator, Callable):
translator = translator()
true_translators += [translator]

if len(translators) == 1:
translator = true_translators[0]
translate_once(
Expand Down
110 changes: 64 additions & 46 deletions src/translate_shell/translators/llm/_llama_cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@
from ...external.platformdirs import AppDirs
from . import LLMTranslator

MODEL = str(AppDirs("translate-shell").user_data_path / "model.bin")
MODEL_PATH = str(AppDirs("translate-shell").user_data_path / "model.bin")
# cache init
old_model_path = MODEL_PATH
old_kwargs = {}
old_model = None


class LlamaTranslator(LLMTranslator):
Expand All @@ -30,49 +34,63 @@ def init_model(option: dict) -> Any:
:type option: dict
:rtype: Any
"""
model = option.get("model", MODEL)
if isinstance(model, str):
model = os.path.expanduser(model)
kwargs = {}
if n_ctx := option.get("n_ctx"):
kwargs["n_ctx"] = int(n_ctx)
if n_parts := option.get("n_parts"):
kwargs["n_parts"] = int(n_parts)
if n_gpu_layers := option.get("n_gpu_layers"):
kwargs["n_gpu_layers"] = int(n_gpu_layers)
if seed := option.get("seed"):
kwargs["seed"] = int(seed)
if f16_kv := option.get("f16_kv"):
kwargs["f16_kv"] = bool(f16_kv)
if logits_all := option.get("logits_all"):
kwargs["logits_all"] = bool(logits_all)
if vocab_only := option.get("vocab_only"):
kwargs["vocab_only"] = bool(vocab_only)
if use_mmap := option.get("use_mmap"):
kwargs["use_mmap"] = bool(use_mmap)
if use_mlock := option.get("use_mlock"):
kwargs["use_mlock"] = bool(use_mlock)
if embedding := option.get("embedding"):
kwargs["embedding"] = bool(embedding)
if n_threads := option.get("n_threads"):
kwargs["n_threads"] = n_threads
if n_batch := option.get("n_batch"):
kwargs["n_batch"] = int(n_batch)
if last_n_tokens_size := option.get("last_n_tokens_size"):
kwargs["last_n_tokens_size"] = int(last_n_tokens_size)
if lora_base := option.get("lora_base"):
kwargs["lora_base"] = lora_base
if lora_path := option.get("lora_path"):
kwargs["lora_path"] = lora_path
if low_vram := option.get("low_vram"):
kwargs["low_vram"] = bool(low_vram)
if tensor_split := option.get("tensor_split"):
kwargs["tensor_split"] = tensor_split
if rope_freq_base := option.get("rope_freq_base"):
kwargs["rope_freq_base"] = float(rope_freq_base)
if rope_freq_scale := option.get("rope_freq_scale"):
kwargs["rope_freq_scale"] = float(rope_freq_scale)
if verbose := option.get("verbose"):
kwargs["verbose"] = bool(verbose)
model = Llama(model, **kwargs)
global old_model_path, old_kwargs, old_model
model_path = option.get("model", MODEL_PATH)
if isinstance(model_path, Llama):
# cache clear
old_model_path = ""
old_kwargs = {}
old_model = model_path
return model_path
model_path = os.path.expanduser(model_path)
kwargs = {}
if n_ctx := option.get("n_ctx"):
kwargs["n_ctx"] = int(n_ctx)
if n_parts := option.get("n_parts"):
kwargs["n_parts"] = int(n_parts)
if n_gpu_layers := option.get("n_gpu_layers"):
kwargs["n_gpu_layers"] = int(n_gpu_layers)
if seed := option.get("seed"):
kwargs["seed"] = int(seed)
if f16_kv := option.get("f16_kv"):
kwargs["f16_kv"] = bool(f16_kv)
if logits_all := option.get("logits_all"):
kwargs["logits_all"] = bool(logits_all)
if vocab_only := option.get("vocab_only"):
kwargs["vocab_only"] = bool(vocab_only)
if use_mmap := option.get("use_mmap"):
kwargs["use_mmap"] = bool(use_mmap)
if use_mlock := option.get("use_mlock"):
kwargs["use_mlock"] = bool(use_mlock)
if embedding := option.get("embedding"):
kwargs["embedding"] = bool(embedding)
if n_threads := option.get("n_threads"):
kwargs["n_threads"] = n_threads
if n_batch := option.get("n_batch"):
kwargs["n_batch"] = int(n_batch)
if last_n_tokens_size := option.get("last_n_tokens_size"):
kwargs["last_n_tokens_size"] = int(last_n_tokens_size)
if lora_base := option.get("lora_base"):
kwargs["lora_base"] = lora_base
if lora_path := option.get("lora_path"):
kwargs["lora_path"] = lora_path
if low_vram := option.get("low_vram"):
kwargs["low_vram"] = bool(low_vram)
if tensor_split := option.get("tensor_split"):
kwargs["tensor_split"] = tensor_split
if rope_freq_base := option.get("rope_freq_base"):
kwargs["rope_freq_base"] = float(rope_freq_base)
if rope_freq_scale := option.get("rope_freq_scale"):
kwargs["rope_freq_scale"] = float(rope_freq_scale)
if verbose := option.get("verbose"):
kwargs["verbose"] = bool(verbose)
# cache hit
if kwargs == old_kwargs and model_path == old_model_path and old_model:
model = old_model
else:
model = Llama(model_path, **kwargs)
# cache reinit
old_model_path = model_path
old_kwargs = kwargs
old_model = model
return model
3 changes: 3 additions & 0 deletions src/translate_shell/translators/stardict/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ def get_tokens(self, text: str, tl: str, sl: str) -> tuple[list[str], str]:
sl = detect(text)
except LangDetectException:
sl = "en"
# convert zh-cn to zh_CN
lang, _, country = sl.partition("-")
sl = lang + "_" + country.upper() if country else ""

dictionaries = self.stardict.get(sl, self.stardict["en"]).get(tl, [])
for directory in STARDICT_DIRS:
Expand Down
7 changes: 7 additions & 0 deletions src/translate_shell/ui/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from ..translate import translate
from ..translators import TRANSLATORS, get_dummy

try:
import tomllib as tomli
except ImportError:
import tomli

with suppress(ImportError):
from rich import traceback
from rich.logging import RichHandler
Expand Down Expand Up @@ -112,6 +117,7 @@ def init(args: Namespace) -> None:
if value is not None:
setattr(args, attr, value)
args.text = " ".join(args.text)
args.options = tomli.loads("\n".join(args.options))
if not args.lsp:
_readline = init_readline()
_readline.set_completer(args.complete)
Expand Down Expand Up @@ -172,6 +178,7 @@ def get_processed_result_text(
target_lang,
source_lang,
translators,
args.options,
)
if args.format == "json":
rst = json.dumps(vars(translation))
Expand Down
3 changes: 3 additions & 0 deletions tests/txt/options.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,6 @@ options:
target languages. default: auto
--source-lang SOURCE_LANG
source languages. default: auto
--options OPTIONS advanced usage, see https://translate-
shell.readthedocs.io/en/latest/resources/config.html .
default: None
1 change: 1 addition & 0 deletions tests/txt/usage.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ usage: trans [-h] [-V] [--print-completion {bash,zsh,tcsh}]
[--sleep-seconds SLEEP_SECONDS] [--config CONFIG]
[--format {json,yaml,text}] [--translators TRANSLATORS]
[--target-lang TARGET_LANG] [--source-lang SOURCE_LANG]
[--options OPTIONS]
[text ...]

0 comments on commit e2616de

Please sign in to comment.