diff --git a/llama_cpp/server/__main__.py b/llama_cpp/server/__main__.py index 03691cfb6..2942e5146 100644 --- a/llama_cpp/server/__main__.py +++ b/llama_cpp/server/__main__.py @@ -51,6 +51,21 @@ def contains_list_type(annotation) -> bool: else: return False +def parse_bool_arg(arg): + if isinstance(arg, bytes): + arg = arg.decode('utf-8') + + true_values = {'1', 'on', 't', 'true', 'y', 'yes'} + false_values = {'0', 'off', 'f', 'false', 'n', 'no'} + + arg_str = str(arg).lower().strip() + + if arg_str in true_values: + return True + elif arg_str in false_values: + return False + else: + raise ValueError(f'Invalid boolean argument: {arg}') if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -72,16 +87,8 @@ def contains_list_type(annotation) -> bool: parser.add_argument( f"--{name}", dest=name, - action="store_true", - help=f"Disable {description}", - default=field.default, - ) - parser.add_argument( - f"--no-{name}", - dest=name, - action="store_false", + type=parse_bool_arg, help=f"Disable {description}", - default=field.default, ) args = parser.parse_args()