diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 00000000..ceadcf22 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,13 @@ +[run] +branch = True + +[report] +# Regexes for lines to exclude from consideration +exclude_lines = + # Don't complain if non-runnable code isn't run: + if 0: + if __name__ == .__main__.: + + .*# pragma: no cover + .*# pragma: no branch + diff --git a/Makefile b/Makefile index fc2c7dd5..168a72fb 100644 --- a/Makefile +++ b/Makefile @@ -51,4 +51,4 @@ unit-tests: coverage run -m unittest -v check-test-coverage: - coverage report -m --omit="jetstream/core/proto/*,jetstream/engine/tokenizer_pb2.py,jetstream/external_tokenizers/*,benchmarks/benchmark_serving.py,benchmarks/eval_accuracy.py,benchmarks/eval_accuracy_mmlu.py,benchmarks/eval_accuracy_longcontext.py,benchmarks/math_utils.py" --fail-under=96 + coverage report -m --omit="jetstream/tests/*,jetstream/core/proto/*,jetstream/engine/tokenizer_pb2.py,jetstream/external_tokenizers/*,benchmarks/benchmark_serving.py,benchmarks/eval_accuracy.py,benchmarks/eval_accuracy_mmlu.py,benchmarks/eval_accuracy_longcontext.py,benchmarks/math_utils.py,benchmarks/tests/*" --fail-under=90 diff --git a/jetstream/core/server_lib.py b/jetstream/core/server_lib.py index bd68257e..d9ae2cb0 100644 --- a/jetstream/core/server_lib.py +++ b/jetstream/core/server_lib.py @@ -23,6 +23,7 @@ import logging import os import signal +import sys import threading import time import traceback @@ -41,6 +42,21 @@ _HOST = "[::]" +# Create seperate logger to log all INFO message for this module. These show +# stages of server startup and inform user if server is ready to take requests. +# The default logger created in orchestrator.py only logs WARNINGs and above +logger = logging.getLogger(__name__) +logger.propagate = False +logger.setLevel(logging.INFO) +formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) + +info_handler = logging.StreamHandler(sys.stdout) +info_handler.setLevel(logging.INFO) +info_handler.setFormatter(formatter) +logger.addHandler(info_handler) + class JetStreamServer: """JetStream grpc server.""" @@ -120,7 +136,7 @@ def create_driver( prefill_params = [pe.load_params() for pe in engines.prefill_engines] generate_params = [ge.load_params() for ge in engines.generate_engines] shared_params = [ie.load_params() for ie in engines.interleaved_engines] - logging.info("Loaded all weights.") + logger.info("Loaded all weights.") if metrics_collector: metrics_collector.get_model_load_time_metric().set( time.time() - model_load_start_time @@ -135,13 +151,13 @@ def create_driver( generate_params = generate_params + shared_params if prefill_engines is None: - prefill_engines = [] + prefill_engines = [] # pragma: no branch if generate_engines is None: - generate_engines = [] + generate_engines = [] # pragma: no branch if prefill_params is None: - prefill_params = [] + prefill_params = [] # pragma: no branch if generate_params is None: - generate_params = [] + generate_params = [] # pragma: no branch if enable_model_warmup: prefill_engines = [engine_api.JetStreamEngine(pe) for pe in prefill_engines] @@ -215,11 +231,11 @@ def run( del lora_input_adapters_path server_start_time = time.time() - logging.info("Kicking off gRPC server.") + logger.info("Kicking off gRPC server.") # Setup Prometheus server metrics_collector: JetstreamMetricsCollector = None if metrics_server_config and metrics_server_config.port: - logging.info( + logger.info( "Starting Prometheus server on port %d", metrics_server_config.port ) start_http_server(metrics_server_config.port) @@ -227,7 +243,7 @@ def run( model_name=metrics_server_config.model_name ) else: - logging.info( + logger.info( "Not starting Prometheus server: --prometheus_port flag not set" ) @@ -256,7 +272,7 @@ def run( gc.set_threshold(allocs, gen1, gen2) print("GC tweaked (allocs, gen1, gen2): ", allocs, gen1, gen2) - logging.info("Starting server on port %d with %d threads", port, threads) + logger.info("Starting server on port %d with %d threads", port, threads) jetstream_server.start() if metrics_collector: @@ -266,10 +282,10 @@ def run( # Setup Jax Profiler if enable_jax_profiler: - logging.info("Starting JAX profiler server on port %s", jax_profiler_port) + logger.info("Starting JAX profiler server on port %s", jax_profiler_port) jax.profiler.start_server(jax_profiler_port) else: - logging.info("Not starting JAX profiler server: %s", enable_jax_profiler) + logger.info("Not starting JAX profiler server: %s", enable_jax_profiler) # Start profiling server by default for proxy backend. if jax.config.jax_platforms and "proxy" in jax.config.jax_platforms: @@ -279,6 +295,7 @@ def run( target=proxy_util.start_profiling_server, args=(jax_profiler_port,) ) thread.run() + logger.info("Server up and ready to process requests on port %s", port) return jetstream_server @@ -287,5 +304,5 @@ def get_devices() -> Any: """Gets devices.""" # TODO: Add more logs for the devices. devices = jax.devices() - logging.info("Using devices: %d", len(devices)) + logger.info("Using devices: %d", len(devices)) return devices